add completion method
[libmarkov.git] / markov.myr
blob70d90bda9e52664f7bca3687d96391a37379f572
1 use std;
3 pkg markov =
4         type token = uint32
5         const Start : token = 0
6         const End : token = 0
8         type token_prediction = struct
9                 /* Total sample size */
10                 total_seen : int
12                 /* ("t", relative probability of "t") */
13                 possibilities : (token, int)[:]
14         ;;
17         type chain = struct
18                 /* What tokens are */
19                 tok_from_str : std.htab(byte[:], token)#
20                 str_from_tok : std.htab(token, byte[:])#
21                 next_tok : token
23                 /* Possible followups for a token sequence */
24                 options : std.htab(token[:], token_prediction#)#
26                 /* When generating, how many past tokens to keep */
27                 lookback_num : int
29                 /*
30                    Technical detail: when tokens are passed in, we
31                    slice them up in many, many different ways. For
32                    memory purposes, we want to use many slices of
33                    a single slice. Therefore, we (mostly) copy the
34                    input slice. But at clean-up time, we have to
35                    have something to delete. These are those. They
36                    are here to free at the end.
37                  */
38                 full_storage : token[:][:]
39         ;;
41         type mciter = struct
42                 c : chain#
43                 past : token[:]
44                 rng : std.option(std.rng#)
45         ;;
46         impl iterable mciter -> byte[:]
47         impl disposable mciter
49         /* First call mk() */
50         const mk : (lookback_num : int -> chain#)
52         /* Then call seed() once for each complete string of tokens you have */
53         const seed : (c : chain#, s : byte[:][:] -> void)
55         /* Call generate() to get an iterator */
56         const generate : (c : chain#, rng_seed : std.option(uint32) -> mciter)
58         /* complete() is like generate(), but from a given starting point */
59         const complete : (c : chain#, start : byte[:][:], rng_seed : std.option(uint32) -> mciter)
61         /* Finally, call cleanup() to free memory */
62         const cleanup : (c : chain# -> void)
65 impl iterable mciter -> byte[:] =
66         __iternext__ = {mci, valp
67                 /*
68                    First, using the tokens stored in mci.past, look
69                    up what to return
70                  */
71                 var wanted_token = End
72                 match std.htget(mci.c.options, mci.past)
73                 | `std.None:
74                         /* This should be impossible. */
75                         -> false
76                 | `std.Some tp:
77                         var k
78                         match mci.rng
79                         | `std.None: k = std.randnum() % tp.total_seen
80                         | `std.Some rng: k = std.rngrandnum(rng) % tp.total_seen
81                         ;;
82                         k = (k + tp.total_seen) % tp.total_seen
84                         /* Figure out which bucket we picked */
85                         var l = 0
86                         for (p_t, p_n) : tp.possibilities
87                                 l += p_n
88                                 if l > k
89                                         wanted_token = p_t
90                                         break
91                                 ;;
92                         ;;
94                         /* Now apply it */
95                         if wanted_token == Start || wanted_token == End
96                                 -> false
97                         ;;
99                         match std.htget(mci.c.str_from_tok, wanted_token)
100                         | `std.None: -> false
101                         | `std.Some s: valp# = s
102                         ;;
103                 ;;
105                 /* Second, cycle mci.past */
106                 if mci.past.len >= mci.c.lookback_num
107                         for var j : std.size = 0; j + 1< mci.past.len; j++
108                                 mci.past[j] = mci.past[j + 1]
109                         ;;
110                         mci.past[mci.past.len - 1] = wanted_token
111                 else
112                         std.slpush(&mci.past, wanted_token)
113                 ;;
115                 -> true
116         }
118         __iterfin__ = {mci, valp
119         }
122 impl disposable mciter =
123         __dispose__ = {mc
124                 std.slfree(mc.past)
125                 match mc.rng
126                 | `std.Some rng: std.freerng(rng)
127                 | `std.None:
128                 ;;
129         }
132 const mk = {lookback_num : int
133         lookback_num = std.max(lookback_num, 1)
134         var c : chain = [
135                 .options = std.mkht(),
136                 .tok_from_str = std.mkht(),
137                 .str_from_tok = std.mkht(),
138                 .full_storage = std.slalloc(0),
139                 .lookback_num = lookback_num,
140                 .next_tok = 2,
141         ]
142         -> std.mk(c)
145 const seed = {c : chain#, s : byte[:][:] -> void
146         /* Translate s to tokens (otherwise memory explodes) */
147         var toks : token[:] = std.slalloc(0)
148         std.slpush(&toks, Start)
149         for ss : s
150                 var this_t : token = c.next_tok
151                 var dup : byte[:] = std.sldup(ss)
152                 match std.htget(c.tok_from_str, dup)
153                 | `std.None:
154                         std.htput(c.tok_from_str, dup, this_t)
155                         std.htput(c.str_from_tok, this_t, dup)
156                         c.next_tok++
157                 | `std.Some t:
158                         this_t = t
159                 ;;
160                 std.slpush(&toks, this_t)
161         ;;
162         std.slpush(&toks, End)
163         std.slpush(&c.full_storage, toks)
165         /* Record having seen tokens */
167         /*
168            First, the beginning sequence, which is special because
169            we might not have enough tokens yet
170          */
171         for var len = 2; len < std.min(c.lookback_num + 1, toks.len); ++len
172                 see_sequence(c, toks[:len])
173         ;;
175         /* Now the intermediate sequences */
176         for var j = 0; j + c.lookback_num + 1 <= toks.len; ++j
177                 see_sequence(c, toks[j:j + c.lookback_num + 1])
178         ;;
181 const generate = {c : chain#, rng_seed : std.option(uint32)
182         var rng : std.option(std.rng#)
183         match rng_seed
184         | `std.None: rng = `std.None
185         | `std.Some i: rng = `std.Some std.mksrng(i)
186         ;;
187         -> [.c = c, .rng = rng, .past = std.sldup([Start][:])]
190 const complete = {c : chain#, start : byte[:][:], rng_seed : std.option(uint32)
191         var rng : std.option(std.rng#)
192         match rng_seed
193         | `std.None: rng = `std.None
194         | `std.Some i: rng = `std.Some std.mksrng(i)
195         ;;
196         var past : token[:] = [][:]
197         if start.len >= c.lookback_num
198                 for var j = start.len - c.lookback_num; j < start.len; ++j
199                         var t : token = -1
200                         match std.htget(c.tok_from_str, start[j])
201                         | `std.Some tt: t = tt
202                         | `std.None: /* well, this is going to fail */
203                         ;;
204                         std.slpush(&past, t)
205                 ;;
206         else
207                 std.slpush(&past, Start)
208                 for s : start
209                         var t : token = -1
210                         match std.htget(c.tok_from_str, s)
211                         | `std.Some tt: t = tt
212                         | `std.None: /* well, this is going to fail */
213                         ;;
214                         std.slpush(&past, t)
215                 ;;
216         ;;
217         -> [.c = c, .rng = rng, .past = past]
220 const see_sequence = {c : chain#, t : token[:]
221         var k : token[:] = t[:t.len - 1]
222         var next_t : token = t[t.len - 1]
223         match std.htget(c.options, k)
224         | `std.Some tp:
225                 tp.total_seen++
226                 for var j : std.size = 0; j < tp.possibilities.len; ++j
227                         var tt
228                         var n
229                         (tt, n) = tp.possibilities[j]
230                         if std.eq(tt, next_t)
231                                 tp.possibilities[j] = (tt, n + 1)
232                                 -> void
233                         ;;
234                 ;;
235                 std.slpush(&tp.possibilities, (next_t, 1))
236         | `std.None:
237                 var p : (token, int)[:] = std.slalloc(1)
238                 p[0] = (next_t, 1)
239                 var tp : token_prediction# = std.mk([.total_seen = 1, .possibilities = p])
240                 std.htput(c.options, k, tp)
241         ;;
244 const cleanup = {c : chain#
246         /* Delete options */
247         for (k, v) : std.byhtkeyvals(c.options)
248                 std.slfree(v.possibilities)
249                 std.free(v)
250         ;;
251         std.htfree(c.options)
253         /* The byte are owned by str_from_tok */
254         std.htfree(c.tok_from_str)
255         for (k, v) : std.byhtkeyvals(c.str_from_tok)
256                 std.slfree(v)
257         ;;
258         std.htfree(c.str_from_tok)
260         /* Delete the base memory that the options' keys are in */
261         for f : c.full_storage
262                 std.slfree(f)
263         ;;
264         std.slfree(c.full_storage)