Skip to content

Commit

Permalink
Make document root assertions grammar agnostic
Browse files Browse the repository at this point in the history
Currently the approach to checking if a node represents the document
root is to check if its `type` is `"source"`. While a lot of grammars do
this by convension it's not part of any spec and some grammars call it
something else - like `"program"`.

This replaces all these assertions with a comparison to see if the node
is :equal() to the tree :root() which is completely grammar agnostic.
  • Loading branch information
julienvincent committed Oct 13, 2024
1 parent d128eb2 commit 5de6eab
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 18 deletions.
5 changes: 3 additions & 2 deletions lua/nvim-paredit/api/barfing.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local ts_context = require("nvim-paredit.treesitter.context")
local ts_forms = require("nvim-paredit.treesitter.forms")
local ts_utils = require("nvim-paredit.treesitter.utils")
local traversal = require("nvim-paredit.utils.traversal")
local indentation = require("nvim-paredit.indentation")
local common = require("nvim-paredit.utils.common")
Expand Down Expand Up @@ -43,7 +44,7 @@ function M.barf_forwards(opts)
end

local form = traversal.find_closest_form_with_children(current_form, context)
if not form or form:type() == "source" then
if not form or ts_utils.is_document_root(form) then
return
end

Expand Down Expand Up @@ -117,7 +118,7 @@ function M.barf_backwards(opts)
end

local form = traversal.find_closest_form_with_children(current_form, context)
if not form or form:type() == "source" then
if not form or ts_utils.is_document_root(form) then
return
end

Expand Down
2 changes: 1 addition & 1 deletion lua/nvim-paredit/api/motions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ local function move_to_parent_form_edge(direction)
target_form = root:parent()
end

if not target_form or target_form:type() == "source" then
if not target_form or ts_utils.is_document_root(target_form) then
return
end

Expand Down
19 changes: 9 additions & 10 deletions lua/nvim-paredit/api/raising.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local ts_context = require("nvim-paredit.treesitter.context")
local ts_forms = require("nvim-paredit.treesitter.forms")
local ts_utils = require("nvim-paredit.treesitter.utils")

local M = {}

Expand All @@ -16,19 +17,18 @@ function M.raise_form()
end

local parent = current_form:parent()
if not parent or parent:type() == "source" then
if not parent or ts_utils.is_document_root(parent) then
return
end

local replace_text = vim.treesitter.get_node_text(current_form, 0)

local parent_range = { parent:range() }
-- stylua: ignore
vim.api.nvim_buf_set_text(
0,
parent_range[1],
parent_range[2],
parent_range[3],
parent_range[4],
parent_range[1], parent_range[2],
parent_range[3], parent_range[4],
vim.fn.split(replace_text, "\n")
)
vim.api.nvim_win_set_cursor(0, { parent_range[1] + 1, parent_range[2] })
Expand All @@ -43,19 +43,18 @@ function M.raise_element()
local current_node = ts_forms.get_node_root(context.node, context)

local parent = current_node:parent()
if not parent or parent:type() == "source" then
if not parent or ts_utils.is_document_root(parent) then
return
end

local replace_text = vim.treesitter.get_node_text(current_node, 0)

local parent_range = { parent:range() }
-- stylua: ignore
vim.api.nvim_buf_set_text(
0,
parent_range[1],
parent_range[2],
parent_range[3],
parent_range[4],
parent_range[1], parent_range[2],
parent_range[3], parent_range[4],
vim.fn.split(replace_text, "\n")
)
vim.api.nvim_win_set_cursor(0, { parent_range[1] + 1, parent_range[2] })
Expand Down
4 changes: 2 additions & 2 deletions lua/nvim-paredit/api/wrap.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local function find_parent_form(element, opts)
parent = nearest_form:parent()
end

if parent and parent:type() ~= "source" then
if parent and not ts_utils.is_document_root(parent) then
return ts_forms.find_nearest_form(parent, {
captures = opts.captures,
use_source = false,
Expand Down Expand Up @@ -106,7 +106,7 @@ function M.wrap_enclosing_form_under_cursor(prefix, suffix)
return
end

if not use_direct_parent and form:type() ~= "source" then
if not use_direct_parent and not ts_utils.is_document_root(form) then
form = find_parent_form(current_element, context)
end

Expand Down
3 changes: 2 additions & 1 deletion lua/nvim-paredit/indentation/native.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local ts_context = require("nvim-paredit.treesitter.context")
local ts_forms = require("nvim-paredit.treesitter.forms")
local ts_utils = require("nvim-paredit.treesitter.utils")
local traversal = require("nvim-paredit.utils.traversal")
local utils = require("nvim-paredit.indentation.utils")
local common = require("nvim-paredit.utils.common")
Expand Down Expand Up @@ -102,7 +103,7 @@ local function indent_barf(event)
local lines = utils.find_affected_lines(node, utils.get_node_line_range(node_range))

local delta
if parent:type() == "source" then
if ts_utils.is_document_root(parent) then
delta = node_range[2]
else
local row
Expand Down
6 changes: 5 additions & 1 deletion lua/nvim-paredit/treesitter/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ local common = require("nvim-paredit.utils.common")

local M = {}

function M.is_document_root(node)
return node and node:tree():root():equal(node)
end

-- Find the root node of the tree `node` is a member of, excluding the root
-- 'source' document.
function M.find_local_root(node)
local current = node
while true do
local next = current:parent()
if not next or next:type() == "source" then
if not next or M.is_document_root(next) then
break
end
current = next
Expand Down
2 changes: 1 addition & 1 deletion lua/nvim-paredit/utils/traversal.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end

function M.find_closest_form_with_children(current_node, opts)
local form = ts_forms.get_form_inner(current_node, opts)
if form:named_child_count() > 0 and current_node:type() ~= "source" then
if form:named_child_count() > 0 and not ts_utils.is_document_root(current_node) then
return form
end

Expand Down

0 comments on commit 5de6eab

Please sign in to comment.