migrate to 12.0
[ta-parkour.git] / parkour / node.lua
blob29541e6a2e8f7e10078df9d28cc58af314e9170e
1 -- SPDX-License-Identifier: GPL-3.0-or-later
2 -- © 2020 Georgi Kirilov
4 require'lpeg'
5 local l = lpeg
6 local P, S, V, Cc, Cmt, Cp, Ct = l.P, l.S, l.V, l.Cc, l.Cmt, l.Cp, l.Ct
8 local hspace = S" \t"
9 local newline = S"\n\r"
10 local function past(_, position, pos) return position <= pos end
12 local M = {}
14 local function startof(node) return node.start end
15 local function finishof(node) return node.finish end
17 function M.new(read)
19 local function at_pos(_, position, start, finish, range)
20 if range.start + 1 >= start and range.start < finish and range.finish < finish then
21 return position, start - 1, finish - 1
22 end
23 end
25 local function gt(pos, val)
26 return val and pos > val
27 end
29 local function le(pos, val)
30 return val and pos <= val
31 end
33 -- binary search a list for the nearest node before pos
34 local function before(t, pos, key, skip)
35 local left, right = 1, #t
36 while left <= right do
37 local m = math.floor(left + (right - left) / 2)
38 local m_val = key(t[m])
39 if gt(pos, m_val) and (m == #t or le(pos, key(t[m + 1]))) then
40 if skip then
41 while (t[m] and skip(t[m])) do
42 m = m - 1
43 end
44 end
45 return t[m], m
46 end
47 if le(pos, m_val) then right = m - 1 else left = m + 1 end
48 end
49 end
51 -- binary search a list for the nearest node after pos
52 local function after(t, pos, key, skip)
53 if t.is_root and (#t == 0 or pos >= t[#t].start) then
54 -- XXX: if we are at the last parsed top-level node, try to access the next one
55 -- to get it parsed as well. Otherwise the search will stop prematurely.
56 local _ = t[#t + 1]
57 end
58 local left, right = 1, #t
59 while left <= right do
60 local m = math.floor(left + (right - left) / 2)
61 local m_val = key(t[m])
62 if le(pos, m_val) and (m == 1 or gt(pos, key(t[m - 1]))) then
63 if skip then
64 while (t[m] and skip(t[m])) do
65 m = m + 1
66 end
67 end
68 return t[m], m
69 end
70 if le(pos, m_val) then right = m - 1 else left = m + 1 end
71 end
72 end
74 -- binary search a list for the node that contains pos
75 local function around(t, range)
76 if not (range.start and range.finish) then return end
77 local left, right = 1, #t
78 while left <= right do
79 local m = math.floor(left + (right - left) / 2)
80 local e = t[m]
81 local nxt = t[m + 1]
82 if e.start and range.start >= e.start and e.finish and range.finish <= e.finish + 1 and
83 (not nxt or nxt.start > range.start) -- if adjoined - asdf|(qwer) - this will prefer (qwer)
84 then return e, m end
85 if e.start and range.start < e.start then right = m - 1 else left = m + 1 end
86 end
87 end
89 local function find_after(t, selection, pred, stop)
90 local _, n = t.after(selection.start, startof)
91 while t[n] do
92 if pred(t[n]) then return t[n], n
93 elseif stop and stop(t[n]) then return end
94 n = n + 1
95 end
96 end
98 local function find_before(t, selection, pred, stop)
99 local _, n = t.before(selection.finish, finishof)
100 while t[n] do
101 if pred(t[n]) then return t[n], n
102 elseif stop and stop(t[n]) then return end
103 n = n - 1
107 local function intersects(t, range)
108 local start = t.start + (t.p and #t.p or 0)
109 return start >= range.start and t.finish >= range.finish and start < range.finish
110 or start < range.start and t.finish < range.finish and range.start <= t.finish
113 local function touches(t, range)
114 return t.start + (t.p and #t.p or 0) >= range.start and t.start <= range.start or t.finish == range.finish - 1
117 local function contains(t, range)
118 if t.is_list then
119 return t.start + (t.p and #t.p or 0) < range.start and t.finish > range.finish - 1
120 else
121 return t.start <= range.start and t.finish >= range.finish - 1
125 local function sexp_around(t, range)
126 if t.is_root or (t.is_list and contains(t, range)) then
127 local child, nth = t.around(range)
128 if child and child.is_list and not touches(child, range) then
129 return sexp_around(child, range)
130 else
131 return child, t, nth
136 -- save the current logical position in terms of sexps, not offsets
137 local function _sexp_path(t, range, indices, nodes)
138 if t.is_root or (t.is_list and contains(t, range)) then
139 local child, nth = t.around(range)
140 table.insert(indices, nth)
141 table.insert(nodes, child)
142 if child and child.is_list and not touches(child, range) then
143 _sexp_path(child, range, indices, nodes)
144 elseif not child then
145 table.insert(indices, false)
146 table.insert(nodes, false)
149 return indices, nodes
152 local function sexp_path(t, range)
153 return _sexp_path(t, range, {}, {})
156 -- find a sexp by following a previously-saved "sexp path"
157 local function goto_path(root, path)
158 local dest = root
159 local parent, nth
160 for n, i in ipairs(path) do
161 dest = dest[i]
162 nth = i
163 parent = not parent and root or parent[path[n - 1]]
165 return dest, parent, nth
168 local function catchup(tree, range)
169 if range.finish and range.finish >= tree.first_invalid then
170 tree.parse_to(range.finish + 1)
174 local function ensure_reparse(func)
175 return function(t)
176 return function(range, ...)
177 -- TODO: get rid of this type check. It is here because .after and .before take pos, not range
178 catchup(t, type(range) == "number" and {finish = range} or range)
179 return func(t, range, ...)
184 local root_methods = {
185 around = around,
186 before = before,
187 after = after,
188 find_after = find_after,
189 find_before = find_before,
190 sexp_at = sexp_around,
191 sexp_path = sexp_path,
194 local function bind(func)
195 return function(t)
196 return function(...)
197 return func(t, ...)
202 local list_methods = {
203 is_list = function(_) return true end,
204 is_empty = function(t) return #t == 0 end,
205 around = bind(around),
206 before = bind(before),
207 after = bind(after),
208 find_after = bind(find_after),
209 find_before = bind(find_before),
212 local atom_methods = {}
214 local quasiatom_methods = {} -- "quasi atom" is a word inside a string or comment
216 local function text(t)
217 local len = t.finish + 1 - t.start
218 return read(t.start, len)
221 atom_methods.text = text
222 list_methods.text = text
223 quasiatom_methods.text = text
225 local function dispatch(vtable)
226 return function(self, key)
227 if vtable[key] then
228 return vtable[key](self)
233 local quasiatom_node = {
234 __index = dispatch(quasiatom_methods)
237 local quasilist_methods = (function(word) -- "quasi list" is a string or comment
238 return {
239 start_before = function(t, pos)
240 local base = t.start + #t.d + (t.p and #t.p or 0)
241 local stops = P{Ct(((1 - word)^0 * Cp() * Cmt(Cc(pos - base), past) * word)^0)}:match(t.itext)
242 local newpos = stops and stops[#stops]
243 return newpos and (base + newpos - 1)
244 end,
246 start_after = function(t, pos)
247 local base = t.start + #t.d + (t.p and #t.p or 0)
248 if pos >= base then
249 local stop = P{(1 - word) * Cp() * word + 1 * V(1)}:match(t.itext, pos - base + 1)
250 return stop and (base + stop - 1)
251 else
252 local stop = P{(1 - word)^1 * Cp() * word}:match(t.itext)
253 return stop and (base + stop - 1) or base
255 end,
257 finish_before = function(t, pos)
258 local base = t.start + #t.d + (t.p and #t.p or 0)
259 local stops = P{Ct(((1 - word)^0 * word * Cp() * Cmt(Cc(pos - base), past))^0)}:match(t.itext)
260 local newpos = stops and stops[#stops]
261 return newpos and (base + newpos - 1)
262 end,
264 finish_after = function(t, pos)
265 local base = t.start + #t.d + (t.p and #t.p or 0)
266 local stop = P{word * Cp() + 1 * V(1)}:match(t.itext, (pos > base and pos - base + 1 or nil))
267 return stop and (base + stop - 1)
268 end,
270 word_at = function(t, range)
271 local r = {}
272 local base = t.start + #t.d + (t.p and #t.p or 0)
273 r.start = range.start - base
274 r.finish = range.finish - base
275 local start, finish = P{Cmt(Cp() * word * Cp() * Cc(r), at_pos) +
276 Cmt(1 * Cc(r.finish + 1), past) * V(1)}:match(t.itext)
277 if not start then return end
278 local node = {start = base + start, finish = base + finish - 1}
279 return setmetatable(node, quasiatom_node)
280 end,
282 spaces_after = function(t, pos)
283 local base = t.start + #t.d + (t.p and #t.p or 0)
284 if pos >= base then
285 local stop = P{(hspace + newline)^1 * Cp()}:match(t.itext, pos - base + 1)
286 return stop and (base + stop - 1)
288 end,
290 spaces_before = function(t, pos)
291 local base = t.start + #t.d + (t.p and #t.p or 0)
292 local stops = P{Ct(((1 - word)^0 * word * Cp() * Cmt(Cc(pos - base), past))^0) * (hspace + newline)^1}:match(t.itext)
293 local newpos = stops and stops[#stops]
294 return newpos and (base + newpos - 1)
295 end,
297 end)(P{"\\" * P(1 - hspace - newline) + P(1 - hspace - newline)^1})
299 local function quasilist_is_empty(t)
300 return not t.finish_after(t.start + (t.p and #t.p or 0) + #t.d)
303 local function quasilist_new(opposite)
304 local methods = {}
305 for k, v in pairs(quasilist_methods) do
306 methods[k] = bind(v)
308 methods.is_empty = quasilist_is_empty
309 methods.text = text
310 function methods.itext(t)
311 local start = t.start + (t.p and #t.p or 0) + #t.d
312 local len = t.finish + 1 - #opposite[t.d] - start
313 return read(start, len) or ""
316 local quasilist_node = {
317 __index = dispatch(methods)
320 return quasilist_node
323 local atom_node = {
324 __index = dispatch(atom_methods)
327 local list_node = {
328 __index = dispatch(list_methods)
331 local function root_rewind(t)
332 return function(index)
333 assert(index <= t.first_invalid, ("%d > %d - You can only rewind backwards!"):format(index, t.first_invalid))
334 t.first_invalid = index
338 local function root_is_parsed(t)
339 return function(index)
340 return index < t.first_invalid
344 local function root_goto_path(t)
345 return function(path)
346 local range = {start = t[path[1]].finish, finish = t[path[1]].finish}
347 catchup(t, range)
348 return goto_path(t, path)
352 local function root_new(opposite)
353 local methods = {}
354 for k, v in pairs(root_methods) do
355 methods[k] = ensure_reparse(v)
358 methods.rewind = root_rewind
359 methods.is_parsed = root_is_parsed
360 methods.goto_path = root_goto_path
362 local function unbalanced_delimiters(t, range)
363 if t.d or t.is_root then
364 for i = 1, #t do
365 if t[i].d and (contains(t[i], range) or intersects(t[i], range)) then
366 unbalanced_delimiters(t[i], range)
369 if not t.is_root then
370 local start = t.start + (t.p and #t.p or 0) + #t.d
371 local finish = t.finish + 1 - #opposite[t.d]
372 if start <= range.start and finish < range.finish then
373 coroutine.yield({start = finish, finish = math.min(t.finish + 1, range.finish), closing = true})
374 elseif start > range.start and t.finish >= range.finish then
375 coroutine.yield({start = math.max(range.start, t.start), finish = start, opening = true})
381 function methods.unbalanced_delimiters(t)
382 local slice_at = coroutine.wrap(unbalanced_delimiters)
383 return function(range)
384 local skips = {}
385 catchup(t, range)
386 repeat
387 local slice_pos = slice_at(t, range)
388 if slice_pos then table.insert(skips, slice_pos) end
389 until not slice_pos
390 return skips
394 return methods
397 return {
398 atom = atom_node,
399 list = list_node,
400 quasilist = quasilist_new,
401 root = root_new,
402 _before = before,
403 _around = around,
408 return M