diff --git a/README.md b/README.md index dc1bf8a..80cf9c3 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,25 @@ require("nvim-paredit").setup({ -- defaults to all supported file types including custom lang -- extensions (see next section) filetypes = { "clojure" }, + + -- This controls where the cursor is placed when performing slurp/barf operations + -- + -- - "remain" - It will never change the cursor position, keeping it in the same place + -- - "follow" - It will always place the cursor on the form edge that was moved + -- - "auto" - A combination of remain and follow, it will try keep the cursor in the original position + -- unless doing so would result in the cursor no longer being within the original form. In + -- this case it will place the cursor on the moved edge cursor_behaviour = "auto", -- remain, follow, auto + + -- This controls how nvim-paredit handles indentation when performing operations which + -- should change the indentation of the form (such as when slurping or barfing). + -- + -- When set to true then it will attempt to fix the indentation of nodes operated on. + auto_indent = true, + -- A function that will be called after a slurp/barf if you want to provide a custom indentation + -- implementation. + indent_fn = nil, + -- list of default keybindings keys = { [">)"] = { paredit.api.slurp_forwards, "Slurp forwards" }, diff --git a/lua/nvim-paredit/api/barfing.lua b/lua/nvim-paredit/api/barfing.lua index 329987f..5179c5a 100644 --- a/lua/nvim-paredit/api/barfing.lua +++ b/lua/nvim-paredit/api/barfing.lua @@ -1,4 +1,5 @@ local traversal = require("nvim-paredit.utils.traversal") +local indentation = require("nvim-paredit.indentation") local common = require("nvim-paredit.utils.common") local ts = require("nvim-treesitter.ts_utils") local config = require("nvim-paredit.config") @@ -28,11 +29,11 @@ function M.barf_forwards(opts) local child if opts.reversed then child = traversal.get_first_child_ignoring_comments(form, { - lang = lang + lang = lang, }) else child = traversal.get_last_child_ignoring_comments(form, { - lang = lang + lang = lang, }) end if not child then @@ -42,7 +43,7 @@ function M.barf_forwards(opts) local edges = lang.get_form_edges(form) local sibling = traversal.get_prev_sibling_ignoring_comments(child, { - lang = lang + lang = lang, }) local end_pos @@ -55,6 +56,7 @@ function M.barf_forwards(opts) local buf = vim.api.nvim_get_current_buf() local range = edges.right.range + -- stylua: ignore vim.api.nvim_buf_set_text( buf, range[1], range[2], @@ -63,6 +65,7 @@ function M.barf_forwards(opts) ) local text = edges.right.text + -- stylua: ignore vim.api.nvim_buf_set_text(buf, end_pos[1], end_pos[2], end_pos[1], end_pos[2], @@ -77,6 +80,19 @@ function M.barf_forwards(opts) vim.api.nvim_win_set_cursor(0, { end_pos[1] + 1, end_pos[2] }) end end + + local event = { + type = "barf-forwards", + -- stylua: ignore + parent_range = { + edges.left.range[1], edges.left.range[2], + end_pos[1], end_pos[2], + }, + reversed = false, + indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, + lang = lang, + } + indentation.handle_indentation(event, opts) end function M.barf_backwards(opts) @@ -99,7 +115,7 @@ function M.barf_backwards(opts) end local child = traversal.get_first_child_ignoring_comments(form, { - lang = lang + lang = lang, }) if not child then return @@ -108,7 +124,7 @@ function M.barf_backwards(opts) local edges = lang.get_form_edges(lang.get_node_root(form)) local sibling = traversal.get_next_sibling_ignoring_comments(child, { - lang = lang + lang = lang, }) local end_pos @@ -121,6 +137,7 @@ function M.barf_backwards(opts) local buf = vim.api.nvim_get_current_buf() local text = edges.left.text + -- stylua: ignore vim.api.nvim_buf_set_text(buf, end_pos[1], end_pos[2], end_pos[1], end_pos[2], @@ -128,6 +145,7 @@ function M.barf_backwards(opts) ) local range = edges.left.range + -- stylua: ignore vim.api.nvim_buf_set_text( buf, range[1], range[2], @@ -143,6 +161,16 @@ function M.barf_backwards(opts) vim.api.nvim_win_set_cursor(0, { end_pos[1] + 1, end_pos[2] }) end end + + local event = { + type = "barf-backwards", + -- stylua: ignore + parent_range = { + end_pos[1], end_pos[2], + edges.right.range[1], edges.right.range[2], + }, + } + indentation.handle_indentation(event, opts) end return M diff --git a/lua/nvim-paredit/api/slurping.lua b/lua/nvim-paredit/api/slurping.lua index 47fb369..e16cbf5 100644 --- a/lua/nvim-paredit/api/slurping.lua +++ b/lua/nvim-paredit/api/slurping.lua @@ -1,4 +1,5 @@ local traversal = require("nvim-paredit.utils.traversal") +local indentation = require("nvim-paredit.indentation") local common = require("nvim-paredit.utils.common") local ts = require("nvim-treesitter.ts_utils") local config = require("nvim-paredit.config") @@ -40,11 +41,12 @@ local function slurp(opts) end local buf = vim.api.nvim_get_current_buf() + local form_edges = lang.get_form_edges(form) local left_or_right_edge if opts.reversed then - left_or_right_edge = lang.get_form_edges(form).left + left_or_right_edge = form_edges.left else - left_or_right_edge = lang.get_form_edges(form).right + left_or_right_edge = form_edges.right end local start_or_end @@ -57,6 +59,7 @@ local function slurp(opts) local row = start_or_end[1] local col = start_or_end[2] + -- stylua: ignore vim.api.nvim_buf_set_text(buf, row, col, row, col, @@ -65,9 +68,10 @@ local function slurp(opts) local offset = 0 if opts.reversed and row == left_or_right_edge.range[1] then - offset = string.len(left_or_right_edge.text) + offset = #left_or_right_edge.text end + -- stylua: ignore vim.api.nvim_buf_set_text( buf, left_or_right_edge.range[1], left_or_right_edge.range[2] + offset, @@ -77,7 +81,7 @@ local function slurp(opts) local cursor_behaviour = opts.cursor_behaviour or config.config.cursor_behaviour if cursor_behaviour == "follow" then - local offset = 0 + offset = 0 if not opts.reversed then offset = string.len(left_or_right_edge.text) end @@ -88,6 +92,30 @@ local function slurp(opts) vim.api.nvim_win_set_cursor(0, cursor_pos) end end + + local operation_type + local new_range + if not opts.reversed then + operation_type = "slurp-forwards" + -- stylua: ignore + new_range = { + form_edges.left.range[1], form_edges.left.range[2], + row, col, + } + else + operation_type = "slurp-backwards" + -- stylua: ignore + new_range = { + row, col, + form_edges.right.range[1], form_edges.right.range[2], + } + end + + local event = { + type = operation_type, + parent_range = new_range, + } + indentation.handle_indentation(event, opts) end function M.slurp_forwards(opts) @@ -96,7 +124,7 @@ end function M.slurp_backwards(opts) slurp(common.merge(opts or {}, { - reversed = true + reversed = true, })) end diff --git a/lua/nvim-paredit/defaults.lua b/lua/nvim-paredit/defaults.lua index 2e40bf6..46cc9a6 100644 --- a/lua/nvim-paredit/defaults.lua +++ b/lua/nvim-paredit/defaults.lua @@ -61,6 +61,7 @@ M.default_keys = { M.defaults = { use_default_keys = true, cursor_behaviour = "auto", -- remain, follow, auto + auto_indent = true, keys = {}, } diff --git a/lua/nvim-paredit/indentation/init.lua b/lua/nvim-paredit/indentation/init.lua new file mode 100644 index 0000000..82d1142 --- /dev/null +++ b/lua/nvim-paredit/indentation/init.lua @@ -0,0 +1,31 @@ +local native = require("nvim-paredit.indentation.native") +local config = require("nvim-paredit.config") + +local M = {} + +function M.handle_indentation(event, opts) + local auto_indent = opts.auto_indent or config.config.auto_indent + if not auto_indent then + return + end + + local indent_fn = opts.indent_fn or config.config.indent_fn or native.indent + if not indent_fn then + return + end + + local tree = vim.treesitter.get_parser(0) + + tree:parse() + + local parent = tree:named_node_for_range(event.parent_range) + indent_fn( + vim.tbl_deep_extend("force", event, { + tree = tree, + parent = parent, + }), + opts + ) +end + +return M diff --git a/lua/nvim-paredit/indentation/native.lua b/lua/nvim-paredit/indentation/native.lua new file mode 100644 index 0000000..1f03728 --- /dev/null +++ b/lua/nvim-paredit/indentation/native.lua @@ -0,0 +1,141 @@ +local traversal = require("nvim-paredit.utils.traversal") +local utils = require("nvim-paredit.indentation.utils") +local langs = require("nvim-paredit.lang") + +local M = {} + +local function dedent_lines(opts) + -- stylua: ignore + local line_text = vim.api.nvim_buf_get_lines( + opts.buf or 0, + opts.lines[1], opts.lines[#opts.lines] + 1, + false + ) + + local smallest_distance = opts.delta + for _, line in ipairs(line_text) do + local first_char_index = string.find(line, "[^%s]") + if first_char_index and (first_char_index - 1) < smallest_distance then + smallest_distance = first_char_index - 1 + end + end + + for index, line in ipairs(opts.lines) do + local deletion_range = smallest_distance + local contains_chars = string.find(line_text[index], "[^%s]") + if not contains_chars then + deletion_range = #line_text[index] + end + -- stylua: ignore + vim.api.nvim_buf_set_text( + opts.buf or 0, + line, 0, + line, deletion_range, + {} + ) + end +end + +local function indent_lines(opts) + if opts.delta == 0 then + return + end + + if opts.delta < 0 then + return dedent_lines(vim.tbl_deep_extend("force", opts, { + delta = opts.delta * -1, + })) + end + + local spaces = string.rep(" ", opts.delta) + for _, line in ipairs(opts.lines) do + -- stylua: ignore + vim.api.nvim_buf_set_text( + opts.buf or 0, + line, 0, + line, 0, + {spaces} + ) + end +end + +local function indent_barf(event) + local lang = langs.get_language_api() + + local lhs + local node + if event.type == "barf-forwards" then + node = traversal.get_next_sibling_ignoring_comments(event.parent, { lang = lang }) + lhs = event.parent + else + node = event.parent + lhs = traversal.get_prev_sibling_ignoring_comments(event.parent, { lang = lang }) + end + + if not node or not lhs then + return + end + + local parent = node:parent() + + local lhs_range = { lhs:range() } + local node_range = { node:range() } + + if not utils.node_is_first_on_line(node, { lang = lang }) or lhs_range[1] == node_range[1] then + return + end + + local lines = utils.find_affected_lines(node, utils.get_node_line_range(node_range)) + + local delta + if parent:type() == "source" then + delta = node_range[2] + else + local form_edges = lang.get_form_edges(parent) + delta = node_range[2] - form_edges.left.range[2] - 1 + end + + indent_lines({ + buf = event.buf, + lines = lines, + delta = delta * -1, + }) +end + +local function indent_slurp(event) + local parent = event.parent + local lang = langs.get_language_api() + + local child + if event.type == "slurp-forwards" then + child = parent:named_child(parent:named_child_count() - 1) + else + child = parent:named_child(1) + end + + local parent_range = { parent:range() } + local child_range = { child:range() } + + if not utils.node_is_first_on_line(child, { lang = lang }) or parent_range[1] == child_range[1] then + return + end + + local lines = utils.find_affected_lines(child, utils.get_node_line_range(child_range)) + local form_edges = lang.get_form_edges(parent) + + indent_lines({ + buf = event.buf, + lines = lines, + delta = form_edges.left.range[4] - child_range[2], + }) +end + +function M.indent(event) + if event.type == "slurp-forwards" or event.type == "slurp-backwards" then + indent_slurp(event) + else + indent_barf(event) + end +end + +return M diff --git a/lua/nvim-paredit/indentation/utils.lua b/lua/nvim-paredit/indentation/utils.lua new file mode 100644 index 0000000..3290f11 --- /dev/null +++ b/lua/nvim-paredit/indentation/utils.lua @@ -0,0 +1,77 @@ +local traversal = require("nvim-paredit.utils.traversal") + +local M = {} + +function M.get_node_line_range(range) + local lines = {} + for i = range[1], range[3], 1 do + table.insert(lines, i) + end + return lines +end + +function M.get_node_rhs_siblings(node) + local nodes = {} + local current = node + while current do + table.insert(nodes, current) + current = current:next_named_sibling() + end + return nodes +end + +function M.ordered_set(lines) + local seen = {} + local result = {} + for _, value in ipairs(lines) do + if not seen[value] then + table.insert(result, value) + seen[value] = true + end + end + + table.sort(result) + return result +end + +function M.find_affected_lines(node, lines) + local siblings = M.get_node_rhs_siblings(node) + for _, sibling in ipairs(siblings) do + local range = { sibling:range() } + + local sibling_is_affected = false + for _, line in ipairs(lines) do + if line == range[1] then + sibling_is_affected = true + end + end + + if sibling_is_affected then + local new_lines = M.get_node_line_range(range) + for _, row in ipairs(new_lines) do + table.insert(lines, row) + end + end + end + + local parent = node:parent() + if parent then + return M.find_affected_lines(parent, lines) + end + + return M.ordered_set(lines) +end + +function M.node_is_first_on_line(node, opts) + local node_range = { node:range() } + + local sibling = traversal.get_prev_sibling_ignoring_comments(node, opts) + if not sibling then + return true + end + + local sibling_range = { sibling:range() } + return sibling_range[3] ~= node_range[1] +end + +return M diff --git a/tests/nvim-paredit/indentation_spec.lua b/tests/nvim-paredit/indentation_spec.lua new file mode 100644 index 0000000..97122bd --- /dev/null +++ b/tests/nvim-paredit/indentation_spec.lua @@ -0,0 +1,170 @@ +local paredit = require("nvim-paredit.api") + +local expect_all = require("tests.nvim-paredit.utils").expect_all + +describe("forward slurping indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function slurp_forwards() + paredit.slurp_forwards({ + auto_indent = true, + }) + end + + expect_all(slurp_forwards, { + { + "should indent a nested child", + before_content = { "()", "a" }, + before_cursor = { 1, 1 }, + after_content = { "(", " a)" }, + after_cursor = { 1, 0 }, + }, + { + "should indent a multi-line child", + before_content = { "()", "(a", " b c)" }, + before_cursor = { 1, 1 }, + after_content = { "(", " (a", " b c))" }, + after_cursor = { 1, 0 }, + }, + { + "should indent a multi-line child that pushes other nodes", + before_content = { "()", "(a", " b) (c", "d) (e", "f)" }, + before_cursor = { 1, 1 }, + after_content = { "(", " (a", " b)) (c", " d) (e", " f)" }, + after_cursor = { 1, 0 }, + }, + { + "should not indent if node is not first on line", + before_content = { "(", "a) (a", "b)" }, + before_cursor = { 1, 1 }, + after_content = { "(", "a (a", "b))" }, + after_cursor = { 1, 0 }, + }, + { + "should not indent when on same line", + before_content = "() 1", + before_cursor = { 1, 1 }, + after_content = "( 1)", + after_cursor = { 1, 1 }, + }, + { + "should dedent when node is too far indented", + before_content = { "()", " a" }, + before_cursor = { 1, 1 }, + after_content = { "(", " a)" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent without deleting characters", + before_content = { "()", " (a", " b)" }, + before_cursor = { 1, 1 }, + after_content = { "(", " (a", "b))" }, + after_cursor = { 1, 0 }, + }, + { + "should indent the correct node ignoring comments", + before_content = { "()", ";; comment", "a" }, + before_cursor = { 1, 1 }, + after_content = { "(", ";; comment", " a)" }, + after_cursor = { 1, 0 }, + }, + }) +end) + +describe("backward slurping indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function slurp_backwards() + paredit.slurp_backwards({ + auto_indent = true, + }) + end + + expect_all(slurp_backwards, { + { + "should indent a nested child", + before_content = { "a", "(b)" }, + before_cursor = { 2, 1 }, + after_content = { "(a", " b)" }, + after_cursor = { 2, 2 }, + }, + { + "should not indent when on same line", + before_content = { "a (b)" }, + before_cursor = { 1, 3 }, + after_content = { "(a b)" }, + after_cursor = { 1, 3 }, + }, + }) +end) + +describe("forward barfing indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function barf_forwards() + paredit.barf_forwards({ + auto_indent = true, + }) + end + + expect_all(barf_forwards, { + { + "should dedent the barfed child", + before_content = { "(", " a)" }, + before_cursor = { 1, 0 }, + after_content = { "()", "a" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent a multi-line child and affected siblings", + before_content = { "(", " (a", " b c)) (a", " d)" }, + before_cursor = { 1, 0 }, + after_content = { "()", "(a", " b c) (a", "d)" }, + after_cursor = { 1, 0 }, + }, + { + "should not dedent if node is on the same line", + before_content = { "(a", "b c)" }, + before_cursor = { 1, 1 }, + after_content = { "(a", "b) c" }, + after_cursor = { 1, 1 }, + }, + { + "should not dedent when there is no indentation", + before_content = { "(", "a)" }, + before_cursor = { 1, 0 }, + after_content = { "()", "a" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent the minimum amount without deleting chars", + before_content = { "(", " a) (b", " c)" }, + before_cursor = { 1, 0 }, + after_content = { "()", " a (b", "c)" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent the correct node ignoring comments", + before_content = { "(", ";; comment", " a)" }, + before_cursor = { 1, 1 }, + after_content = { "()", ";; comment", "a" }, + after_cursor = { 1, 0 }, + }, + }) +end) + +describe("backward barfing indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function barf_backwards() + paredit.barf_backwards({ + auto_indent = true, + }) + end + + expect_all(barf_backwards, { + { + "should dedent a nested child", + before_content = { "(a", " b)" }, + before_cursor = { 1, 0 }, + after_content = { "a", "(b)" }, + after_cursor = { 1, 0 }, + }, + }) +end)