Initial packaging
[pkg-ocaml-deriving-ocsigen.git] / syntax / functor_class.ml
blob06fced402a41a037ab17bcee5443eb0391484796
1 (* Copyright Jeremy Yallop 2007.
2 This file is free software, distributed under the MIT license.
3 See the file COPYING for details.
4 *)
6 open Pa_deriving_common.Defs
7 open Camlp4.PreCast
9 module Description : ClassDescription = struct
10 let classname = "Functor"
11 let runtimename = "Deriving_Functor"
12 let default_module = None
13 let allow_private = false
14 let predefs = [
15 ["list"], "list";
16 ["ref"], "ref";
17 ["option"], "option";
19 let depends = []
20 end
22 module InContext (C : sig val _loc : Camlp4.PreCast.Loc.t end) =
23 struct
24 open C
25 open Pa_deriving_common.Type
26 open Pa_deriving_common.Utils
27 open Pa_deriving_common.Base
29 module Helpers = Pa_deriving_common.Base.InContext(C)(Description)
30 open Helpers
31 open Description
33 let param_map context : string NameMap.t =
34 List.fold_right
35 (fun (name,_) map -> NameMap.add name ("f_" ^ name) map)
36 context.params
37 NameMap.empty
39 let tdec, sigdec =
40 let dec context name =
41 ("f", context.params,
42 `Expr (`Constr ([name], List.map (fun p -> `Param p) context.params)), [], false)
44 (fun context name -> Untranslate.decl (dec context name)),
45 (fun context name -> Untranslate.sigdecl (dec context name))
47 let wrapper context name expr =
48 let param_map = param_map context in
49 let patts :Ast.patt list =
50 List.map
51 (fun (name,_) -> <:patt< $lid:NameMap.find name param_map$ >>)
52 context.params in
53 let rhs =
54 List.fold_right (fun p e -> <:expr< fun $p$ -> $e$ >>) patts expr in
55 <:module_expr< struct
56 type $tdec context name$
57 let map = $rhs$
58 end >>
60 prototype: [[t]] : t -> t[b_i/a_i]
63 [[a_i]] = f_i
65 [[C1|...CN]] = function [[C1]] ... [[CN]] sum
66 [[`C1|...`CN]] = function [[`C1]] ... [[`CN]] variant
68 [[{t1,...tn}]] = fun (t1,tn) -> ([[t1]],[[tn]]) tuple
69 [[{l1:t1; ... ln:tn}]] =
70 fun {l1=t1;...ln=tn} -> {l1=[[t1]];...ln=[[tn]]} record
72 [[(t1,...tn) c]] = c_map [[t1]]...[[tn]] constructor
74 [[a -> b]] = f . [[a]] (where a_i \notin fv(b)) function
76 [[C0]] = C0->C0 nullary constructors
77 [[C1 (t1...tn)]] = C1 t -> C0 ([[t1]] t1...[[tn]] tn) unary constructor
78 [[`C0]] = `C0->`C0 nullary tag
79 [[`C1 t]] = `C1 t->`C0 [[t]] t unary tag
81 let rec polycase context = function
82 | Tag (name, []) -> <:match_case< `$name$ -> `$name$ >>
83 | Tag (name, es) -> <:match_case< `$name$ x -> `$name$ ($expr context (`Tuple es)$ x) >>
84 | Extends t ->
85 let patt, guard, exp = cast_pattern context.argmap t in
86 <:match_case< $patt$ when $guard$ -> $expr context t$ $exp$ >>
88 and expr context : Pa_deriving_common.Type.expr -> Ast.expr = function
89 | t when not (contains_tvars t) -> <:expr< fun x -> x >>
90 | `Param (p,_) -> <:expr< $lid:NameMap.find p (param_map context)$ >>
91 | `Function (f,t) when not (contains_tvars t) ->
92 <:expr< fun f x -> f ($expr context f$ x) >>
93 | `Constr (qname, ts) ->
94 let qname =
95 try [runtimename ; List.assoc qname predefs]
96 with Not_found -> qname in
97 List.fold_left
98 (fun fn arg -> <:expr< $fn$ $expr context arg$ >>)
99 <:expr< $id:modname_from_qname ~qname ~classname$.map >>
101 | `Tuple ts -> tup context ts
102 | _ -> raise (Underivable "Functor cannot be derived for this type")
104 and tup context = function
105 | [t] -> expr context t
106 | ts ->
107 let args, exps =
108 (List.fold_right2
109 (fun t n (p,e) ->
110 let v = Printf.sprintf "t%d" n in
111 Ast.PaCom (_loc, <:patt< $lid:v$ >>, p),
112 Ast.ExCom (_loc, <:expr< $expr context t$ $lid:v$ >>, e))
114 (List.range 0 (List.length ts))
115 (<:patt< >>, <:expr< >>)) in
116 let pat, exp = Ast.PaTup (_loc, args), Ast.ExTup (_loc, exps) in
117 <:expr< fun $pat$ -> $exp$ >>
119 and case context = function
120 | (name, []) -> <:match_case< $uid:name$ -> $uid:name$ >>
121 | (name, args) ->
122 let f = tup context args
123 and _, tpatt, texp = tuple (List.length args) in
124 <:match_case< $uid:name$ $tpatt$ -> let $tpatt$ = ($f$ $texp$) in $uid:name$ ($texp$) >>
126 and field context (name, (_,t), _) : Ast.expr =
127 <:expr< $expr context t$ $lid:name$ >>
129 let rhs context = function
130 |`Fresh (_, _, `Private) -> raise (Underivable "Functor cannot be derived for private types")
131 |`Fresh (_, Sum summands, _) ->
132 <:expr< function $list:List.map (case context) summands$ >>
133 |`Fresh (_, Record fields, _) ->
134 <:expr< fun $record_pattern fields$ ->
135 $record_expr (List.map (fun ((l,_,_) as f) -> (l,field context f)) fields)$ >>
136 |`Expr e -> expr context e
137 |`Variant (_, tags) ->
138 <:expr< function $list:List.map (polycase context) tags$ | _ -> assert false >>
139 | `Nothing -> raise (Underivable "Cannot generate functor instance for the empty type")
142 let maptype context name =
143 let param_map = param_map context in
144 let ctor_in = `Constr ([name], List.map (fun p -> `Param p) context.params) in
145 let ctor_out = substitute param_map ctor_in (* c[f_i/a_i] *) in
146 List.fold_right (* (a_i -> f_i) -> ... -> c[a_i] -> c[f_i/a_i] *)
147 (fun (p,_) out ->
148 (<:ctyp< ('$lid:p$ -> '$lid:NameMap.find p param_map$) -> $out$>>))
149 context.params
150 (Untranslate.expr (`Function (ctor_in, ctor_out)))
152 let signature context name : Ast.sig_item list =
153 [ <:sig_item< type $list:sigdec context name$ >>;
154 <:sig_item< val map : $maptype context name$ >> ]
156 let decl context (name, _, r, _, _) : Camlp4.PreCast.Ast.module_binding =
157 if name = "f" then
158 raise (Underivable ("deriving: Functor cannot be derived for types called `f'.\n"
159 ^"Please change the name of your type and try again."))
160 else
161 <:module_binding<
162 $uid:classname ^ "_" ^ name$
163 : sig $list:signature context name$ end
164 = $wrapper context name (rhs context r)$ >>
166 let gen_sig context (tname, params, _, _, generated) =
167 if tname = "f" then
168 raise (Underivable ("deriving: Functor cannot be derived for types called `f'.\n"
169 ^"Please change the name of your type and try again."))
170 else
171 if generated then
172 <:sig_item< >>
173 else
174 <:sig_item< module $uid:classname ^ "_" ^ tname$ :
175 sig type $tdec context tname$ val map : $maptype context tname$ end >>
177 let generate decls =
178 let context = setup_context decls in
179 <:str_item< module rec $list:List.map (decl context) decls$ >>
181 let generate_sigs decls =
182 let context = setup_context decls in
183 <:sig_item< $list:List.map (gen_sig context) decls$>>
185 let generate_expr _ = assert false
189 module Functor = Pa_deriving_common.Base.Register(Description)(InContext)