diff --git a/README.md b/README.md index c8b79ec..3d4a2af 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,13 @@ paredit.setup({ -- this case it will place the cursor on the moved edge cursor_behaviour = "auto", -- remain, follow, auto + dragging = { + -- If set to `true` paredit will attempt to infer if an element being + -- dragged is part of a 'paired' form like as a map. If so then the element + -- will be dragged along with it's pair. + auto_drag_pairs = true, + }, + indent = { -- This controls how nvim-paredit handles indentation when performing operations which -- should change the indentation of the form (such as when slurping or barfing). @@ -88,6 +95,9 @@ paredit.setup({ [">e"] = { paredit.api.drag_element_forwards, "Drag element right" }, ["p"] = { api.drag_pair_forwards, "Drag element pairs right" }, + ["f"] = { paredit.api.drag_form_forwards, "Drag form right" }, [" +## Pairwise Dragging + +Nvim-paredit has support for dragging elements pairwise. If an element being dragged is within a form that contains +pairs of elements (such as a clojure `map`) then the element will be dragged along with it's pair. + +For example: + +```clojure +{:a 1 + |:b 2} +;; Drag backwards +{|:b 2 + :a 1} +``` + +This is enabled by default and can be disabled by setting `dragging.auto_drag_pairs = false`. + +Pairwise dragging works using treesitter queries to identify element pairs within some localized node. This means you +can very easily extend the paredit pairwise implementation by simply adding new treesitter queries to your nvim +configuration. + +You might want to extend if: + +1) You are a language extension author and want to add pairwise dragging support to your extension. +2) You want to add support for some syntax not supported by nvim-paredit. + +This is especially useful if you have your own clojure macros that you want to enable pairwise dragging on. + +All you need to do to extend is to add a new file called `queries//paredit/pairwise.scm` in your nvim config +directory. Make sure to include the `;; extends` directive to the file or you will overwrite any pre-existing queries +defined by nvim-paredit or other language extensions. + +As an example if you want to add support for the following clojure macro: + +```clojure +(defmacro my-custom-bindings [bindings & body] + ...) + +(my-custom-bindings [a 1 + b 2] + (println a b)) +``` + +You can add the following TS query + +```scm +;; extends + +(list_lit + (sym_lit) @fn-name + (vec_lit + (_) @pair) + (#eq? @fn-name "my-custom-bindings")) +``` + ## Language Support As this is built using Treesitter it requires that you have the relevant Treesitter grammar installed for your language @@ -332,6 +397,8 @@ paredit.api.slurp_forwards() - **`barf_backwards`** - **`drag_element_forwards`** - **`drag_element_backwards`** +- **`drag_pair_forwards`** +- **`drag_pair_backwards`** - **`drag_form_forwards`** - **`drag_form_backwards`** - **`raise_element`** diff --git a/lua/nvim-paredit/api/dragging.lua b/lua/nvim-paredit/api/dragging.lua index 39d1ea3..6409c8c 100644 --- a/lua/nvim-paredit/api/dragging.lua +++ b/lua/nvim-paredit/api/dragging.lua @@ -1,5 +1,8 @@ local traversal = require("nvim-paredit.utils.traversal") +local common = require("nvim-paredit.utils.common") +local ts_utils = require("nvim-paredit.utils.ts") local ts = require("nvim-treesitter.ts_utils") +local config = require("nvim-paredit.config") local langs = require("nvim-paredit.lang") local M = {} @@ -44,24 +47,79 @@ function M.drag_form_backwards() ts.swap_nodes(root, sibling, buf, true) end -function M.drag_element_forwards() - local lang = langs.get_language_api() - local current_node = lang.get_node_root(ts.get_node_at_cursor()) +local function find_current_pair(pairs, current_node) + for i, pair in ipairs(pairs) do + for _, node in ipairs(pair) do + if node:equal(current_node) then + return i, pair + end + end + end +end - local sibling = current_node:next_named_sibling() - if not sibling then +local function drag_node_in_pair(current_node, nodes, opts) + local direction = 1 + if opts.reversed then + direction = -1 + end + + local pairs = common.chunk_table(nodes, 2) + local chunk_index, pair = find_current_pair(pairs, current_node) + + local corresponding_pair = pairs[chunk_index + direction] + if not corresponding_pair then return end local buf = vim.api.nvim_get_current_buf() - ts.swap_nodes(current_node, sibling, buf, true) + if pair[2] and corresponding_pair[2] then + ts.swap_nodes(pair[2], corresponding_pair[2], buf, true) + end + if pair[1] and corresponding_pair[1] then + ts.swap_nodes(pair[1], corresponding_pair[1], buf, true) + end end -function M.drag_element_backwards() +local function drag_pair(opts) local lang = langs.get_language_api() local current_node = lang.get_node_root(ts.get_node_at_cursor()) + if not current_node then + return + end + + local pairwise_nodes = ts_utils.find_pairwise_nodes(current_node, opts) + if not pairwise_nodes then + local parent = current_node:parent() + if not parent then + return + end + + pairwise_nodes = traversal.get_children_ignoring_comments(parent, { + lang = lang, + }) + end + + drag_node_in_pair(current_node, pairwise_nodes, opts) +end + +local function drag_element(opts) + local lang = langs.get_language_api() + local current_node = lang.get_node_root(ts.get_node_at_cursor()) + + if opts.dragging.auto_drag_pairs then + local pairwise_nodes = ts_utils.find_pairwise_nodes(current_node, { lang = lang }) + if pairwise_nodes then + return drag_node_in_pair(current_node, pairwise_nodes, opts) + end + end + + local sibling + if opts.reversed then + sibling = current_node:prev_named_sibling() + else + sibling = current_node:next_named_sibling() + end - local sibling = current_node:prev_named_sibling() if not sibling then return end @@ -70,4 +128,44 @@ function M.drag_element_backwards() ts.swap_nodes(current_node, sibling, buf, true) end +function M.drag_element_forwards(opts) + local drag_opts = vim.tbl_deep_extend( + "force", + { + dragging = config.config.dragging or {}, + }, + opts or {}, + { + reversed = false, + } + ) + drag_element(drag_opts) +end + +function M.drag_element_backwards(opts) + local drag_opts = vim.tbl_deep_extend( + "force", + { + dragging = config.config.dragging or {}, + }, + opts or {}, + { + reversed = true, + } + ) + drag_element(drag_opts) +end + +function M.drag_pair_forwards() + drag_pair({ + reversed = false, + }) +end + +function M.drag_pair_backwards() + drag_pair({ + reversed = true, + }) +end + return M diff --git a/lua/nvim-paredit/api/init.lua b/lua/nvim-paredit/api/init.lua index df573fa..efac43e 100644 --- a/lua/nvim-paredit/api/init.lua +++ b/lua/nvim-paredit/api/init.lua @@ -15,6 +15,10 @@ local M = { drag_element_forwards = dragging.drag_element_forwards, drag_element_backwards = dragging.drag_element_backwards, + + drag_pair_forwards = dragging.drag_pair_forwards, + drag_pair_backwards = dragging.drag_pair_backwards, + drag_form_forwards = dragging.drag_form_forwards, drag_form_backwards = dragging.drag_form_backwards, diff --git a/lua/nvim-paredit/api/selections.lua b/lua/nvim-paredit/api/selections.lua index 1531801..3f9f058 100644 --- a/lua/nvim-paredit/api/selections.lua +++ b/lua/nvim-paredit/api/selections.lua @@ -41,7 +41,7 @@ function M.get_range_around_form() end function M.get_range_around_top_level_form() - return get_range_around_form_impl(traversal.get_top_level_node_below_document) + return get_range_around_form_impl(traversal.find_local_root) end local function select_around_form_impl(range) @@ -93,7 +93,7 @@ function M.get_range_in_form() end function M.get_range_in_top_level_form() - return get_range_in_form_impl(traversal.get_top_level_node_below_document) + return get_range_in_form_impl(traversal.find_local_root) end local function select_in_form_impl(range) diff --git a/lua/nvim-paredit/defaults.lua b/lua/nvim-paredit/defaults.lua index 03eb727..cf8caa1 100644 --- a/lua/nvim-paredit/defaults.lua +++ b/lua/nvim-paredit/defaults.lua @@ -4,7 +4,7 @@ local unwrap = require("nvim-paredit.api.unwrap") local M = {} M.default_keys = { - ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp", }, + ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp" }, [">)"] = { api.slurp_forwards, "Slurp forwards" }, [">("] = { api.barf_backwards, "Barf backwards" }, @@ -15,6 +15,9 @@ M.default_keys = { [">e"] = { api.drag_element_forwards, "Drag element right" }, ["p"] = { api.drag_pair_forwards, "Drag element pairs right" }, + ["f"] = { api.drag_form_forwards, "Drag form right" }, [" function M.get_language_api() for l in string.gmatch(vim.bo.filetype, "[^.]+") do if langs[l] ~= nil then return langs[l] end end - return nil + error("Could not find language extension for filetype " .. vim.bo.filetype, vim.log.levels.ERROR) end function M.add_language_extension(filetype, api) diff --git a/lua/nvim-paredit/utils/common.lua b/lua/nvim-paredit/utils/common.lua index 93b7c22..68c9825 100644 --- a/lua/nvim-paredit/utils/common.lua +++ b/lua/nvim-paredit/utils/common.lua @@ -9,6 +9,20 @@ function M.included_in_table(table, item) return false end +function M.chunk_table(tbl, chunk_size) + local result = {} + for i = 1, #tbl, chunk_size do + local chunk = {} + for j = 0, chunk_size - 1 do + if tbl[i + j] then + table.insert(chunk, tbl[i + j]) + end + end + table.insert(result, chunk) + end + return result +end + -- Compares the two given { col, row } position tuples and returns -1/0/1 depending -- on whether `a` is less than, equal to or greater than `b` -- diff --git a/lua/nvim-paredit/utils/traversal.lua b/lua/nvim-paredit/utils/traversal.lua index b8857ed..72d89d7 100644 --- a/lua/nvim-paredit/utils/traversal.lua +++ b/lua/nvim-paredit/utils/traversal.lua @@ -16,6 +16,22 @@ function M.find_nearest_form(current_node, opts) end end +function M.get_children_ignoring_comments(node, opts) + local children = {} + + local index = 0 + local child = node:named_child(index) + while child do + if not child:extra() and not opts.lang.node_is_comment(child) then + table.insert(children, child) + end + index = index + 1 + child = node:named_child(index) + end + + return children +end + local function get_child_ignoring_comments(node, index, opts) if index < 0 or index >= node:named_child_count() then return @@ -139,34 +155,18 @@ function M.find_root_element_relative_to(root, child) return M.find_root_element_relative_to(root, parent) end -function M.get_top_level_node_below_document(node) - -- Document - -- - Branch A - -- -- Node X - -- --- Sub-node 1 - -- - Branch B - -- -- Node Y - -- --- Sub-node 2 - -- --- Sub-node 3 - - -- If we call this function on "Sub-node 2" we expect "Branch B" to be - -- returned, the top level one below the document itself. We know which - -- node is the document because it lacks a parent, just like Batman. - - local parent = node:parent() - - -- Does the node have a parent? If so, we might be at the right level. - -- If not, we should just return the node right away, we're already too high. - if parent then - -- If the parent _also_ has a parent then we still need to go higher, recur. - if parent:parent() then - return M.get_top_level_node_below_document(parent) +-- Find the root node of the tree `node` is a member of, excluding the root +-- 'source' document. +function M.find_local_root(node) + local current = node + while true do + local next = current:parent() + if not next or next:type() == "source" then + break end + current = next end - - -- As soon as we don't have a grandparent or parent, return the node - -- we're on because it means we're one step below the top level document node. - return node + return current end return M diff --git a/lua/nvim-paredit/utils/ts.lua b/lua/nvim-paredit/utils/ts.lua new file mode 100644 index 0000000..606e4ad --- /dev/null +++ b/lua/nvim-paredit/utils/ts.lua @@ -0,0 +1,42 @@ +local traversal = require("nvim-paredit.utils.traversal") + +local M = {} + +-- Use a 'paredit/pairwise' treesitter query to find all nodes within a local +-- branch that are labeled as @pair. +-- +-- If any of these labeled nodes match the given target node then return all +-- matched nodes. +function M.find_pairwise_nodes(target_node, opts) + local root_node = traversal.find_local_root(target_node) + + local bufnr = vim.api.nvim_get_current_buf() + local lang = vim.treesitter.language.get_lang(vim.bo.filetype) + + local query = vim.treesitter.query.get(lang, "paredit/pairwise") + if not query then + return + end + + local captures = query:iter_captures(root_node, bufnr) + local pairwise_nodes = {} + local found = false + for id, node in captures do + if query.captures[id] == "pair" then + if not node:extra() and not opts.lang.node_is_comment(node) then + table.insert(pairwise_nodes, node) + if node:equal(target_node) then + found = true + end + end + end + end + + if not found then + return + end + + return pairwise_nodes +end + +return M diff --git a/queries/clojure/paredit/pairwise.scm b/queries/clojure/paredit/pairwise.scm new file mode 100644 index 0000000..78a71a2 --- /dev/null +++ b/queries/clojure/paredit/pairwise.scm @@ -0,0 +1,26 @@ +(list_lit + (sym_lit) @fn-name + (vec_lit + (_) @pair) + (#any-of? @fn-name "let" "loop" "binding" "with-open" "with-redefs")) + +(map_lit + (_) @pair) + +(list_lit + (sym_lit) @fn-name + (_) + (_) @pair + (#eq? @fn-name "case")) + +(list_lit + (sym_lit) @fn-name + (_) @pair + (#eq? @fn-name "cond")) + +(list_lit + (sym_lit) @fn-name + (_) + (_) + (_) @pair + (#eq? @fn-name "condp")) diff --git a/tests/nvim-paredit/pair_drag_spec.lua b/tests/nvim-paredit/pair_drag_spec.lua new file mode 100644 index 0000000..139f7df --- /dev/null +++ b/tests/nvim-paredit/pair_drag_spec.lua @@ -0,0 +1,86 @@ +local paredit = require("nvim-paredit.api") + +local prepare_buffer = require("tests.nvim-paredit.utils").prepare_buffer +local expect_all = require("tests.nvim-paredit.utils").expect_all +local expect = require("tests.nvim-paredit.utils").expect + +describe("paired-element-auto-dragging", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + it("should drag map pairs forward", function() + prepare_buffer({ + content = "{:a 1 :b 2}", + cursor = { 1, 1 }, + }) + + paredit.drag_element_forwards({ + dragging = { + auto_drag_pairs = true, + }, + }) + expect({ + content = "{:b 2 :a 1}", + cursor = { 1, 6 }, + }) + end) + + it("should drag map pairs backwards", function() + prepare_buffer({ + content = "{:a 1 :b 2}", + cursor = { 1, 9 }, + }) + + paredit.drag_element_backwards({ + dragging = { + auto_drag_pairs = true, + }, + }) + expect({ + content = "{:b 2 :a 1}", + cursor = { 1, 1 }, + }) + end) + + it("should detect various types", function() + expect_all(function() + paredit.drag_element_forwards({ dragging = { auto_drag_pairs = true } }) + end, { + { + "let binding", + before_content = "(let [a b c d])", + before_cursor = { 1, 6 }, + after_content = "(let [c d a b])", + after_cursor = { 1, 10 }, + }, + { + "loop binding", + before_content = "(loop [a b c d])", + before_cursor = { 1, 7 }, + after_content = "(loop [c d a b])", + after_cursor = { 1, 11 }, + }, + { + "case", + before_content = "(case a :a 1 :b 2)", + before_cursor = { 1, 8 }, + after_content = "(case a :b 2 :a 1)", + after_cursor = { 1, 13 }, + }, + }) + end) +end) + +describe("paired-element-dragging", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + it("should drag vector elements forwards", function() + prepare_buffer({ + content = "'[a b c d]", + cursor = { 1, 2 }, + }) + + paredit.drag_pair_forwards() + expect({ + content = "'[c d a b]", + cursor = { 1, 6 }, + }) + end) +end)