diff --git a/README.md b/README.md index c8b79ec..e2e9b8f 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,13 @@ paredit.setup({ -- this case it will place the cursor on the moved edge cursor_behaviour = "auto", -- remain, follow, auto + dragging = { + -- If set to `true` paredit will attempt to infer if an element being + -- dragged is part of a 'paired' form like as a map. If so then the element + -- will be dragged along with it's pair. + auto_drag_pairs = true, + }, + indent = { -- This controls how nvim-paredit handles indentation when performing operations which -- should change the indentation of the form (such as when slurping or barfing). @@ -88,6 +95,9 @@ paredit.setup({ [">e"] = { paredit.api.drag_element_forwards, "Drag element right" }, ["p"] = { api.drag_pair_forwards, "Drag element pairs right" }, + ["f"] = { paredit.api.drag_form_forwards, "Drag form right" }, ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp", }, + ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp" }, [">)"] = { api.slurp_forwards, "Slurp forwards" }, [">("] = { api.barf_backwards, "Barf backwards" }, @@ -15,6 +15,9 @@ M.default_keys = { [">e"] = { api.drag_element_forwards, "Drag element right" }, ["p"] = { api.drag_pair_forwards, "Drag element pairs right" }, + ["f"] = { api.drag_form_forwards, "Drag form right" }, [" function M.get_language_api() for l in string.gmatch(vim.bo.filetype, "[^.]+") do if langs[l] ~= nil then return langs[l] end end - return nil + error("Could not find language extension for filetype " .. vim.bo.filetype, vim.log.levels.ERROR) end function M.add_language_extension(filetype, api) diff --git a/lua/nvim-paredit/utils/common.lua b/lua/nvim-paredit/utils/common.lua index 93b7c22..68c9825 100644 --- a/lua/nvim-paredit/utils/common.lua +++ b/lua/nvim-paredit/utils/common.lua @@ -9,6 +9,20 @@ function M.included_in_table(table, item) return false end +function M.chunk_table(tbl, chunk_size) + local result = {} + for i = 1, #tbl, chunk_size do + local chunk = {} + for j = 0, chunk_size - 1 do + if tbl[i + j] then + table.insert(chunk, tbl[i + j]) + end + end + table.insert(result, chunk) + end + return result +end + -- Compares the two given { col, row } position tuples and returns -1/0/1 depending -- on whether `a` is less than, equal to or greater than `b` -- diff --git a/lua/nvim-paredit/utils/traversal.lua b/lua/nvim-paredit/utils/traversal.lua index b8857ed..72d89d7 100644 --- a/lua/nvim-paredit/utils/traversal.lua +++ b/lua/nvim-paredit/utils/traversal.lua @@ -16,6 +16,22 @@ function M.find_nearest_form(current_node, opts) end end +function M.get_children_ignoring_comments(node, opts) + local children = {} + + local index = 0 + local child = node:named_child(index) + while child do + if not child:extra() and not opts.lang.node_is_comment(child) then + table.insert(children, child) + end + index = index + 1 + child = node:named_child(index) + end + + return children +end + local function get_child_ignoring_comments(node, index, opts) if index < 0 or index >= node:named_child_count() then return @@ -139,34 +155,18 @@ function M.find_root_element_relative_to(root, child) return M.find_root_element_relative_to(root, parent) end -function M.get_top_level_node_below_document(node) - -- Document - -- - Branch A - -- -- Node X - -- --- Sub-node 1 - -- - Branch B - -- -- Node Y - -- --- Sub-node 2 - -- --- Sub-node 3 - - -- If we call this function on "Sub-node 2" we expect "Branch B" to be - -- returned, the top level one below the document itself. We know which - -- node is the document because it lacks a parent, just like Batman. - - local parent = node:parent() - - -- Does the node have a parent? If so, we might be at the right level. - -- If not, we should just return the node right away, we're already too high. - if parent then - -- If the parent _also_ has a parent then we still need to go higher, recur. - if parent:parent() then - return M.get_top_level_node_below_document(parent) +-- Find the root node of the tree `node` is a member of, excluding the root +-- 'source' document. +function M.find_local_root(node) + local current = node + while true do + local next = current:parent() + if not next or next:type() == "source" then + break end + current = next end - - -- As soon as we don't have a grandparent or parent, return the node - -- we're on because it means we're one step below the top level document node. - return node + return current end return M diff --git a/lua/nvim-paredit/utils/ts.lua b/lua/nvim-paredit/utils/ts.lua new file mode 100644 index 0000000..09cba9a --- /dev/null +++ b/lua/nvim-paredit/utils/ts.lua @@ -0,0 +1,37 @@ +local traversal = require("nvim-paredit.utils.traversal") + +local M = {} + +function M.find_pairwise_nodes(target_node, opts) + local root_node = traversal.find_local_root(target_node) + + local bufnr = vim.api.nvim_get_current_buf() + local lang = vim.treesitter.language.get_lang(vim.bo.filetype) + + local query = vim.treesitter.query.get(lang, "paredit/pairwise") + if not query then + return + end + + local captures = query:iter_captures(root_node, bufnr) + local pairwise_nodes = {} + local found = false + for id, node in captures do + if query.captures[id] == "pair" then + if not node:extra() and not opts.lang.node_is_comment(node) then + table.insert(pairwise_nodes, node) + if node:equal(target_node) then + found = true + end + end + end + end + + if not found then + return + end + + return pairwise_nodes +end + +return M diff --git a/queries/clojure/paredit/pairwise.scm b/queries/clojure/paredit/pairwise.scm new file mode 100644 index 0000000..78a71a2 --- /dev/null +++ b/queries/clojure/paredit/pairwise.scm @@ -0,0 +1,26 @@ +(list_lit + (sym_lit) @fn-name + (vec_lit + (_) @pair) + (#any-of? @fn-name "let" "loop" "binding" "with-open" "with-redefs")) + +(map_lit + (_) @pair) + +(list_lit + (sym_lit) @fn-name + (_) + (_) @pair + (#eq? @fn-name "case")) + +(list_lit + (sym_lit) @fn-name + (_) @pair + (#eq? @fn-name "cond")) + +(list_lit + (sym_lit) @fn-name + (_) + (_) + (_) @pair + (#eq? @fn-name "condp")) diff --git a/tests/nvim-paredit/pair_drag_spec.lua b/tests/nvim-paredit/pair_drag_spec.lua new file mode 100644 index 0000000..139f7df --- /dev/null +++ b/tests/nvim-paredit/pair_drag_spec.lua @@ -0,0 +1,86 @@ +local paredit = require("nvim-paredit.api") + +local prepare_buffer = require("tests.nvim-paredit.utils").prepare_buffer +local expect_all = require("tests.nvim-paredit.utils").expect_all +local expect = require("tests.nvim-paredit.utils").expect + +describe("paired-element-auto-dragging", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + it("should drag map pairs forward", function() + prepare_buffer({ + content = "{:a 1 :b 2}", + cursor = { 1, 1 }, + }) + + paredit.drag_element_forwards({ + dragging = { + auto_drag_pairs = true, + }, + }) + expect({ + content = "{:b 2 :a 1}", + cursor = { 1, 6 }, + }) + end) + + it("should drag map pairs backwards", function() + prepare_buffer({ + content = "{:a 1 :b 2}", + cursor = { 1, 9 }, + }) + + paredit.drag_element_backwards({ + dragging = { + auto_drag_pairs = true, + }, + }) + expect({ + content = "{:b 2 :a 1}", + cursor = { 1, 1 }, + }) + end) + + it("should detect various types", function() + expect_all(function() + paredit.drag_element_forwards({ dragging = { auto_drag_pairs = true } }) + end, { + { + "let binding", + before_content = "(let [a b c d])", + before_cursor = { 1, 6 }, + after_content = "(let [c d a b])", + after_cursor = { 1, 10 }, + }, + { + "loop binding", + before_content = "(loop [a b c d])", + before_cursor = { 1, 7 }, + after_content = "(loop [c d a b])", + after_cursor = { 1, 11 }, + }, + { + "case", + before_content = "(case a :a 1 :b 2)", + before_cursor = { 1, 8 }, + after_content = "(case a :b 2 :a 1)", + after_cursor = { 1, 13 }, + }, + }) + end) +end) + +describe("paired-element-dragging", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + it("should drag vector elements forwards", function() + prepare_buffer({ + content = "'[a b c d]", + cursor = { 1, 2 }, + }) + + paredit.drag_pair_forwards() + expect({ + content = "'[c d a b]", + cursor = { 1, 6 }, + }) + end) +end)