diff --git a/lua/nvim-paredit/api/barfing.lua b/lua/nvim-paredit/api/barfing.lua index d1409ea..452c388 100644 --- a/lua/nvim-paredit/api/barfing.lua +++ b/lua/nvim-paredit/api/barfing.lua @@ -72,17 +72,6 @@ function M.barf_forwards(opts) { text } ) - indentation.handle_indentation({ - type = "barf", - from = range, - to = { end_pos[1], end_pos[2], end_pos[1], end_pos[2] }, - child = child, - parent = form, - - indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, - lang = lang, - }) - local cursor_behaviour = opts.cursor_behaviour or config.config.cursor_behaviour if cursor_behaviour == "auto" or cursor_behaviour == "follow" then local cursor_pos = vim.api.nvim_win_get_cursor(0) @@ -91,6 +80,18 @@ function M.barf_forwards(opts) vim.api.nvim_win_set_cursor(0, { end_pos[1] + 1, end_pos[2] }) end end + + indentation.handle_indentation({ + type = "barf-forwards", + -- stylua: ignore + parent_range = { + edges.left.range[1], edges.left.range[2], + end_pos[1], end_pos[2], + }, + reversed = false, + indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, + lang = lang, + }) end function M.barf_backwards(opts) @@ -151,17 +152,6 @@ function M.barf_backwards(opts) {} ) - indentation.handle_indentation({ - type = "barf", - from = range, - to = { end_pos[1], end_pos[2], end_pos[1], end_pos[2] }, - child = child, - parent = form, - - indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, - lang = lang, - }) - local cursor_behaviour = opts.cursor_behaviour or config.config.cursor_behaviour if cursor_behaviour == "auto" or cursor_behaviour == "follow" then local cursor_pos = vim.api.nvim_win_get_cursor(0) @@ -170,6 +160,18 @@ function M.barf_backwards(opts) vim.api.nvim_win_set_cursor(0, { end_pos[1] + 1, end_pos[2] }) end end + + indentation.handle_indentation({ + type = "barf-backwards", + -- stylua: ignore + parent_range = { + end_pos[1], end_pos[2], + edges.right.range[1], edges.right.range[2], + }, + reversed = true, + indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, + lang = lang, + }) end return M diff --git a/lua/nvim-paredit/api/slurping.lua b/lua/nvim-paredit/api/slurping.lua index 9482543..6633523 100644 --- a/lua/nvim-paredit/api/slurping.lua +++ b/lua/nvim-paredit/api/slurping.lua @@ -68,7 +68,7 @@ local function slurp(opts) local offset = 0 if opts.reversed and row == left_or_right_edge.range[1] then - offset = string.len(left_or_right_edge.text) + offset = #left_or_right_edge.text end -- stylua: ignore @@ -79,17 +79,6 @@ local function slurp(opts) {} ) - indentation.handle_indentation({ - type = "slurp", - from = left_or_right_edge.range, - to = { row, col, row, col }, - child = sibling, - parent = form, - - indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, - lang = lang, - }) - local cursor_behaviour = opts.cursor_behaviour or config.config.cursor_behaviour if cursor_behaviour == "follow" then offset = 0 @@ -103,6 +92,34 @@ local function slurp(opts) vim.api.nvim_win_set_cursor(0, cursor_pos) end end + + local operation_type + local new_range + if not opts.reversed then + operation_type = "slurp-forwards" + new_range = { + form_edges.left.range[1], + form_edges.left.range[2], + row, + col, + } + else + operation_type = "slurp-backwards" + new_range = { + row, + col, + form_edges.right.range[1], + form_edges.right.range[2], + } + end + + indentation.handle_indentation({ + type = operation_type, + parent_range = new_range, + reversed = opts.reversed, + indent_behaviour = opts.indent_behaviour or config.config.indent_behaviour, + lang = lang, + }) end function M.slurp_forwards(opts) diff --git a/lua/nvim-paredit/indentation/init.lua b/lua/nvim-paredit/indentation/init.lua index f301d47..8bb6b5f 100644 --- a/lua/nvim-paredit/indentation/init.lua +++ b/lua/nvim-paredit/indentation/init.lua @@ -23,7 +23,15 @@ function M.handle_indentation(operation) return end - indent_fn(operation) + local tree = vim.treesitter.get_parser(0) + + tree:parse() + + local parent = tree:named_node_for_range(operation.parent_range) + indent_fn(vim.tbl_deep_extend("force", operation, { + tree = tree, + parent = parent, + })) end return M diff --git a/lua/nvim-paredit/indentation/native.lua b/lua/nvim-paredit/indentation/native.lua index f45c76a..459005f 100644 --- a/lua/nvim-paredit/indentation/native.lua +++ b/lua/nvim-paredit/indentation/native.lua @@ -1,20 +1,104 @@ local M = {} +local function get_node_line_range(range) + local lines = {} + for i = range[1], range[3], 1 do + table.insert(lines, i) + end + return lines +end + +function siblings(node) + local siblings = {} + local current = node + while current do + table.insert(siblings, current) + current = current:next_named_sibling() + end + return siblings +end + +function ordered_set(lines) + local seen = {} + local result = {} + for _, value in ipairs(lines) do + if not seen[value] then + table.insert(result, value) + seen[value] = true + end + end + + table.sort(result) + return result +end + +function find_affected_lines(node, lines) + local siblings = siblings(node) + for _, sibling in ipairs(siblings) do + local range = { sibling:range() } + + local sibling_is_affected = false + for _, line in ipairs(lines) do + if line == range[1] then + sibling_is_affected = true + end + end + + if sibling_is_affected then + local new_lines = get_node_line_range(range) + for _, row in ipairs(new_lines) do + table.insert(lines, row) + end + end + end + + local parent = node:parent() + if parent then + return find_affected_lines(parent, lines) + end + + return ordered_set(lines) +end + +function is_first_node_on_line(node) + local node_range = { node:range() } + local sibling = node:prev_named_sibling() + if not sibling then + return true + end + + local sibling_range = { sibling:range() } + return sibling_range[3] ~= node_range[1] +end + function indent_barf(operation) - local child_range = { operation.child:range() } - if operation.to[1] == child_range[1] then + local lhs + local node + if operation.type == "barf-forwards" then + node = operation.parent:next_named_sibling() + lhs = operation.parent + else + node = operation.parent + lhs = operation.parent:prev_named_sibling() + end + + local parent = node:parent() + + local lhs_range = { lhs:range() } + local node_range = { node:range() } + + if not is_first_node_on_line(node) or lhs_range[1] == node_range[1] then return end - local lang = operation.lang - local parent = lang.get_node_root(operation.parent):parent() + local lines = find_affected_lines(node, get_node_line_range(node_range)) local delta if parent:type() == "source" then - delta = child_range[2] + delta = node_range[2] else local form_edges = operation.lang.get_form_edges(parent) - delta = child_range[2] - form_edges.left.range[2] - 1 + delta = node_range[2] - form_edges.left.range[2] - 1 end if delta == 0 then @@ -24,44 +108,42 @@ function indent_barf(operation) if delta < 0 then local spaces = string.rep(" ", delta * -1) - for i = child_range[1], child_range[3], 1 do + for _, line in ipairs(lines) do -- stylua: ignore vim.api.nvim_buf_set_text( operation.buf or 0, - i, 0, - i, 0, + line, 0, + line, 0, {spaces} ) end else -- stylua: ignore - local lines = vim.api.nvim_buf_get_lines( + local line_text = vim.api.nvim_buf_get_lines( operation.buf or 0, - child_range[1], child_range[3] + 1, + lines[1], lines[#lines] + 1, false ) local smallest_distance = delta - for i = 1, #lines, 1 do - local first_char_index = string.find(lines[i], "[^%s]") + for _, line in ipairs(line_text) do + local first_char_index = string.find(line, "[^%s]") if first_char_index and (first_char_index - 1) < smallest_distance then smallest_distance = first_char_index - 1 end end - local line_index = 0 - for i = child_range[1], child_range[3], 1 do - line_index = line_index + 1 + for index, line in ipairs(lines) do local deletion_range = smallest_distance - local contains_chars = string.find(lines[line_index], "[^%s]") + local contains_chars = string.find(line_text[index], "[^%s]") if not contains_chars then - deletion_range = #lines[line_index] + deletion_range = #line_text[index] end -- stylua: ignore vim.api.nvim_buf_set_text( operation.buf or 0, - i, 0, - i, deletion_range, + line, 0, + line, deletion_range, {} ) end @@ -69,12 +151,25 @@ function indent_barf(operation) end function indent_slurp(operation) - local child_range = { operation.child:range() } - if operation.from[1] == child_range[1] then + local parent = operation.parent + + local child + if operation.type == "slurp-forwards" then + child = parent:named_child(parent:named_child_count() - 1) + else + child = parent:named_child(1) + end + + local parent_range = { parent:range() } + local child_range = { child:range() } + + if not is_first_node_on_line(child) or parent_range[1] == child_range[1] then return end - local form_edges = operation.lang.get_form_edges(operation.parent) + local lines = find_affected_lines(child, get_node_line_range(child_range)) + + local form_edges = operation.lang.get_form_edges(parent) local delta = form_edges.left.range[4] - child_range[2] if delta == 0 then @@ -84,44 +179,42 @@ function indent_slurp(operation) if delta > 0 then local spaces = string.rep(" ", delta) - for i = child_range[1], child_range[3], 1 do + for _, line in pairs(lines) do -- stylua: ignore vim.api.nvim_buf_set_text( operation.buf or 0, - i, 0, - i, 0, + line, 0, + line, 0, {spaces} ) end else -- stylua: ignore - local lines = vim.api.nvim_buf_get_lines( + local line_text = vim.api.nvim_buf_get_lines( operation.buf or 0, - child_range[1], child_range[3] + 1, + lines[1], lines[#lines] + 1, false ) local smallest_distance = delta * -1 - for i = 1, #lines, 1 do - local first_char_index = string.find(lines[i], "[^%s]") + for _, line in ipairs(line_text) do + local first_char_index = string.find(line, "[^%s]") if first_char_index and (first_char_index - 1) < smallest_distance then smallest_distance = first_char_index - 1 end end - local line_index = 0 - for i = child_range[1], child_range[3], 1 do - line_index = line_index + 1 + for index, line in ipairs(lines) do local deletion_range = smallest_distance - local contains_chars = string.find(lines[line_index], "[^%s]") + local contains_chars = string.find(line_text[index], "[^%s]") if not contains_chars then - deletion_range = #lines[line_index] + deletion_range = #line_text[index] end -- stylua: ignore vim.api.nvim_buf_set_text( operation.buf or 0, - i, 0, - i, deletion_range, + line, 0, + line, deletion_range, {} ) end @@ -129,7 +222,7 @@ function indent_slurp(operation) end function M.indent(operation) - if operation.type == "slurp" then + if operation.type == "slurp-forwards" or operation.type == "slurp-backwards" then indent_slurp(operation) else indent_barf(operation) diff --git a/tests/nvim-paredit/indentation_spec.lua b/tests/nvim-paredit/indentation_spec.lua index 96c55ad..0ae70c0 100644 --- a/tests/nvim-paredit/indentation_spec.lua +++ b/tests/nvim-paredit/indentation_spec.lua @@ -36,6 +36,30 @@ describe("forward slurping indentation", function() }) end) + it("should indent a sibling multi-line child", function() + prepare_buffer({ + content = { "()", "(a", " b) (c", "d)" }, + cursor = { 1, 1 }, + }) + slurp_forwards() + expect({ + content = { "(", " (a", " b)) (c", " d)" }, + cursor = { 1, 0 }, + }) + end) + + it("should not indent if node is not first on line", function() + prepare_buffer({ + content = { "(", "a) (a", "b)" }, + cursor = { 1, 1 }, + }) + slurp_forwards() + expect({ + content = { "(", "a (a", "b))" }, + cursor = { 1, 0 }, + }) + end) + it("should not indent when on same line", function() prepare_buffer({ content = "() 1", @@ -50,6 +74,28 @@ describe("forward slurping indentation", function() end) end) + +describe("backward slurping indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function slurp_backwards() + paredit.slurp_backwards({ + indent_behaviour = "native", + }) + end + + it("should indent a nested child", function() + prepare_buffer({ + content = { "a", "(b)" }, + cursor = { 2, 1 }, + }) + slurp_backwards() + expect({ + content = { "(a", " b)" }, + cursor = { 2, 2 }, + }) + end) +end) + describe("forward barfing indentation", function() vim.api.nvim_buf_set_option(0, "filetype", "clojure") local function barf_forwards() @@ -82,6 +128,18 @@ describe("forward barfing indentation", function() }) end) + it("should not dedent if node is not first on line", function() + prepare_buffer({ + content = { "(a", "b c)" }, + cursor = { 1, 1 }, + }) + barf_forwards() + expect({ + content = { "(a", "b) c" }, + cursor = { 1, 1 }, + }) + end) + it("should not deindent when on same line", function() prepare_buffer({ content = "( 1)", @@ -106,3 +164,24 @@ describe("forward barfing indentation", function() }) end) end) + +describe("backward barfing indentation", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + local function barf_backwards() + paredit.barf_backwards({ + indent_behaviour = "native", + }) + end + + it("should deindent a nested child", function() + prepare_buffer({ + content = { "(a", " b)" }, + cursor = { 1, 0 }, + }) + barf_backwards() + expect({ + content = { "a", "(b)" }, + cursor = { 1, 0 }, + }) + end) +end)