From cf1b77fc8f3bfd9928f5117ab504e6189c6ce8cb Mon Sep 17 00:00:00 2001 From: Julien Vincent Date: Fri, 11 Oct 2024 00:48:07 +0100 Subject: [PATCH] Implement pairwise element dragging This adds two new APIs - `drag_pair_forwards` and its companion `drag_pair_backwards` which allow dragging 'pairs' of elements within a form. An easy example is the key-value pairs within a map which typically are meant to stay together when reordering. This can, however, be used for any form. This means if you have some vector containing logical pairs (because of the actual semantics of your code) you can use a dedicated keybinding to drag them around together. This change also introduces a new config option at: `dragging.auto_drag_pairs = true|false` which will alter the behaviour of the existing `drag_element_forwards` and `drag_element_backwards` APIs in order to try infer whether they are contained within a node that is made up of pairs. For example if this setting is `true` and a `drag_element_forwards` is used on the keys of a map then they will be dragged pairwise. This new config option defaults to `true` under the assumption that this is generally the desired behaviour. --- README.md | 75 +++++++++++++++- lua/nvim-paredit/api/dragging.lua | 119 ++++++++++++++++++++++++-- lua/nvim-paredit/api/init.lua | 4 + lua/nvim-paredit/api/selections.lua | 4 +- lua/nvim-paredit/defaults.lua | 11 ++- lua/nvim-paredit/lang/clojure.lua | 4 + lua/nvim-paredit/lang/init.lua | 5 +- lua/nvim-paredit/utils/common.lua | 14 +++ lua/nvim-paredit/utils/traversal.lua | 52 +++++------ lua/nvim-paredit/utils/ts.lua | 42 +++++++++ queries/clojure/paredit/pairwise.scm | 26 ++++++ tests/nvim-paredit/pair_drag_spec.lua | 86 +++++++++++++++++++ 12 files changed, 398 insertions(+), 44 deletions(-) create mode 100644 lua/nvim-paredit/utils/ts.lua create mode 100644 queries/clojure/paredit/pairwise.scm create mode 100644 tests/nvim-paredit/pair_drag_spec.lua diff --git a/README.md b/README.md index c8b79ec..3d4a2af 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,13 @@ paredit.setup({ -- this case it will place the cursor on the moved edge cursor_behaviour = "auto", -- remain, follow, auto + dragging = { + -- If set to `true` paredit will attempt to infer if an element being + -- dragged is part of a 'paired' form like as a map. If so then the element + -- will be dragged along with it's pair. + auto_drag_pairs = true, + }, + indent = { -- This controls how nvim-paredit handles indentation when performing operations which -- should change the indentation of the form (such as when slurping or barfing). @@ -88,6 +95,9 @@ paredit.setup({ [">e"] = { paredit.api.drag_element_forwards, "Drag element right" }, ["p"] = { api.drag_pair_forwards, "Drag element pairs right" }, + ["f"] = { paredit.api.drag_form_forwards, "Drag form right" }, [" +## Pairwise Dragging + +Nvim-paredit has support for dragging elements pairwise. If an element being dragged is within a form that contains +pairs of elements (such as a clojure `map`) then the element will be dragged along with it's pair. + +For example: + +```clojure +{:a 1 + |:b 2} +;; Drag backwards +{|:b 2 + :a 1} +``` + +This is enabled by default and can be disabled by setting `dragging.auto_drag_pairs = false`. + +Pairwise dragging works using treesitter queries to identify element pairs within some localized node. This means you +can very easily extend the paredit pairwise implementation by simply adding new treesitter queries to your nvim +configuration. + +You might want to extend if: + +1) You are a language extension author and want to add pairwise dragging support to your extension. +2) You want to add support for some syntax not supported by nvim-paredit. + +This is especially useful if you have your own clojure macros that you want to enable pairwise dragging on. + +All you need to do to extend is to add a new file called `queries//paredit/pairwise.scm` in your nvim config +directory. Make sure to include the `;; extends` directive to the file or you will overwrite any pre-existing queries +defined by nvim-paredit or other language extensions. + +As an example if you want to add support for the following clojure macro: + +```clojure +(defmacro my-custom-bindings [bindings & body] + ...) + +(my-custom-bindings [a 1 + b 2] + (println a b)) +``` + +You can add the following TS query + +```scm +;; extends + +(list_lit + (sym_lit) @fn-name + (vec_lit + (_) @pair) + (#eq? @fn-name "my-custom-bindings")) +``` + ## Language Support As this is built using Treesitter it requires that you have the relevant Treesitter grammar installed for your language @@ -332,6 +397,8 @@ paredit.api.slurp_forwards() - **`barf_backwards`** - **`drag_element_forwards`** - **`drag_element_backwards`** +- **`drag_pair_forwards`** +- **`drag_pair_backwards`** - **`drag_form_forwards`** - **`drag_form_backwards`** - **`raise_element`** diff --git a/lua/nvim-paredit/api/dragging.lua b/lua/nvim-paredit/api/dragging.lua index 39d1ea3..1685edb 100644 --- a/lua/nvim-paredit/api/dragging.lua +++ b/lua/nvim-paredit/api/dragging.lua @@ -1,5 +1,8 @@ local traversal = require("nvim-paredit.utils.traversal") +local common = require("nvim-paredit.utils.common") +local ts_utils = require("nvim-paredit.utils.ts") local ts = require("nvim-treesitter.ts_utils") +local config = require("nvim-paredit.config") local langs = require("nvim-paredit.lang") local M = {} @@ -44,24 +47,84 @@ function M.drag_form_backwards() ts.swap_nodes(root, sibling, buf, true) end -function M.drag_element_forwards() - local lang = langs.get_language_api() - local current_node = lang.get_node_root(ts.get_node_at_cursor()) +local function find_current_pair(pairs, current_node) + for i, pair in ipairs(pairs) do + for _, node in ipairs(pair) do + if node:equal(current_node) then + return i, pair + end + end + end +end - local sibling = current_node:next_named_sibling() - if not sibling then +local function drag_node_in_pair(current_node, nodes, opts) + local direction = 1 + if opts.reversed then + direction = -1 + end + + local pairs = common.chunk_table(nodes, 2) + local chunk_index, pair = find_current_pair(pairs, current_node) + + local corresponding_pair = pairs[chunk_index + direction] + if not corresponding_pair then return end local buf = vim.api.nvim_get_current_buf() - ts.swap_nodes(current_node, sibling, buf, true) + if pair[2] and corresponding_pair[2] then + ts.swap_nodes(pair[2], corresponding_pair[2], buf, true) + end + if pair[1] and corresponding_pair[1] then + ts.swap_nodes(pair[1], corresponding_pair[1], buf, true) + end +end + +local function drag_pair(opts) + local lang = langs.get_language_api() + local current_node = lang.get_node_root(ts.get_node_at_cursor()) + if not current_node then + return + end + + local pairwise_nodes = ts_utils.find_pairwise_nodes( + current_node, + vim.tbl_deep_extend("force", opts, { + lang = lang, + }) + ) + if not pairwise_nodes then + local parent = current_node:parent() + if not parent then + return + end + + pairwise_nodes = traversal.get_children_ignoring_comments(parent, { + lang = lang, + }) + end + + drag_node_in_pair(current_node, pairwise_nodes, opts) end -function M.drag_element_backwards() +local function drag_element(opts) local lang = langs.get_language_api() local current_node = lang.get_node_root(ts.get_node_at_cursor()) - local sibling = current_node:prev_named_sibling() + if opts.dragging.auto_drag_pairs then + local pairwise_nodes = ts_utils.find_pairwise_nodes(current_node, { lang = lang }) + if pairwise_nodes then + return drag_node_in_pair(current_node, pairwise_nodes, opts) + end + end + + local sibling + if opts.reversed then + sibling = current_node:prev_named_sibling() + else + sibling = current_node:next_named_sibling() + end + if not sibling then return end @@ -70,4 +133,44 @@ function M.drag_element_backwards() ts.swap_nodes(current_node, sibling, buf, true) end +function M.drag_element_forwards(opts) + local drag_opts = vim.tbl_deep_extend( + "force", + { + dragging = config.config.dragging or {}, + }, + opts or {}, + { + reversed = false, + } + ) + drag_element(drag_opts) +end + +function M.drag_element_backwards(opts) + local drag_opts = vim.tbl_deep_extend( + "force", + { + dragging = config.config.dragging or {}, + }, + opts or {}, + { + reversed = true, + } + ) + drag_element(drag_opts) +end + +function M.drag_pair_forwards() + drag_pair({ + reversed = false, + }) +end + +function M.drag_pair_backwards() + drag_pair({ + reversed = true, + }) +end + return M diff --git a/lua/nvim-paredit/api/init.lua b/lua/nvim-paredit/api/init.lua index df573fa..efac43e 100644 --- a/lua/nvim-paredit/api/init.lua +++ b/lua/nvim-paredit/api/init.lua @@ -15,6 +15,10 @@ local M = { drag_element_forwards = dragging.drag_element_forwards, drag_element_backwards = dragging.drag_element_backwards, + + drag_pair_forwards = dragging.drag_pair_forwards, + drag_pair_backwards = dragging.drag_pair_backwards, + drag_form_forwards = dragging.drag_form_forwards, drag_form_backwards = dragging.drag_form_backwards, diff --git a/lua/nvim-paredit/api/selections.lua b/lua/nvim-paredit/api/selections.lua index 1531801..3f9f058 100644 --- a/lua/nvim-paredit/api/selections.lua +++ b/lua/nvim-paredit/api/selections.lua @@ -41,7 +41,7 @@ function M.get_range_around_form() end function M.get_range_around_top_level_form() - return get_range_around_form_impl(traversal.get_top_level_node_below_document) + return get_range_around_form_impl(traversal.find_local_root) end local function select_around_form_impl(range) @@ -93,7 +93,7 @@ function M.get_range_in_form() end function M.get_range_in_top_level_form() - return get_range_in_form_impl(traversal.get_top_level_node_below_document) + return get_range_in_form_impl(traversal.find_local_root) end local function select_in_form_impl(range) diff --git a/lua/nvim-paredit/defaults.lua b/lua/nvim-paredit/defaults.lua index 03eb727..cf8caa1 100644 --- a/lua/nvim-paredit/defaults.lua +++ b/lua/nvim-paredit/defaults.lua @@ -4,7 +4,7 @@ local unwrap = require("nvim-paredit.api.unwrap") local M = {} M.default_keys = { - ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp", }, + ["@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp" }, [">)"] = { api.slurp_forwards, "Slurp forwards" }, [">("] = { api.barf_backwards, "Barf backwards" }, @@ -15,6 +15,9 @@ M.default_keys = { [">e"] = { api.drag_element_forwards, "Drag element right" }, ["p"] = { api.drag_pair_forwards, "Drag element pairs right" }, + ["f"] = { api.drag_form_forwards, "Drag form right" }, [" function M.get_language_api() for l in string.gmatch(vim.bo.filetype, "[^.]+") do if langs[l] ~= nil then return langs[l] end end - return nil + error("Could not find language extension for filetype " .. vim.bo.filetype, vim.log.levels.ERROR) end function M.add_language_extension(filetype, api) diff --git a/lua/nvim-paredit/utils/common.lua b/lua/nvim-paredit/utils/common.lua index 93b7c22..68c9825 100644 --- a/lua/nvim-paredit/utils/common.lua +++ b/lua/nvim-paredit/utils/common.lua @@ -9,6 +9,20 @@ function M.included_in_table(table, item) return false end +function M.chunk_table(tbl, chunk_size) + local result = {} + for i = 1, #tbl, chunk_size do + local chunk = {} + for j = 0, chunk_size - 1 do + if tbl[i + j] then + table.insert(chunk, tbl[i + j]) + end + end + table.insert(result, chunk) + end + return result +end + -- Compares the two given { col, row } position tuples and returns -1/0/1 depending -- on whether `a` is less than, equal to or greater than `b` -- diff --git a/lua/nvim-paredit/utils/traversal.lua b/lua/nvim-paredit/utils/traversal.lua index b8857ed..72d89d7 100644 --- a/lua/nvim-paredit/utils/traversal.lua +++ b/lua/nvim-paredit/utils/traversal.lua @@ -16,6 +16,22 @@ function M.find_nearest_form(current_node, opts) end end +function M.get_children_ignoring_comments(node, opts) + local children = {} + + local index = 0 + local child = node:named_child(index) + while child do + if not child:extra() and not opts.lang.node_is_comment(child) then + table.insert(children, child) + end + index = index + 1 + child = node:named_child(index) + end + + return children +end + local function get_child_ignoring_comments(node, index, opts) if index < 0 or index >= node:named_child_count() then return @@ -139,34 +155,18 @@ function M.find_root_element_relative_to(root, child) return M.find_root_element_relative_to(root, parent) end -function M.get_top_level_node_below_document(node) - -- Document - -- - Branch A - -- -- Node X - -- --- Sub-node 1 - -- - Branch B - -- -- Node Y - -- --- Sub-node 2 - -- --- Sub-node 3 - - -- If we call this function on "Sub-node 2" we expect "Branch B" to be - -- returned, the top level one below the document itself. We know which - -- node is the document because it lacks a parent, just like Batman. - - local parent = node:parent() - - -- Does the node have a parent? If so, we might be at the right level. - -- If not, we should just return the node right away, we're already too high. - if parent then - -- If the parent _also_ has a parent then we still need to go higher, recur. - if parent:parent() then - return M.get_top_level_node_below_document(parent) +-- Find the root node of the tree `node` is a member of, excluding the root +-- 'source' document. +function M.find_local_root(node) + local current = node + while true do + local next = current:parent() + if not next or next:type() == "source" then + break end + current = next end - - -- As soon as we don't have a grandparent or parent, return the node - -- we're on because it means we're one step below the top level document node. - return node + return current end return M diff --git a/lua/nvim-paredit/utils/ts.lua b/lua/nvim-paredit/utils/ts.lua new file mode 100644 index 0000000..606e4ad --- /dev/null +++ b/lua/nvim-paredit/utils/ts.lua @@ -0,0 +1,42 @@ +local traversal = require("nvim-paredit.utils.traversal") + +local M = {} + +-- Use a 'paredit/pairwise' treesitter query to find all nodes within a local +-- branch that are labeled as @pair. +-- +-- If any of these labeled nodes match the given target node then return all +-- matched nodes. +function M.find_pairwise_nodes(target_node, opts) + local root_node = traversal.find_local_root(target_node) + + local bufnr = vim.api.nvim_get_current_buf() + local lang = vim.treesitter.language.get_lang(vim.bo.filetype) + + local query = vim.treesitter.query.get(lang, "paredit/pairwise") + if not query then + return + end + + local captures = query:iter_captures(root_node, bufnr) + local pairwise_nodes = {} + local found = false + for id, node in captures do + if query.captures[id] == "pair" then + if not node:extra() and not opts.lang.node_is_comment(node) then + table.insert(pairwise_nodes, node) + if node:equal(target_node) then + found = true + end + end + end + end + + if not found then + return + end + + return pairwise_nodes +end + +return M diff --git a/queries/clojure/paredit/pairwise.scm b/queries/clojure/paredit/pairwise.scm new file mode 100644 index 0000000..78a71a2 --- /dev/null +++ b/queries/clojure/paredit/pairwise.scm @@ -0,0 +1,26 @@ +(list_lit + (sym_lit) @fn-name + (vec_lit + (_) @pair) + (#any-of? @fn-name "let" "loop" "binding" "with-open" "with-redefs")) + +(map_lit + (_) @pair) + +(list_lit + (sym_lit) @fn-name + (_) + (_) @pair + (#eq? @fn-name "case")) + +(list_lit + (sym_lit) @fn-name + (_) @pair + (#eq? @fn-name "cond")) + +(list_lit + (sym_lit) @fn-name + (_) + (_) + (_) @pair + (#eq? @fn-name "condp")) diff --git a/tests/nvim-paredit/pair_drag_spec.lua b/tests/nvim-paredit/pair_drag_spec.lua new file mode 100644 index 0000000..139f7df --- /dev/null +++ b/tests/nvim-paredit/pair_drag_spec.lua @@ -0,0 +1,86 @@ +local paredit = require("nvim-paredit.api") + +local prepare_buffer = require("tests.nvim-paredit.utils").prepare_buffer +local expect_all = require("tests.nvim-paredit.utils").expect_all +local expect = require("tests.nvim-paredit.utils").expect + +describe("paired-element-auto-dragging", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + it("should drag map pairs forward", function() + prepare_buffer({ + content = "{:a 1 :b 2}", + cursor = { 1, 1 }, + }) + + paredit.drag_element_forwards({ + dragging = { + auto_drag_pairs = true, + }, + }) + expect({ + content = "{:b 2 :a 1}", + cursor = { 1, 6 }, + }) + end) + + it("should drag map pairs backwards", function() + prepare_buffer({ + content = "{:a 1 :b 2}", + cursor = { 1, 9 }, + }) + + paredit.drag_element_backwards({ + dragging = { + auto_drag_pairs = true, + }, + }) + expect({ + content = "{:b 2 :a 1}", + cursor = { 1, 1 }, + }) + end) + + it("should detect various types", function() + expect_all(function() + paredit.drag_element_forwards({ dragging = { auto_drag_pairs = true } }) + end, { + { + "let binding", + before_content = "(let [a b c d])", + before_cursor = { 1, 6 }, + after_content = "(let [c d a b])", + after_cursor = { 1, 10 }, + }, + { + "loop binding", + before_content = "(loop [a b c d])", + before_cursor = { 1, 7 }, + after_content = "(loop [c d a b])", + after_cursor = { 1, 11 }, + }, + { + "case", + before_content = "(case a :a 1 :b 2)", + before_cursor = { 1, 8 }, + after_content = "(case a :b 2 :a 1)", + after_cursor = { 1, 13 }, + }, + }) + end) +end) + +describe("paired-element-dragging", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + it("should drag vector elements forwards", function() + prepare_buffer({ + content = "'[a b c d]", + cursor = { 1, 2 }, + }) + + paredit.drag_pair_forwards() + expect({ + content = "'[c d a b]", + cursor = { 1, 6 }, + }) + end) +end)