From 0351f12e7ae46fa4be71bdae01be9c2a6080dfaa Mon Sep 17 00:00:00 2001 From: Julien Vincent Date: Sun, 30 Jul 2023 12:43:56 +0100 Subject: [PATCH] Add auto indentation correction on slurp/barf This adds support for automatic indentation correction when slurping and barfing. The goal of this implementation is to: 1) Provide a visual aid to the user that allows them to confirm they are operating on the correct node, and to know when to stop when performing recursive slurp/barf operations. 2) Be simple to maintain and understand while being as correct as possible 3) Be replaceable with other implementations. 4) Be as performant as possible. There should be no lag or visual jitter The goal is _not_ to be 100% correct. If a more correct implementation is needed then one can be provided through `indent_fn`. For example, an implementation using `vim.lsp.buf.format` could be built if the user doesn't mind sacrificing performance for correctness. --- README.md | 20 +++ lua/nvim-paredit/api/barfing.lua | 38 +++++- lua/nvim-paredit/api/slurping.lua | 41 +++++- lua/nvim-paredit/config.lua | 4 +- lua/nvim-paredit/defaults.lua | 4 + lua/nvim-paredit/indentation/init.lua | 27 ++++ lua/nvim-paredit/indentation/native.lua | 136 ++++++++++++++++++++ lua/nvim-paredit/indentation/utils.lua | 64 ++++++++++ lua/nvim-paredit/init.lua | 4 +- lua/nvim-paredit/utils/common.lua | 32 ++--- tests/nvim-paredit/indentation_spec.lua | 163 ++++++++++++++++++++++++ 11 files changed, 500 insertions(+), 33 deletions(-) create mode 100644 lua/nvim-paredit/indentation/init.lua create mode 100644 lua/nvim-paredit/indentation/native.lua create mode 100644 lua/nvim-paredit/indentation/utils.lua create mode 100644 tests/nvim-paredit/indentation_spec.lua diff --git a/README.md b/README.md index dc1bf8a..c7bbbc7 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,27 @@ require("nvim-paredit").setup({ -- defaults to all supported file types including custom lang -- extensions (see next section) filetypes = { "clojure" }, + + -- This controls where the cursor is placed when performing slurp/barf operations + -- + -- - "remain" - It will never change the cursor position, keeping it in the same place + -- - "follow" - It will always place the cursor on the form edge that was moved + -- - "auto" - A combination of remain and follow, it will try keep the cursor in the original position + -- unless doing so would result in the cursor no longer being within the original form. In + -- this case it will place the cursor on the moved edge cursor_behaviour = "auto", -- remain, follow, auto + + indent = { + -- This controls how nvim-paredit handles indentation when performing operations which + -- should change the indentation of the form (such as when slurping or barfing). + -- + -- When set to true then it will attempt to fix the indentation of nodes operated on. + enabled = true, + -- A function that will be called after a slurp/barf if you want to provide a custom indentation + -- implementation. + indentor = require("nvim-paredit.indentation.native").indentor, + }, + -- list of default keybindings keys = { [">)"] = { paredit.api.slurp_forwards, "Slurp forwards" }, diff --git a/lua/nvim-paredit/api/barfing.lua b/lua/nvim-paredit/api/barfing.lua index 329987f..5179c5a 100644 --- a/lua/nvim-paredit/api/barfing.lua +++ b/lua/nvim-paredit/api/barfing.lua @@ -1,4 +1,5 @@ local traversal = require("nvim-paredit.utils.traversal") +local indentation = require("nvim-paredit.indentation") local common = require("nvim-paredit.utils.common") local ts = require("nvim-treesitter.ts_utils") local config = require("nvim-paredit.config") @@ -28,11 +29,11 @@ function M.barf_forwards(opts) local child if opts.reversed then child = traversal.get_first_child_ignoring_comments(form, { - lang = lang + lang = lang, }) else child = traversal.get_last_child_ignoring_comments(form, { - lang = lang + lang = lang, }) end if not child then @@ -42,7 +43,7 @@ function M.barf_forwards(opts) local edges = lang.get_form_edges(form) local sibling = traversal.get_prev_sibling_ignoring_comments(child, { - lang = lang + lang = lang, }) local end_pos @@ -55,6 +56,7 @@ function M.barf_forwards(opts) local buf = vim.api.nvim_get_current_buf() local range = edges.right.range + -- stylua: ignore vim.api.nvim_buf_set_text( buf, range[1], range[2], @@ -63,6 +65,7 @@ function M.barf_forwards(opts) ) local text = edges.right.text + -- stylua: ignore vim.api.nvim_buf_set_text(buf, end_pos[1], end_pos[2], end_pos[1], end_pos[2], @@ -77,6 +80,19 @@ function M.barf_forwards(opts) vim.api.nvim_win_set_cursor(0, { end_pos[1] + 1, end_pos[2] }) end end + + local event = { + type = "barf-forwards", + -- stylua: ignore + parent_range = { + edges.left.range[1], edges.left.range[2], + end_pos[1], end_pos[2], + }, + reversed = false, + indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, + lang = lang, + } + indentation.handle_indentation(event, opts) end function M.barf_backwards(opts) @@ -99,7 +115,7 @@ function M.barf_backwards(opts) end local child = traversal.get_first_child_ignoring_comments(form, { - lang = lang + lang = lang, }) if not child then return @@ -108,7 +124,7 @@ function M.barf_backwards(opts) local edges = lang.get_form_edges(lang.get_node_root(form)) local sibling = traversal.get_next_sibling_ignoring_comments(child, { - lang = lang + lang = lang, }) local end_pos @@ -121,6 +137,7 @@ function M.barf_backwards(opts) local buf = vim.api.nvim_get_current_buf() local text = edges.left.text + -- stylua: ignore vim.api.nvim_buf_set_text(buf, end_pos[1], end_pos[2], end_pos[1], end_pos[2], @@ -128,6 +145,7 @@ function M.barf_backwards(opts) ) local range = edges.left.range + -- stylua: ignore vim.api.nvim_buf_set_text( buf, range[1], range[2], @@ -143,6 +161,16 @@ function M.barf_backwards(opts) vim.api.nvim_win_set_cursor(0, { end_pos[1] + 1, end_pos[2] }) end end + + local event = { + type = "barf-backwards", + -- stylua: ignore + parent_range = { + end_pos[1], end_pos[2], + edges.right.range[1], edges.right.range[2], + }, + } + indentation.handle_indentation(event, opts) end return M diff --git a/lua/nvim-paredit/api/slurping.lua b/lua/nvim-paredit/api/slurping.lua index 47fb369..794072c 100644 --- a/lua/nvim-paredit/api/slurping.lua +++ b/lua/nvim-paredit/api/slurping.lua @@ -1,5 +1,5 @@ local traversal = require("nvim-paredit.utils.traversal") -local common = require("nvim-paredit.utils.common") +local indentation = require("nvim-paredit.indentation") local ts = require("nvim-treesitter.ts_utils") local config = require("nvim-paredit.config") local langs = require("nvim-paredit.lang") @@ -40,11 +40,12 @@ local function slurp(opts) end local buf = vim.api.nvim_get_current_buf() + local form_edges = lang.get_form_edges(form) local left_or_right_edge if opts.reversed then - left_or_right_edge = lang.get_form_edges(form).left + left_or_right_edge = form_edges.left else - left_or_right_edge = lang.get_form_edges(form).right + left_or_right_edge = form_edges.right end local start_or_end @@ -57,6 +58,7 @@ local function slurp(opts) local row = start_or_end[1] local col = start_or_end[2] + -- stylua: ignore vim.api.nvim_buf_set_text(buf, row, col, row, col, @@ -65,9 +67,10 @@ local function slurp(opts) local offset = 0 if opts.reversed and row == left_or_right_edge.range[1] then - offset = string.len(left_or_right_edge.text) + offset = #left_or_right_edge.text end + -- stylua: ignore vim.api.nvim_buf_set_text( buf, left_or_right_edge.range[1], left_or_right_edge.range[2] + offset, @@ -77,7 +80,7 @@ local function slurp(opts) local cursor_behaviour = opts.cursor_behaviour or config.config.cursor_behaviour if cursor_behaviour == "follow" then - local offset = 0 + offset = 0 if not opts.reversed then offset = string.len(left_or_right_edge.text) end @@ -88,6 +91,30 @@ local function slurp(opts) vim.api.nvim_win_set_cursor(0, cursor_pos) end end + + local operation_type + local new_range + if not opts.reversed then + operation_type = "slurp-forwards" + -- stylua: ignore + new_range = { + form_edges.left.range[1], form_edges.left.range[2], + row, col, + } + else + operation_type = "slurp-backwards" + -- stylua: ignore + new_range = { + row, col, + form_edges.right.range[1], form_edges.right.range[2], + } + end + + local event = { + type = operation_type, + parent_range = new_range, + } + indentation.handle_indentation(event, opts) end function M.slurp_forwards(opts) @@ -95,8 +122,8 @@ function M.slurp_forwards(opts) end function M.slurp_backwards(opts) - slurp(common.merge(opts or {}, { - reversed = true + slurp(vim.tbl_deep_extend("force", opts or {}, { + reversed = true, })) end diff --git a/lua/nvim-paredit/config.lua b/lua/nvim-paredit/config.lua index 1f43cb2..c09f05a 100644 --- a/lua/nvim-paredit/config.lua +++ b/lua/nvim-paredit/config.lua @@ -1,11 +1,9 @@ -local common = require("nvim-paredit.utils.common") - local M = {} M.config = {} function M.update_config(config) - M.config = common.merge(M.config, config) + M.config = vim.tbl_deep_extend("force", M.config, config) end return M diff --git a/lua/nvim-paredit/defaults.lua b/lua/nvim-paredit/defaults.lua index 2e40bf6..9796505 100644 --- a/lua/nvim-paredit/defaults.lua +++ b/lua/nvim-paredit/defaults.lua @@ -61,6 +61,10 @@ M.default_keys = { M.defaults = { use_default_keys = true, cursor_behaviour = "auto", -- remain, follow, auto + indent = { + enabled = true, + indentor = require("nvim-paredit.indentation.native").indentor, + }, keys = {}, } diff --git a/lua/nvim-paredit/indentation/init.lua b/lua/nvim-paredit/indentation/init.lua new file mode 100644 index 0000000..881bb79 --- /dev/null +++ b/lua/nvim-paredit/indentation/init.lua @@ -0,0 +1,27 @@ +local config = require("nvim-paredit.config") + +local M = {} + +function M.handle_indentation(event, opts) + local indent = opts.indent or config.config.indent or {} + if not indent.enabled or not indent.indentor then + return + end + + local tree = vim.treesitter.get_parser(0) + + tree:parse() + local parent = tree:named_node_for_range(event.parent_range) + + indent.indentor( + vim.tbl_deep_extend("force", event, { + tree = tree, + parent = parent, + }), + vim.tbl_deep_extend("force", opts, { + indent = indent, + }) + ) +end + +return M diff --git a/lua/nvim-paredit/indentation/native.lua b/lua/nvim-paredit/indentation/native.lua new file mode 100644 index 0000000..f47da35 --- /dev/null +++ b/lua/nvim-paredit/indentation/native.lua @@ -0,0 +1,136 @@ +local traversal = require("nvim-paredit.utils.traversal") +local utils = require("nvim-paredit.indentation.utils") +local langs = require("nvim-paredit.lang") + +local M = {} + +local function dedent_lines(lines, delta, opts) + -- stylua: ignore + local line_text = vim.api.nvim_buf_get_lines( + opts.buf or 0, + lines[1], lines[#lines] + 1, + false + ) + + local smallest_distance = delta + for _, line in ipairs(line_text) do + local first_char_index = string.find(line, "[^%s]") + if first_char_index and (first_char_index - 1) < smallest_distance then + smallest_distance = first_char_index - 1 + end + end + + for index, line in ipairs(lines) do + local deletion_range = smallest_distance + local contains_chars = string.find(line_text[index], "[^%s]") + if not contains_chars then + deletion_range = #line_text[index] + end + -- stylua: ignore + vim.api.nvim_buf_set_text( + opts.buf or 0, + line, 0, + line, deletion_range, + {} + ) + end +end + +local function indent_lines(lines, delta, opts) + if delta == 0 then + return + end + + if delta < 0 then + return dedent_lines(lines, delta * -1, opts) + end + + local chars = string.rep(" ", delta) + for _, line in ipairs(lines) do + -- stylua: ignore + vim.api.nvim_buf_set_text( + opts.buf or 0, + line, 0, + line, 0, + {chars} + ) + end +end + +local function indent_barf(event) + local lang = langs.get_language_api() + + local lhs + local node + if event.type == "barf-forwards" then + node = traversal.get_next_sibling_ignoring_comments(event.parent, { lang = lang }) + lhs = event.parent + else + node = event.parent + lhs = traversal.get_prev_sibling_ignoring_comments(event.parent, { lang = lang }) + end + + if not node or not lhs then + return + end + + local parent = node:parent() + + local lhs_range = { lhs:range() } + local node_range = { node:range() } + + if not utils.node_is_first_on_line(node, { lang = lang }) or lhs_range[1] == node_range[1] then + return + end + + local lines = utils.find_affected_lines(node, utils.get_node_line_range(node_range)) + + local delta + if parent:type() == "source" then + delta = node_range[2] + else + local form_edges = lang.get_form_edges(parent) + delta = node_range[2] - form_edges.left.range[2] - 1 + end + + indent_lines(lines, delta * -1, { + buf = event.buf, + }) +end + +local function indent_slurp(event) + local parent = event.parent + local lang = langs.get_language_api() + + local child + if event.type == "slurp-forwards" then + child = parent:named_child(parent:named_child_count() - 1) + else + child = parent:named_child(1) + end + + local parent_range = { parent:range() } + local child_range = { child:range() } + + if not utils.node_is_first_on_line(child, { lang = lang }) or parent_range[1] == child_range[1] then + return + end + + local lines = utils.find_affected_lines(child, utils.get_node_line_range(child_range)) + local form_edges = lang.get_form_edges(parent) + + local delta = form_edges.left.range[4] - child_range[2] + indent_lines(lines, delta, { + buf = event.buf, + }) +end + +function M.indentor(event, _) + if event.type == "slurp-forwards" or event.type == "slurp-backwards" then + indent_slurp(event) + else + indent_barf(event) + end +end + +return M diff --git a/lua/nvim-paredit/indentation/utils.lua b/lua/nvim-paredit/indentation/utils.lua new file mode 100644 index 0000000..0afb09c --- /dev/null +++ b/lua/nvim-paredit/indentation/utils.lua @@ -0,0 +1,64 @@ +local traversal = require("nvim-paredit.utils.traversal") +local common = require("nvim-paredit.utils.common") + +local M = {} + +function M.get_node_line_range(range) + local lines = {} + for i = range[1], range[3], 1 do + table.insert(lines, i) + end + return lines +end + +function M.get_node_rhs_siblings(node) + local nodes = {} + local current = node + while current do + table.insert(nodes, current) + current = current:next_named_sibling() + end + return nodes +end + +function M.find_affected_lines(node, lines) + local siblings = M.get_node_rhs_siblings(node) + for _, sibling in ipairs(siblings) do + local range = { sibling:range() } + + local sibling_is_affected = false + for _, line in ipairs(lines) do + if line == range[1] then + sibling_is_affected = true + end + end + + if sibling_is_affected then + local new_lines = M.get_node_line_range(range) + for _, row in ipairs(new_lines) do + table.insert(lines, row) + end + end + end + + local parent = node:parent() + if parent then + return M.find_affected_lines(parent, lines) + end + + return common.ordered_set(lines) +end + +function M.node_is_first_on_line(node, opts) + local node_range = { node:range() } + + local sibling = traversal.get_prev_sibling_ignoring_comments(node, opts) + if not sibling then + return true + end + + local sibling_range = { sibling:range() } + return sibling_range[3] ~= node_range[1] +end + +return M diff --git a/lua/nvim-paredit/init.lua b/lua/nvim-paredit/init.lua index 81d6d2a..4bd1932 100644 --- a/lua/nvim-paredit/init.lua +++ b/lua/nvim-paredit/init.lua @@ -41,11 +41,11 @@ function M.setup(opts) local keys = opts.keys or {} if type(opts.use_default_keys) ~= "boolean" or opts.use_default_keys then - keys = common.merge(defaults.default_keys, opts.keys or {}) + keys = vim.tbl_deep_extend("force", defaults.default_keys, opts.keys or {}) end config.update_config(defaults.defaults) - config.update_config(common.merge(opts, { + config.update_config(vim.tbl_deep_extend("force", opts, { filetypes = filetypes, keys = keys, })) diff --git a/lua/nvim-paredit/utils/common.lua b/lua/nvim-paredit/utils/common.lua index ea297e3..c4c7fbd 100644 --- a/lua/nvim-paredit/utils/common.lua +++ b/lua/nvim-paredit/utils/common.lua @@ -9,17 +9,6 @@ function M.included_in_table(table, item) return false end -function M.merge(a, b) - local result = {} - for k, v in pairs(a) do - result[k] = v - end - for k, v in pairs(b) do - result[k] = v - end - return result -end - -- Compares the two given { col, row } position tuples and returns -1/0/1 depending -- on whether `a` is less than, equal to or greater than `b` -- @@ -59,6 +48,20 @@ function M.intersection(tbl, original) return result end +function M.ordered_set(lines) + local seen = {} + local result = {} + for _, value in ipairs(lines) do + if not seen[value] then + table.insert(result, value) + seen[value] = true + end + end + + table.sort(result) + return result +end + function M.ensure_visual_mode() if vim.api.nvim_get_mode().mode ~= "v" then vim.api.nvim_command("normal! v") @@ -72,11 +75,8 @@ function M.is_whitespace_under_cursor(lang) cursor = { cursor[1] - 1, cursor[2] } local char_under_cursor = vim.api.nvim_buf_get_text(0, cursor[1], cursor[2], cursor[1], cursor[2] + 1, {}) - return M.included_in_table( - lang.whitespace_chars or M.default_whitespace_chars, - char_under_cursor[1] - ) or char_under_cursor[1] == "" + return M.included_in_table(lang.whitespace_chars or M.default_whitespace_chars, char_under_cursor[1]) + or char_under_cursor[1] == "" end return M - diff --git a/tests/nvim-paredit/indentation_spec.lua b/tests/nvim-paredit/indentation_spec.lua new file mode 100644 index 0000000..46d17b9 --- /dev/null +++ b/tests/nvim-paredit/indentation_spec.lua @@ -0,0 +1,163 @@ +local defaults = require("nvim-paredit.defaults") +local paredit = require("nvim-paredit.api") + +local expect_all = require("tests.nvim-paredit.utils").expect_all + +describe("forward slurping indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function slurp_forwards() + paredit.slurp_forwards(defaults.defaults) + end + + expect_all(slurp_forwards, { + { + "should indent a nested child", + before_content = { "()", "a" }, + before_cursor = { 1, 1 }, + after_content = { "(", " a)" }, + after_cursor = { 1, 0 }, + }, + { + "should indent a multi-line child", + before_content = { "()", "(a", " b c)" }, + before_cursor = { 1, 1 }, + after_content = { "(", " (a", " b c))" }, + after_cursor = { 1, 0 }, + }, + { + "should indent a multi-line child that pushes other nodes", + before_content = { "()", "(a", " b) (c", "d) (e", "f)" }, + before_cursor = { 1, 1 }, + after_content = { "(", " (a", " b)) (c", " d) (e", " f)" }, + after_cursor = { 1, 0 }, + }, + { + "should not indent if node is not first on line", + before_content = { "(", "a) (a", "b)" }, + before_cursor = { 1, 1 }, + after_content = { "(", "a (a", "b))" }, + after_cursor = { 1, 0 }, + }, + { + "should not indent when on same line", + before_content = "() 1", + before_cursor = { 1, 1 }, + after_content = "( 1)", + after_cursor = { 1, 1 }, + }, + { + "should dedent when node is too far indented", + before_content = { "()", " a" }, + before_cursor = { 1, 1 }, + after_content = { "(", " a)" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent without deleting characters", + before_content = { "()", " (a", " b)" }, + before_cursor = { 1, 1 }, + after_content = { "(", " (a", "b))" }, + after_cursor = { 1, 0 }, + }, + { + "should indent the correct node ignoring comments", + before_content = { "()", ";; comment", "a" }, + before_cursor = { 1, 1 }, + after_content = { "(", ";; comment", " a)" }, + after_cursor = { 1, 0 }, + }, + }) +end) + +describe("backward slurping indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function slurp_backwards() + paredit.slurp_backwards(defaults.defaults) + end + + expect_all(slurp_backwards, { + { + "should indent a nested child", + before_content = { "a", "(b)" }, + before_cursor = { 2, 1 }, + after_content = { "(a", " b)" }, + after_cursor = { 2, 2 }, + }, + { + "should not indent when on same line", + before_content = { "a (b)" }, + before_cursor = { 1, 3 }, + after_content = { "(a b)" }, + after_cursor = { 1, 3 }, + }, + }) +end) + +describe("forward barfing indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function barf_forwards() + paredit.barf_forwards(defaults.defaults) + end + + expect_all(barf_forwards, { + { + "should dedent the barfed child", + before_content = { "(", " a)" }, + before_cursor = { 1, 0 }, + after_content = { "()", "a" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent a multi-line child and affected siblings", + before_content = { "(", " (a", " b c)) (a", " d)" }, + before_cursor = { 1, 0 }, + after_content = { "()", "(a", " b c) (a", "d)" }, + after_cursor = { 1, 0 }, + }, + { + "should not dedent if node is on the same line", + before_content = { "(a", "b c)" }, + before_cursor = { 1, 1 }, + after_content = { "(a", "b) c" }, + after_cursor = { 1, 1 }, + }, + { + "should not dedent when there is no indentation", + before_content = { "(", "a)" }, + before_cursor = { 1, 0 }, + after_content = { "()", "a" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent the minimum amount without deleting chars", + before_content = { "(", " a) (b", " c)" }, + before_cursor = { 1, 0 }, + after_content = { "()", " a (b", "c)" }, + after_cursor = { 1, 0 }, + }, + { + "should dedent the correct node ignoring comments", + before_content = { "(", ";; comment", " a)" }, + before_cursor = { 1, 1 }, + after_content = { "()", ";; comment", "a" }, + after_cursor = { 1, 0 }, + }, + }) +end) + +describe("backward barfing indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function barf_backwards() + paredit.barf_backwards(defaults.defaults) + end + + expect_all(barf_backwards, { + { + "should dedent a nested child", + before_content = { "(a", " b)" }, + before_cursor = { 1, 0 }, + after_content = { "a", "(b)" }, + after_cursor = { 2, 1 }, + }, + }) +end)