1 ///////////////////////////////////////////////////////////////////////////////
3 // This file implements the inference rules compiler of Prop.
5 ///////////////////////////////////////////////////////////////////////////////
7 #include <AD/memory/mempool.h>
8 #include <AD/generic/ordering.h>
19 ///////////////////////////////////////////////////////////////////////////////
21 // Constructor and destructor for the inference compiler
23 ///////////////////////////////////////////////////////////////////////////////
24 InferenceCompiler:: InferenceCompiler() {}
25 InferenceCompiler::~InferenceCompiler() {}
27 ///////////////////////////////////////////////////////////////////////////////
29 // Import some type definitions.
31 ///////////////////////////////////////////////////////////////////////////////
32 typedef HashTable::Key Key;
33 typedef HashTable::Value Value;
35 ///////////////////////////////////////////////////////////////////////////////
37 // Method to create an inference class.
39 ///////////////////////////////////////////////////////////////////////////////
40 InferenceClass::InferenceClass(Id id, Inherits i, TyQual qual, Decls body)
41 : ClassDefinition(INFERENCE_CLASS,id,#[],
42 add_inherit("Rete",#[],i),qual,body) {}
43 InferenceClass::~InferenceClass() {}
45 ///////////////////////////////////////////////////////////////////////////////
47 // Method to generate the interface for an inference class.
49 ///////////////////////////////////////////////////////////////////////////////
50 void InferenceClass::gen_class_interface (CodeGen& C)
54 "%^void operator = (const %s&);%-"
56 "%^static const Node network_table[];"
57 "%^static const RelationTable relation_table[];%-"
60 "%^virtual const char * name_of() const;"
61 "%^void initialise_axioms();%-"
63 "%^virtual void alpha_test (int, int, Fact *);"
64 "%^virtual int beta_test (Join, Fact * []);"
65 "%^virtual void action (RuleId, Fact * []);%-"
67 class_name, class_name, class_name, class_name);
70 ///////////////////////////////////////////////////////////////////////////////
72 // Forward declarations.
74 ///////////////////////////////////////////////////////////////////////////////
75 Bool partition(Exp e1, Exp e2, int &obj);
76 Bool partition(Exps exp, int &obj);
77 Bool partition(LabExps exp, int &obj);
78 int partition_inference_rules
79 (int n, InferenceRule Rs[], HashTable& rule_map, HashTable& join_map);
81 ///////////////////////////////////////////////////////////////////////////////
83 // Flatten an expression into conjuncts. Also try to push down negation
84 // as much as possible. Notice that C++ short circuited ands and ors are not
85 // commutative (but they are associative.)
87 ///////////////////////////////////////////////////////////////////////////////
88 Exp * flatten (Exp exp, Exp * cnf, Bool neg)
90 { MARKEDexp(_,e): { exp = e; }
91 | PREFIXexp("!",e): { exp = e; neg = !neg; }
92 | BINOPexp("&&",e1,e2) | !neg:{ cnf = flatten(e1,cnf,neg); exp = e2; }
93 | BINOPexp("||",e1,e2) | neg: { cnf = flatten(e1,cnf,neg); exp = e2; }
94 | BINOPexp(">",e1,e2) | neg: { *cnf = BINOPexp("<=",e1,e2); return cnf+1;}
95 | BINOPexp(">=",e1,e2) | neg: { *cnf = BINOPexp("<",e1,e2); return cnf+1;}
96 | BINOPexp("<",e1,e2) | neg: { *cnf = BINOPexp(">=",e1,e2); return cnf+1;}
97 | BINOPexp("<=",e1,e2) | neg: { *cnf = BINOPexp(">",e1,e2); return cnf+1;}
98 | BINOPexp("==",e1,e2) | neg: { *cnf = BINOPexp("!=",e1,e2); return cnf+1;}
99 | BINOPexp("!=",e1,e2) | neg: { *cnf = BINOPexp("==",e1,e2); return cnf+1;}
100 | e | neg: { *cnf = PREFIXexp("!",e); return cnf+1;}
101 | e: { *cnf = e; return cnf+1; }
105 ///////////////////////////////////////////////////////////////////////////////
107 // Returns true if an expression involves only one object.
108 // Also returns the highest numbered object involved.
109 // As a convention, the object number is -1 if the expression is a constant.
111 ///////////////////////////////////////////////////////////////////////////////
112 Bool partition(Exp exp, int &obj)
114 { MARKEDexp(_,e): { return partition(e,obj); }
115 | PREFIXexp(_,e): { return partition(e,obj); }
116 | POSTFIXexp(_,e): { return partition(e,obj); }
117 | DEREFexp e: { return partition(e,obj); }
118 | SELECTORexp(e,_,_): { return partition(e,obj); }
119 | DOTexp (e,_): { return partition(e,obj); }
120 | ARROWexp(e,_): { return partition(e,obj); }
121 | CONSexp(_,_,e): { return partition(e,obj); }
122 | HASHexp(_,e): { return partition(e,obj); }
123 | CASTexp(_,e): { return partition(e,obj); }
124 | RELexp i: { obj = i; return true; }
125 | APPexp(e1,e2): { return partition(e1,e2,obj); }
126 | INDEXexp(e1,e2): { return partition(e1,e2,obj); }
127 | ASSIGNexp(e1,e2): { return partition(e1,e2,obj); }
128 | BINOPexp(_,e1,e2): { return partition(e1,e2,obj); }
129 | EQexp(_,e1,e2): { return partition(e1,e2,obj); }
130 | UNIFYexp(_,e1,e2): { return partition(e1,e2,obj); }
131 | LTexp(_,e1,e2): { return partition(e1,e2,obj); }
132 | TUPLEexp es: { return partition(es,obj); }
133 | RECORDexp es: { return partition(es,obj); }
134 | SENDexp(_, es): { return partition(es,obj); }
135 | VECTORexp(_,es): { return partition(es,obj); }
136 | LISTexp(_,_,es,e): { return partition(#[e ... es], obj); }
137 | SETLexp(_,es): { return partition(es,obj); }
138 | IFexp (e1,e2,e3): { return partition(#[e1,e2,e3],obj); }
139 | IDexp _ || LITERALexp _ || NOexp: { obj = -1; return true; }
140 | _: { bug("partition: %e", exp); return false; }
144 ///////////////////////////////////////////////////////////////////////////////
146 // Categorize two expressions.
148 ///////////////////////////////////////////////////////////////////////////////
149 Bool partition(Exp e1, Exp e2, int &obj) { return partition(#[e1,e2],obj); }
151 ///////////////////////////////////////////////////////////////////////////////
153 // Categorize an list of expressions.
155 ///////////////////////////////////////////////////////////////////////////////
156 Bool partition(Exps es, int &obj)
161 if (! partition(e,obj1) || (obj >= 0 && obj1 >= 0 && obj1 != obj))
163 if (obj1 > obj) obj = obj1;
168 ///////////////////////////////////////////////////////////////////////////////
170 // Categorize an labeled list of expressions.
172 ///////////////////////////////////////////////////////////////////////////////
173 Bool partition(LabExps es, int &obj)
176 for_each(LabExp, e, es) {
178 if (! partition(e.exp,obj1) || (obj >= 0 && obj1 >= 0 && obj1 != obj))
180 if (obj1 > obj) obj = obj1;
185 ///////////////////////////////////////////////////////////////////////////////
186 // Create an and expression.
187 ///////////////////////////////////////////////////////////////////////////////
188 Exp mkandexp(Exp a, Exp b)
189 { if (a == NOexp) return b;
190 if (b == NOexp) return a;
191 return BINOPexp("&&",a, b);
194 ///////////////////////////////////////////////////////////////////////////////
196 // Decompose guard expressions in conjunctive normal form into
197 // selections and (theta) joins.
199 // We are given $n$ objects and $n$ booleans expressions.
200 // We want to decompose these $n$ expressions into (at most)
201 // $n$ single object selects and (at most) $n$ joins.
203 ///////////////////////////////////////////////////////////////////////////////
204 int decompose (int n, Exp exps[], Exp selects[], Exp joins[])
207 for (i = 0; i < n; i++) selects[i] = joins[i] = NOexp;
208 for (i = 0; i < n; i++) {
209 debug_msg ("decomposing: %e\n", exps[i]);
210 Exp cnf[MAX_CONJUNCTS]; // assume we don't have more than 256 conjuncts.
211 Exp * last = flatten(exps[i], cnf, false); // flatten expression.
212 int conjuncts = last - cnf;
213 if (conjuncts > MAX_CONJUNCTS)
214 bug ("Conjuncts exceeded %i in decompose()", MAX_CONJUNCTS);
215 for (int j = 0; j < conjuncts; j++) {
217 //////////////////////////////////////////////////////////////////////
218 // Checks whether the conjunct depends on only one variable.
219 // If so it can be executed as a guard during pattern matching.
220 // Otherwise, it is a join and must be executed by the
221 // RETE engine. In any case hoist the conjunct as far up as possible
222 // to minimize the sizes of intermediate relations.
223 //////////////////////////////////////////////////////////////////////
224 debug_msg ("partitioning: %e\n", cnf[j]);
225 Bool depends_on_one_variable = partition(cnf[j],obj);
226 if (obj < 0) obj = 0;
227 if (depends_on_one_variable) { // expression is a select
228 if (selects[obj] == NOexp) selects[obj] = cnf[j];
229 else selects[obj] = mkandexp(selects[obj],cnf[j]);
230 debug_msg ("select: %e\n", cnf[j]);
231 } else { // expression is a join
232 if (joins[obj] == NOexp) joins[obj] = cnf[j];
233 else joins[obj] = mkandexp(joins[obj],cnf[j]);
234 debug_msg ("join: %e\n", cnf[j]);
236 if (obj > max_object) max_object = obj;
242 ///////////////////////////////////////////////////////////////////////////////
244 // Decompose a set of pattern matching rules and extract out the joins
245 // from each of the guards.
247 ///////////////////////////////////////////////////////////////////////////////
248 int decompose (MatchRules rules, Exp joins[], Exp guard_exp)
249 { Exp guards [MAX_INFERENCE_RULE_ARITY];
250 Exp selects [MAX_INFERENCE_RULE_ARITY];
253 { for_each (MatchRule, r, rules)
255 { MATCHrule(_,_,g,_,_): { guards[n++] = g; } }
258 guards[n++] = guard_exp;
260 if (n >= MAX_INFERENCE_RULE_ARITY)
261 bug ("%Linference rule arity exceeds %i in decompose()",
262 MAX_INFERENCE_RULE_ARITY);
264 // take all the guard expressions and decompose them.
265 int max_object = decompose(n, guards, selects, joins);
267 // rebuild the guards. Now they must all involve at most one
270 for_each (MatchRule, r, rules)
272 { MATCHrule(_,_,g,_,_): { g = selects[i]; i++; } }
279 ///////////////////////////////////////////////////////////////////////////////
281 // Top level method to compile a set of inference rules.
283 ///////////////////////////////////////////////////////////////////////////////
284 void InferenceCompiler::gen_inference_rules(Id id, InferenceRules rules)
285 { MemPoolMark marker = mem_pool.getMark(); // get heap marker
287 ////////////////////////////////////////////////////////////////////////////
288 // Mapping from type id to list of rules.
289 ////////////////////////////////////////////////////////////////////////////
290 HashTable rule_map(string_hash, string_equal, 129);
291 HashTable join_map(integer_hash, integer_equal);
293 ////////////////////////////////////////////////////////////////////////////
294 // Map the rules into an array
295 ////////////////////////////////////////////////////////////////////////////
296 int n = length(rules);
298 InferenceRule * Rs = (InferenceRule *)mem_pool[n * sizeof(InferenceRule)];
299 { int i = 0; for_each (InferenceRule,r,rules) Rs[i++] = r; }
301 max_arity = partition_inference_rules(n, Rs, rule_map, join_map);
302 pr ("const char * %s::name_of() const { return \"%s\"; }\n\n", id, id);
303 gen_alpha_tests (id, max_arity, rule_map);
304 gen_beta_tests (id, n, Rs, join_map);
305 gen_inference_actions (id, rules);
306 gen_dispatch_table (id, rule_map);
307 int m = gen_network_table (id, n, Rs, join_map);
308 gen_inference_axioms (id, rules);
309 gen_inference_constructor (id, m, rule_map);
311 ////////////////////////////////////////////////////////////////////////////
313 ////////////////////////////////////////////////////////////////////////////
314 mem_pool.setMark(marker); // reclaim memory
317 ///////////////////////////////////////////////////////////////////////////////
319 // Method to partition the left hand side of each inference rule according
320 // to the type of the pattern. Patterns of the same type are grouped
321 // together and compiled using the pattern matching compiler. The steps are:
322 // (a) Decompose the guard expression into selections (predicates on a
323 // single object object) and joins (predicates on 2 or more objects.)
324 // Both of these are hoisted upward as much as possible.
325 // (b) Perform type inference on the pattern.
326 // (c) Enter all the rules of the same type in the same entry of the rule
329 ///////////////////////////////////////////////////////////////////////////////
330 int partition_inference_rules
331 (int n, InferenceRule Rs[], HashTable& rule_map, HashTable& join_map)
333 int node_number = 0; // node number in network table.
334 for (int rule_no = 0; rule_no < n; rule_no++)
335 { int positive_clauses = 0;
336 int negative_clauses = 0;
338 { INFERENCErule(As, guard_exp, _):
339 { int object_count = 0;
340 Exp joins[MAX_INFERENCE_RULE_ARITY];
341 // decompose multi-object test
342 int n = decompose (As, joins, guard_exp);
344 for_each (MatchRule, r, As)
347 { MATCHrule (_,pat,_,_,action):
348 { match (r->ty = type_of(pat))
349 { TYCONty(DATATYPEtycon { id, qualifiers ... },_)
350 where qualifiers & QUALrelation:
352 HashTable::Entry * e = rule_map.lookup(id);
353 if (e) rule_map.insert(id,#[ r ... MatchRules(rule_map.value(e))]);
354 else rule_map.insert(id,#[ r ]);
355 r->rule_number = node_number;
356 if (joins[i] != NOexp) {
357 HashTable::Entry * e =
358 join_map.lookup((HashTable::Key)node_number);
360 join_map.insert((HashTable::Key)node_number,
361 BINOPexp("&&",(Exp)join_map.value(e),joins[i]));
363 join_map.insert((HashTable::Key)node_number,
367 if (positive_clauses == 0 && !r->negated) {
368 action = #[ INJECTdecl(node_number,LEFTdirection) ];
370 action = #[ INJECTdecl(node_number,RIGHTdirection) ];
373 if (r->negated) negative_clauses++;
374 else positive_clauses++;
378 { error("%Lnon-relation type %T in pattern: %p\n", ty, pat); }
386 if (max_arity < object_count) max_arity = object_count;
393 ///////////////////////////////////////////////////////////////////////////////
395 // Method to generate the alpha (single object) tests.
397 ///////////////////////////////////////////////////////////////////////////////
398 void InferenceCompiler::gen_alpha_tests
399 (Id id, int max_arity, HashTable& rule_map)
400 { pr ("%^%/%^// Single object tests for inference class %s%^%/"
401 "%^void %s::alpha_test(int predicate__, int i__, Fact * fact__)"
404 "%^switch (predicate__) {%+", id, id, max_arity);
405 { Bool save = same_selectors;
406 same_selectors = true;
408 foreach_entry(e, rule_map)
409 { Id ty_name = (Id)rule_map.key(e);
410 MatchRules rules = (MatchRules)rule_map.value(e);
412 "%^%s _0 = (%s)(f__[0] = fact__);",
413 type_number, ty_name, ty_name);
414 gen_match_stmt (#[], rules, MATCHall | MATCHnocheck);
418 same_selectors = save;
424 ///////////////////////////////////////////////////////////////////////////////
426 // Method to generate the beta tests(joins)
428 ///////////////////////////////////////////////////////////////////////////////
429 void InferenceCompiler::gen_beta_tests
430 (Id id, int n, InferenceRule rules[], HashTable& join_map)
431 { pr ("%^%/%^// Joins for inference class %s%^%/"
432 "%^int %s::beta_test(Join join__, Fact * f__[])"
434 "%^switch (join__) {%+", id, id);
436 for (int i = 0; i < n; i++)
438 { INFERENCErule(As,_,_):
439 { for(MatchRules rs = As; rs; rs = rs->#2)
440 { if (rs->#2 == #[] ||
441 rs->#1->rule_number != rs->#2->#1->rule_number)
442 { MatchRule r = rs->#1;
443 int rule_no = r->rule_number;
444 HashTable::Entry * e =
445 join_map.lookup((HashTable::Key)rule_no);
447 { Exp join = (Exp)join_map.value(e);
449 pr ("%^case %i: {%+", rule_no);
450 for_each (MatchRule, mr, As)
451 { pr ("%^%t _%i = (%t)f__[%i];",
452 mr->ty, "", j, mr->ty, "", j);
456 pr ("%^return %e;%-%^}", join);
464 pr ("%^default: return 0;"
469 ///////////////////////////////////////////////////////////////////////////////
471 // Generate the dispatch table.
473 ///////////////////////////////////////////////////////////////////////////////
474 void InferenceCompiler::gen_dispatch_table(Id id, HashTable& rule_map)
475 { pr ("%^%/%^// Dispatch table for inference class %s%^%/"
476 "%^const %s::RelationTable %s::relation_table[] = {%+",
478 { int type_number = 1;
480 foreach_entry(e, rule_map)
481 { Id ty_name = (Id)rule_map.key(e);
483 pr ("%^{ &a_%s::relation_tag, %i, \"%s\" }",
484 ty_name, type_number, ty_name);
492 ///////////////////////////////////////////////////////////////////////////////
494 // Generate the network table.
496 ///////////////////////////////////////////////////////////////////////////////
497 int InferenceCompiler::gen_network_table
498 (Id id, int n, InferenceRule rules[], HashTable& join_table)
499 { pr ("%^%/%^// Network table for inference class %s%^%/"
500 "%^const %s::Node %s::network_table[] = {%+",
505 for (int i = 0; i < n; i++)
507 { INFERENCErule(As,_,_):
508 { int max_arity = length(As);
509 int last_rule_number = -1;
511 for_each (MatchRule, r, As)
512 { if (last_rule_number != r->rule_number
513 && (max_arity > 1 || r->negated) ) {
515 Id typ = r->negated ? "Not" : "And";
517 join_table.contains((HashTable::Key)r->rule_number)
518 ? r->rule_number : 0;
519 pr ("%^{ %i, %i, ReteNet::Node::%s, %i, %i } /* %i */",
520 arity, max_arity, typ, join, r->rule_number + 1, entries);
525 last_rule_number = r->rule_number;
528 pr ("%^{ 0, %i, ReteNet::Node::Bot, %i, %i } /* %i */",
529 max_arity, 0, i, entries);
541 ///////////////////////////////////////////////////////////////////////////////
543 // Generate the axioms of the inference class.
545 ///////////////////////////////////////////////////////////////////////////////
546 void InferenceCompiler::gen_inference_axioms(Id id, InferenceRules rules)
548 pr ("%^%/%^// Axioms for inference class %s%^%/"
549 "%^void %s::initialise_axioms()"
552 for_each(InferenceRule, r, rules)
554 { INFERENCErule(#[], _, conclusions): { gen_conclusions(conclusions); }
561 ///////////////////////////////////////////////////////////////////////////////
563 // Generate the action routine of the inference class
565 ///////////////////////////////////////////////////////////////////////////////
566 void InferenceCompiler::gen_inference_actions(Id id, InferenceRules rules)
568 pr ("%^%/%^// Actions for inference class %s%^%/"
569 "%^void %s::action(%s::RuleId r__, Fact * f__[])"
571 "%^switch (r__) {%+",
575 for_each(InferenceRule, r, rules)
577 { INFERENCErule(mrs as ! #[], _, conclusions):
578 { pr ("%^case %i: {%+", rule_no);
580 for_each (MatchRule, mr, mrs)
581 { pr ("%^%t _%i = (%t)f__[%i];", mr->ty, "", i, mr->ty, "", i);
584 gen_conclusions(conclusions);
595 ///////////////////////////////////////////////////////////////////////////////
597 // Generate the conclusions of the inference class.
599 ///////////////////////////////////////////////////////////////////////////////
600 void InferenceCompiler::gen_conclusions(Conclusions cs)
601 { for_each (Conclusion, c, cs) gen_conclusion(c); }
603 ///////////////////////////////////////////////////////////////////////////////
605 // Generate one conclusion of the inference class.
607 ///////////////////////////////////////////////////////////////////////////////
608 void InferenceCompiler::gen_conclusion(Conclusion c)
610 { ASSERTaction e: { pr ("%^assert_fact(%e);\n", e); }
611 | RETRACTaction e: { pr ("%^retract_fact(%e);\n", e); }
612 | STMTaction decls: { pr ("%^%&", decls); }
616 ///////////////////////////////////////////////////////////////////////////////
618 // Generate the constructor of the inference class.
620 ///////////////////////////////////////////////////////////////////////////////
621 void InferenceCompiler::gen_inference_constructor
622 (Id id, int entries, HashTable& rule_map)
624 pr ("%^%/%^// Constructor for inference class %s%^%/"
626 "%^: Rete(%i,%s::network_table,%i,%s::relation_table)%+"
627 "%^{ initialise_axioms(); }%-%-\n\n",
629 entries, id, rule_map.size(), id);
632 ///////////////////////////////////////////////////////////////////////////////
634 // Generate the interface of a relation object.
636 ///////////////////////////////////////////////////////////////////////////////
637 void DatatypeClass::generate_inference_interface(CodeGen& C)
639 if (this != root) return;
643 "%^// Inference methods"
646 "%^static RelTag relation_tag;"
647 "%^virtual RelTag get_tag() const;"
651 ///////////////////////////////////////////////////////////////////////////////
653 // Generate the implementation of a relation object.
655 ///////////////////////////////////////////////////////////////////////////////
656 void DatatypeClass::generate_inference_implementation
657 (CodeGen& C, Tys tys, DefKind k)
659 if (this != root) return;
663 "%^// Relation datatype %s%P"
666 "%^Fact::RelTag %s%P::relation_tag = 0;"
667 "%^static InitialiseFact %s_dummy__(%s%P::relation_tag);"
668 "%^Fact::RelTag %s%P::get_tag() const"
669 " { return %s%P::relation_tag; }\n \n",
670 root->datatype_name, tys,
672 DatatypeCompiler::temp_vars.new_label(), class_name, tys,