not needed
[prop.git] / prop-src / funmap.pcc.old
blob73d4ca776ddc6594ac842e386f89b77b5077066a
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 //  This file implements the FunctorMap data structure 
4 //
5 ///////////////////////////////////////////////////////////////////////////////
6 #include <iostream.h>
7 #include <strstream.h>
8 #include <AD/automata/treegram.ph>
9 #include <AD/automata/treegen.h>
10 #include <AD/rewrite/burs_gen.h>
11 #include <AD/strings/quark.h>
12 #include "funmap.ph"
13 #include "ir.ph"
14 #include "ast.ph"
15 #include "matchcom.ph"
16 #include "type.h"
17 #include "hashtab.h"
18 #include "datagen.h"
19 #include "config.h"
20 #include "rwgen.h"
21 #include "options.h"
23 ///////////////////////////////////////////////////////////////////////////////
25 //  Import some type definitions from the tree grammar and hash table
26 //  classes.
28 ///////////////////////////////////////////////////////////////////////////////
29 typedef TreeGrammar::TreeProduction TreeProduction;
30 typedef TreeGrammar::Cost           TreeCost;
31 typedef HashTable::Key              Key;
32 typedef HashTable::Value            Value;
34 ///////////////////////////////////////////////////////////////////////////////
36 //  Instantiate the vector id type
38 ///////////////////////////////////////////////////////////////////////////////
39 instantiate datatype VectorId;
41 ///////////////////////////////////////////////////////////////////////////////
43 //  Hashing and equality on vector id's
45 ///////////////////////////////////////////////////////////////////////////////
46 unsigned int vector_id_hash(HashTable::Key k)
47 {  VectorId id = (VectorId)k;
48    return (unsigned int)id->cons + ty_hash(id->ty) + id->arity;
50 Bool vector_id_equal(HashTable::Key a, HashTable::Key b)
51 {  VectorId x = (VectorId)a;
52    VectorId y = (VectorId)b;
53    return x->cons == y->cons && x->arity == y->arity && ty_equal(x->ty,y->ty); 
56 ///////////////////////////////////////////////////////////////////////////////
58 //  Method to decorate cost expression and attribute bindings for
59 //  a pattern.
61 ///////////////////////////////////////////////////////////////////////////////
62 int decor_rewrite (Pat pat, int rule, int kid, PatternVarEnv& E)
63 {  match while (pat)
64    {  NOpat || LITERALpat _ || CONSpat _: { return kid; }
65    |  ASpat (_,p,_,_):           { pat = p; } 
66    |  MARKEDpat(_,p):            { pat = p; }
67    |  TYPEDpat(p,_):             { pat = p; } 
68    |  GUARDpat(p,_):             { pat = p; }
69    |  APPpat (_,p):              { pat = p; }
70    |  CONTEXTpat(_,p):           { pat = p; }
71    |  TUPLEpat ps:               
72       {  return decor_rewrite(ps, rule, kid, E); }
73    |  LISTpat{nil,cons,head=ps,tail=rest}: 
74       {  kid = decor_rewrite(ps, rule, kid, E); pat = rest; }
75    |  VECTORpat { len, array, elements, head_flex, tail_flex ... }:
76       {  kid = decor_rewrite(elements, rule, 
77                   decor_rewrite(array, rule, kid, E), E); 
78          if (head_flex || tail_flex)
79             error("%Lflexible vector pattern currently not supported in rewriting: %p\n", pat); 
80          pat = len;
81       }
82    |  IDpat (id,_,_):
83       {  Id attrib_name = #"#" + id;
84          Id cost_name   = #"$" + id;
85          Ty ty;
86          E.add (attrib_name, SYNexp(kid, rule, mkvar()), ty, ISpositive);
87          E.add (cost_name, COSTexp(kid),  ty, ISpositive);
88          return kid+1;
89       }
90    |  _:  { return kid; }
91    }
94 ///////////////////////////////////////////////////////////////////////////////
96 //  Decorate a pattern list.
98 ///////////////////////////////////////////////////////////////////////////////
99 int decor_rewrite (Pats pats, int rule, int kid, PatternVarEnv& E)
100 {  for_each (Pat, p, pats) kid = decor_rewrite(p, rule, kid, E);
101    return kid;
104 ///////////////////////////////////////////////////////////////////////////////
106 //  Decorate rewriting patterns.
108 ///////////////////////////////////////////////////////////////////////////////
109 int decor_rewrite (Pat pat, int rule, PatternVarEnv& E)
110 {  Ty ty;
111    E.add (#"##", THISSYNexp(rule,mkvar()), ty, ISpositive);
112    E.add (#"$$", THISCOSTexp, ty, ISpositive);
113    return decor_rewrite (pat, rule, 0, E);
116 ///////////////////////////////////////////////////////////////////////////////
118 //  Mapping from rewrite class to protocols.
120 ///////////////////////////////////////////////////////////////////////////////
121 HashTable rewrite_env(string_hash, string_equal);
122 HashTable rewrite_qual(string_hash, string_equal);
124 ///////////////////////////////////////////////////////////////////////////////
126 //  Enter a rewrite class
128 ///////////////////////////////////////////////////////////////////////////////
129 void add_rewrite_class(Id id, Protocols protocols, TyQual qual)
130 {  if (rewrite_env.contains(id)) {
131       error ("%Lrewrite class %s has already been defined\n", id);
132    } else {
133       rewrite_env.insert(id, protocols);
134       rewrite_qual.insert(id, (HashTable::Value)qual);
135       debug_msg ("[Rewriting class %s declared]\n", id);
136    }
139 ///////////////////////////////////////////////////////////////////////////////
141 //  Lookup a rewrite class 
143 ///////////////////////////////////////////////////////////////////////////////
144 Protocols lookup_rewrite_class(Id id)
145 {  HashTable::Entry * e = rewrite_env.lookup(id);
146    if (e == 0) {
147       error ("%Lrewrite class %s is undefined\n", id);
148       return #[];
149    } else {
150       return (Protocols)rewrite_env.value(e);
151    } 
154 ///////////////////////////////////////////////////////////////////////////////
156 //  Constructor of the functor mapping table.
158 ///////////////////////////////////////////////////////////////////////////////
159 FunctorMap::FunctorMap(int n, Id name) 
160                         : class_name(name), N(n),
161                           literal_map(literal_hash,literal_equal,129), 
162                           var_map    (string_hash,string_equal),
163                           type_map   (ty_hash,ty_equal),
164                           vector_map (vector_id_hash,vector_id_equal),
165                           rule_map   (ty_hash,ty_equal),
166                           protocols  (ty_hash,ty_equal),
167                           nonterm_map(string_hash,string_equal),
168                           functors   (0),
169                           variables  (0),
170                           tree_gen   (0),
171                           use_compression(true),
172                           has_guard(false),
173                           has_cost(false),
174                           has_cost_exp(false),
175                           has_syn_attrib(false),
176                           use_stack(false),
177                           iso_tree(false),
178                           gen_traversal(false),
179                           max_arity(1)
180                           {}
182 ///////////////////////////////////////////////////////////////////////////////
184 //  Check whether we know of the type
186 ///////////////////////////////////////////////////////////////////////////////
187 Bool FunctorMap::is_known_type(Ty ty)
188 {  return type_map.contains(ty)      ||
189           ty_equal(ty, integer_ty)   ||
190           ty_equal(ty, bool_ty)      ||
191           ty_equal(ty, real_ty)      ||
192           ty_equal(ty, string_ty)    ||
193           ty_equal(ty, character_ty)
194    ;
197 ///////////////////////////////////////////////////////////////////////////////
199 //  Check whether we the type is rewritable.
201 ///////////////////////////////////////////////////////////////////////////////
202 Bool FunctorMap::is_rewritable_type(Ty ty) { return type_map.contains(ty); } 
204 ///////////////////////////////////////////////////////////////////////////////
206 //  Method to assign variable encoding to a non-terminal
208 ///////////////////////////////////////////////////////////////////////////////
209 void FunctorMap::encode (Id id)
210 {  if (! var_map.contains(id))
211    {  ++variables;
212       var_map.insert(id,(HashTable::Value)(variables));
213    }
216 ///////////////////////////////////////////////////////////////////////////////
218 //  Method to assign functor encoding to a type
220 ///////////////////////////////////////////////////////////////////////////////
221 void FunctorMap::encode (Ty ty)
222 {  match (deref_all(ty))
223    {  ty as TYCONty(DATATYPEtycon { unit, arg ... }, _):
224       {  if (! type_map.contains(ty)) 
225          {  type_map.insert(ty, (HashTable::Value)functors);
226             functors += unit + arg;
227          }
228       }
229    |  TYCONty(_,tys): {  for_each(Ty, ty, tys) encode(ty); }
230    |  _:  // skip
231    }
234 ///////////////////////////////////////////////////////////////////////////////
236 //  Method to assign functor encoding to a pattern.
237 //  Assign a functor value to each distinct literal and pattern constructor.
239 ///////////////////////////////////////////////////////////////////////////////
240 void FunctorMap::encode(Pat pat)
241 {  match while (pat)
242    {  NOpat || WILDpat _ || IDpat _: { return; }
243    |  ASpat(_,p,_,_):  { pat = p; }
244    |  TYPEDpat(p,_):   { pat = p; }
245    |  MARKEDpat(_,p):  { pat = p; }
246    |  TUPLEpat ps:     
247       {  int i = 0; 
248          for_each (Pat, p, ps) { i++; encode(p); }
249          if (max_arity < i) max_arity = i;
250          return; 
251       }
252    |  RECORDpat(lab_pats,_): 
253       {  for_each (LabPat, p, lab_pats) { encode(p.pat); }
254          int arity = arity_of(pat->ty);
255          if (max_arity < arity) max_arity = arity;
256          return;
257       }
258    |  LITERALpat l:                
259       {  if (! literal_map.contains(l)) 
260          {  literal_map.insert(l,(HashTable::Value)functors); 
261             functors++; 
262          }
263          return;
264       }
265    |  CONSpat(ONEcons 
266               { alg_ty = alg_ty as 
267                    TYCONty(DATATYPEtycon { unit, arg, terms ... },_) 
268                 ... }):
269       {  if (pat->ty != NOty && ! type_map.contains(pat->ty)) 
270          {  type_map.insert(pat->ty, (HashTable::Value)functors);
271             functors += unit + arg;
272          }
273          return;
274       }  
275    |  APPpat(a,b):  { encode(pat->ty); pat = b; }
276    |  LISTpat{cons,nil,head=ps,tail=p}:
277       {  Pat new_pat = CONSpat(nil);
278          new_pat->ty = pat->ty;
279          encode(new_pat);
280          for_each (Pat, i, ps) encode(i);
281          if (max_arity < 2) max_arity = 2;
282          pat = p;
283       }
284    |  VECTORpat { cons, elements ... }:
285       {  Pat new_pat = CONSpat(cons);
286          new_pat->ty = pat->ty;
287          encode(new_pat);
288          for_each (Pat, p, elements) encode(p); 
289          int l = length(elements);
290          if (max_arity < l) max_arity = l;
291          if (pat->ty != NOty)
292          {  VectorId vec_id = vector_id(cons,pat->ty,l);
293             if ( ! vector_map.contains(vec_id))
294             {  vector_map.insert(vec_id, (HashTable::Value)functors);
295                functors++;
296             }
297          }
298          return;
299       }
300    |  _: { error ("%LSorry: pattern not supported in rewriting: %p\n", pat); return; }
301    }
304 ///////////////////////////////////////////////////////////////////////////////
306 //  Method to translate a pattern into a term.
308 ///////////////////////////////////////////////////////////////////////////////
309 TreeTerm FunctorMap::trans(Pat pat)
310 {  match while (pat)
311    {  NOpat || WILDpat _: { return wild_term; }
312    |  ASpat(_,p,_,_):     { pat = p; }
313    |  TYPEDpat(p,_):      { pat = p; }
314    |  MARKEDpat(_,p):     { pat = p; }
315    |  LITERALpat l:     
316       {  return new_term(mem_pool,(Functor)literal_map[l]); } 
317    |  IDpat (id,_,_):                
318       {  return var_map.contains(id) ? 
319             var_term((Variable)var_map[id]) : wild_term; 
320       }
321    |  TUPLEpat pats: 
322       {  int arity = length (pats);
323          TreeTerm * subterms = 
324             (TreeTerm *)mem_pool.c_alloc(sizeof(TreeTerm) * arity);
325          int i = 0; 
326          for_each (Pat, p, pats)
327          {  subterms[i++] = trans(p); }
328          return new_term(mem_pool,0,arity,subterms);
329       }
330    |  RECORDpat (lab_pats,_):
331       {  match (deref(pat->ty))
332          {  RECORDty (labels,_,tys):
333             {  Bool relevant[256]; int i; int arity;
334                arity = 0;
335                for_each(Ty, t, tys) 
336                {  if (relevant[i++] = is_known_type(t)) arity++; }
337                TreeTerm * subterms = 
338                   (TreeTerm *)mem_pool.c_alloc(sizeof(TreeTerm) * arity);
339                for (i = 0; i < arity; i++)
340                   subterms[i] = wild_term;
341                for_each (LabPat, p, lab_pats)
342                {  Ids labs; Tys ts;
343                   for (i = 0, labs = labels, ts = tys; 
344                        labs && ts; labs = labs->_2, ts = ts->_2)
345                   {  if (p.label == labs->_1)
346                      {  subterms[i] = trans(p.pat); break; }
347                      if (is_known_type(ts->_1)) i++;
348                   }
349                }
350                return new_term(mem_pool,0,arity,subterms);
351             }
352          |  _: { bug("%Lillegal record pattern %p\n", pat); }
353          }
354       }
355    |  APPpat(CONSpat(ONEcons 
356               { ty = arg_ty, tag, 
357                 alg_ty = TYCONty(DATATYPEtycon { unit ... },_) ... 
358               }), p):
359       {  TreeTerm a = trans(p);
360          match (arity_of(arg_ty)) and (a)
361          {  1, _: 
362             {  return new_term(mem_pool, 
363                   (Functor)type_map[pat->ty]+unit+tag,1,&a); 
364             }
365          |  _, tree_term(f,_,_):
366             {  f = (Functor)type_map[pat->ty]+unit+tag; return a; }
367          |  n, _:
368             {  return new_term(mem_pool,
369                   (Functor)type_map[pat->ty]+unit+tag, n);
370             }
371          }
372       }
373    |  CONSpat(ONEcons { tag ... }):
374       {  return new_term(mem_pool, (Functor)type_map[pat->ty]+tag); }
375    |  LISTpat{ nil, head = #[], tail = NOpat ... }: 
376       {  Pat p = CONSpat(nil); p->ty = pat->ty; pat = p; }
377    |  LISTpat{ head = #[], tail ... }:   {  pat = tail; }
378    |  LISTpat{ cons, nil, head = #[h ... t], tail }: 
379       {  Pat new_tail = LISTpat'{cons=cons,nil=nil,head=t,tail=tail};
380          Pat new_p    = APPpat(CONSpat(cons),TUPLEpat(#[h, new_tail]));
381          new_p->ty    = new_tail->ty = pat->ty;
382          pat = new_p;
383       }
384    |  VECTORpat { cons, elements ... }:
385       {  TreeTerm a     = trans(TUPLEpat(elements));
386          int      arity = length(elements);
387          match (a)
388          {  tree_term(f,_,_):
389             {  f = (Functor)vector_map[vector_id(cons,pat->ty,arity)]; 
390                return a; 
391             }
392          |  _: 
393             { bug ("%Lillegal pattern: %p\n", pat); return wild_term; }
394          }
395       }
396    |  _: { error ("%LSorry: pattern not supported: %p\n", pat); return wild_term; }
397    }
400 ///////////////////////////////////////////////////////////////////////////////
402 //  Method to partition the set of rules according to the types of the
403 //  patterns.  Also encode the patterns in the process.
405 ///////////////////////////////////////////////////////////////////////////////
406 void FunctorMap::partition_rules (MatchRules rules)
407 {  // First, we assign a new type variable for each lhs non-terminal.
408    {  for_each (MatchRule, r, rules)
409       {  match (r)
410          {  MATCHrule(lhs,_,_,_,_):
411             {  if (lhs)
412                {  HashTable::Entry * lhs_entry = nonterm_map.lookup(lhs);
413                   if (! lhs_entry) nonterm_map.insert(lhs,mkvar());
414                   encode(lhs);  // compute encoding for the variable
415                }
416             }
417          }
418       }
419    }
421    // Type check all the rules next.
422    // We have to also compute the type map for each lhs non-terminal.
423    // Of course, a non-terminal but have only one single type.
424    // This is done by unifying all occurances of a non-terminal.
426    patvar_typemap = &nonterm_map; // set the pattern variable type map
428    for_each (MatchRule, r, rules)
429    {  match (r)
430       {  MATCHrule(lhs,pat,_,_,_):
431          {  r->set_loc();
432             Ty ty = r->ty = type_of(pat); 
434             // Check the type of the non-terminal (if any).
435             if (lhs)
436             {  Ty lhs_ty = Ty(nonterm_map.lookup(lhs)->v);
437                if (! unify(lhs_ty, ty))
438                {  error("%!type mismatch between nonterminal %s(type %T) and rule %r(type %T)\n",
439                         r->loc(),lhs,lhs_ty,r,ty);
440                }
441             }
442             
443             if (! is_datatype(ty))
444                error ("%!rule %r is of a non datatype: %T\n",r->loc(),r,ty); 
445          }
446       }
447    }
449    patvar_typemap = 0; // reset the pattern variable type map
451    // Now partition rules by type and assign functor encoding.
452    // Since we have also typed the rules, this is quite simple: just
453    // another pass.  We have to make sure that after the type inference
454    // we don't have any more polymorphic types inside the patterns.
455    int rule_num = 0;
456    for_each (MatchRule, R, rules)
457    {  match (R)
458       {  MATCHrule(_,pat,_,_,_):
459          {  if (! is_ground(R->ty))
460                error ("%!rule %r has incomplete type %T\n",R->loc(),R,R->ty); 
461             HashTable::Entry * e = rule_map.lookup(R->ty);
462             if (e) e->v = #[ R ... (MatchRules)e->v ];
463             else rule_map.insert(R->ty,#[ R ]);
464             // assign functor encoding
465             encode(pat);
466             R->rule_number = rule_num++;
467          }
468       }
469    }
472 ///////////////////////////////////////////////////////////////////////////////
474 //  Method to compute the functor and variable table.
476 ///////////////////////////////////////////////////////////////////////////////
477 void FunctorMap::compute_names (Id fun_names[], Id var_names [])
478 {  functor_names  = fun_names;
479    variable_names = var_names;
480    {  for (int i = N + variables - 1; i >= 0; i--) variable_names[i] = 0; }
481    {  for (int i = functors - 1; i >= 0; i--)  functor_names[i] = "???"; }
482    variable_names[0] = "_";
484    // Compute variable names
485    {  foreach_entry (i,var_map) 
486         variable_names[(Variable)var_map.value(i)] = (Id)var_map.key(i);
487    }
489    // Compute literal names
490    {  foreach_entry (i,literal_map) 
491       {  Literal l = (Literal)literal_map.key(i); 
492          Functor f = (Functor)literal_map.value(i);
493          char buf[1024];
494          ostrstream b(buf,sizeof(buf));
495          ostream& s = b;
496          s << l << ends;
497          functor_names[f] = Quark(buf); 
498       }
499    }
501    // Compute constructor names
502    {  foreach_entry (i,type_map) 
503       {  Ty      t = (Ty)type_map.key(i); 
504          Functor f = (Functor)type_map.value(i);
505          match (deref(t))
506          {  TYCONty(DATATYPEtycon { unit, arg, terms ... },_):
507             {  int arity = unit + arg;
508                for (int j = 0; j < arity; j++)
509                {  match (terms[j])
510                   {  ONEcons { name, ty, tag ... }:
511                      {  functor_names[f + (ty == NOty ? tag : tag + unit)] =
512                            name;
513                      }
514                   |  _: // skip
515                   }
516                }
517             }
518          |  _: { bug ("compute_names()"); }
519          }
520       }
521    }
523    // Compute vector constructor names
524    {  foreach_entry (i, vector_map)
525       {  VectorId id = (VectorId)vector_map.key(i);
526          Functor  f  = (Functor)vector_map.value(i);
527          if (id->cons) functor_names[f] = id->cons->name;
528       }
529    }
532 ///////////////////////////////////////////////////////////////////////////////
534 //  Method to print a report detailing the functor/variable encoding,
535 //  the tree grammar and the generated table size.
537 ///////////////////////////////////////////////////////////////////////////////
538 void FunctorMap::print_report (ostream& log)
540    if (var_map.size() > 0) 
541    {  log << "Variable encoding:\n";
542       foreach_entry (e, var_map)
543       {  log << "\tnon-terminal \"" << (Id)var_map.key(e) << "\"\t=\t"
544              << (Variable)var_map.value(e) << '\n';
545       } 
546    }
548    if (literal_map.size() > 0) 
549    {  log << "\nFunctor encoding for literals:\n";
550       foreach_entry (e, literal_map)
551       {  log << "literal " << (Literal)literal_map.key(e) << "\t=\t"
552              << (Functor)literal_map.value(e) << '\n';
553       }
554    }
556    log << "\nFunctor encoding for constructors:\n";
558    {  foreach_entry (e, type_map)
559       {  Ty      t = (Ty)type_map.key(e);
560          Functor f = (Functor)type_map.value(e);
561          log << "datatype " << t << ":\n"; 
562          match (deref(t))
563          {  TYCONty(DATATYPEtycon { unit, arg, terms ... },_):
564             {  int arity = unit + arg;
565                for (int i = 0; i < arity; i++)
566                {  match (terms[i])
567                   {  ONEcons { name, ty, tag ... }:
568                      {  log << '\t' << name << "\t=\t" 
569                             << f + (ty == NOty ? tag : tag + unit) << '\n';
570                      }
571                   |  _: // skip
572                   }
573                }
574             }
575          |  _: // skip
576          }
577       }
578    }
580    log << "\nIndex compression is " 
581        << (use_compression ? "enabled" : "disabled")
582        << "\nAlgorithm is " << tree_gen->algorithm();
584    if (tree_gen) tree_gen->print_report(log);