Skip to content

Commit 733c20f

Browse files
committed
Make document root assertions grammar agnostic
Currently the approach to checking if a node represents the document root is to check if its `type` is `"source"`. While a lot of grammars do this by convension it's not part of any spec and some grammars call it something else - like `"program"`. This replaces all these assertions with a comparison to see if the node is :equal() to the tree :root() which is completely grammar agnostic.
1 parent 33b8032 commit 733c20f

File tree

7 files changed

+23
-18
lines changed

7 files changed

+23
-18
lines changed

lua/nvim-paredit/api/barfing.lua

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
local ts_context = require("nvim-paredit.treesitter.context")
22
local ts_forms = require("nvim-paredit.treesitter.forms")
3+
local ts_utils = require("nvim-paredit.treesitter.utils")
34
local traversal = require("nvim-paredit.utils.traversal")
45
local indentation = require("nvim-paredit.indentation")
56
local common = require("nvim-paredit.utils.common")
@@ -43,7 +44,7 @@ function M.barf_forwards(opts)
4344
end
4445

4546
local form = traversal.find_closest_form_with_children(current_form, context)
46-
if not form or form:type() == "source" then
47+
if not form or ts_utils.is_document_root(form) then
4748
return
4849
end
4950

@@ -117,7 +118,7 @@ function M.barf_backwards(opts)
117118
end
118119

119120
local form = traversal.find_closest_form_with_children(current_form, context)
120-
if not form or form:type() == "source" then
121+
if not form or ts_utils.is_document_root(form) then
121122
return
122123
end
123124

lua/nvim-paredit/api/motions.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ local function move_to_parent_form_edge(direction)
223223
target_form = root:parent()
224224
end
225225

226-
if not target_form or target_form:type() == "source" then
226+
if not target_form or ts_utils.is_document_root(target_form) then
227227
return
228228
end
229229

lua/nvim-paredit/api/raising.lua

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
local ts_context = require("nvim-paredit.treesitter.context")
22
local ts_forms = require("nvim-paredit.treesitter.forms")
3+
local ts_utils = require("nvim-paredit.treesitter.utils")
34

45
local M = {}
56

@@ -16,19 +17,18 @@ function M.raise_form()
1617
end
1718

1819
local parent = current_form:parent()
19-
if not parent or parent:type() == "source" then
20+
if not parent or ts_utils.is_document_root(parent) then
2021
return
2122
end
2223

2324
local replace_text = vim.treesitter.get_node_text(current_form, 0)
2425

2526
local parent_range = { parent:range() }
27+
-- stylua: ignore
2628
vim.api.nvim_buf_set_text(
2729
0,
28-
parent_range[1],
29-
parent_range[2],
30-
parent_range[3],
31-
parent_range[4],
30+
parent_range[1], parent_range[2],
31+
parent_range[3], parent_range[4],
3232
vim.fn.split(replace_text, "\n")
3333
)
3434
vim.api.nvim_win_set_cursor(0, { parent_range[1] + 1, parent_range[2] })
@@ -43,19 +43,18 @@ function M.raise_element()
4343
local current_node = ts_forms.get_node_root(context.node, context)
4444

4545
local parent = current_node:parent()
46-
if not parent or parent:type() == "source" then
46+
if not parent or ts_utils.is_document_root(parent) then
4747
return
4848
end
4949

5050
local replace_text = vim.treesitter.get_node_text(current_node, 0)
5151

5252
local parent_range = { parent:range() }
53+
-- stylua: ignore
5354
vim.api.nvim_buf_set_text(
5455
0,
55-
parent_range[1],
56-
parent_range[2],
57-
parent_range[3],
58-
parent_range[4],
56+
parent_range[1], parent_range[2],
57+
parent_range[3], parent_range[4],
5958
vim.fn.split(replace_text, "\n")
6059
)
6160
vim.api.nvim_win_set_cursor(0, { parent_range[1] + 1, parent_range[2] })

lua/nvim-paredit/api/wrap.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ local function find_parent_form(element, opts)
2121
parent = nearest_form:parent()
2222
end
2323

24-
if parent and parent:type() ~= "source" then
24+
if parent and not ts_utils.is_document_root(parent) then
2525
return ts_forms.find_nearest_form(parent, {
2626
captures = opts.captures,
2727
use_source = false,
@@ -106,7 +106,7 @@ function M.wrap_enclosing_form_under_cursor(prefix, suffix)
106106
return
107107
end
108108

109-
if not use_direct_parent and form:type() ~= "source" then
109+
if not use_direct_parent and not ts_utils.is_document_root(form) then
110110
form = find_parent_form(current_element, context)
111111
end
112112

lua/nvim-paredit/indentation/native.lua

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
local ts_context = require("nvim-paredit.treesitter.context")
22
local ts_forms = require("nvim-paredit.treesitter.forms")
3+
local ts_utils = require("nvim-paredit.treesitter.utils")
34
local traversal = require("nvim-paredit.utils.traversal")
45
local utils = require("nvim-paredit.indentation.utils")
56
local common = require("nvim-paredit.utils.common")
@@ -102,7 +103,7 @@ local function indent_barf(event)
102103
local lines = utils.find_affected_lines(node, utils.get_node_line_range(node_range))
103104

104105
local delta
105-
if parent:type() == "source" then
106+
if ts_utils.is_document_root(parent) then
106107
delta = node_range[2]
107108
else
108109
local row

lua/nvim-paredit/treesitter/utils.lua

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ local common = require("nvim-paredit.utils.common")
22

33
local M = {}
44

5+
function M.is_document_root(node)
6+
return node and node:tree():root():equal(node)
7+
end
8+
59
-- Find the root node of the tree `node` is a member of, excluding the root
610
-- 'source' document.
711
function M.find_local_root(node)
812
local current = node
913
while true do
1014
local next = current:parent()
11-
if not next or next:type() == "source" then
15+
if not next or M.is_document_root(next) then
1216
break
1317
end
1418
current = next

lua/nvim-paredit/utils/traversal.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ end
4949

5050
function M.find_closest_form_with_children(current_node, opts)
5151
local form = ts_forms.get_form_inner(current_node, opts)
52-
if form:named_child_count() > 0 and current_node:type() ~= "source" then
52+
if form:named_child_count() > 0 and not ts_utils.is_document_root(current_node) then
5353
return form
5454
end
5555

0 commit comments

Comments
 (0)