Skip to content

Commit

Permalink
Implement pairwise element dragging
Browse files Browse the repository at this point in the history
This adds two new APIs - `drag_pair_forwards` and its companion
`drag_pair_backwards` which allow dragging 'pairs' of elements within a
form.

An easy example is the key-value pairs within a map which typically are
meant to stay together when reordering.

This can, however, be used for any form. This means if you have some
vector containing logical pairs (because of the actual semantics of your
code) you can use a dedicated keybinding to drag them around together.

This change also introduces a new config option at:

`dragging.auto_drag_pairs = true|false`

which will alter the behaviour of the existing `drag_element_forwards`
and `drag_element_backwards` APIs in order to try infer whether they are
contained within a node that is made up of pairs.

For example if this setting is `true` and a `drag_element_forwards` is
used on the keys of a map then they will be dragged pairwise.

This new config option defaults to `true` under the assumption that this
is generally the desired behaviour.
  • Loading branch information
julienvincent committed Oct 12, 2024
1 parent de1c08f commit 6bf71ba
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 40 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ paredit.setup({
-- this case it will place the cursor on the moved edge
cursor_behaviour = "auto", -- remain, follow, auto

dragging = {
-- If set to `true` paredit will attempt to infer if an element being
-- dragged is part of a 'paired' form like as a map. If so then the element
-- will be dragged along with it's pair.
auto_drag_pairs = true,
},

indent = {
-- This controls how nvim-paredit handles indentation when performing operations which
-- should change the indentation of the form (such as when slurping or barfing).
Expand All @@ -88,6 +95,9 @@ paredit.setup({
[">e"] = { paredit.api.drag_element_forwards, "Drag element right" },
["<e"] = { paredit.api.drag_element_backwards, "Drag element left" },

[">p"] = { api.drag_pair_forwards, "Drag element pairs right" },
["<p"] = { api.drag_pair_backwards, "Drag element pairs left" },

[">f"] = { paredit.api.drag_form_forwards, "Drag form right" },
["<f"] = { paredit.api.drag_form_backwards, "Drag form left" },

Expand Down Expand Up @@ -332,6 +342,8 @@ paredit.api.slurp_forwards()
- **`barf_backwards`**
- **`drag_element_forwards`**
- **`drag_element_backwards`**
- **`drag_pair_forwards`**
- **`drag_pair_backwards`**
- **`drag_form_forwards`**
- **`drag_form_backwards`**
- **`raise_element`**
Expand Down
114 changes: 106 additions & 8 deletions lua/nvim-paredit/api/dragging.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
local traversal = require("nvim-paredit.utils.traversal")
local common = require("nvim-paredit.utils.common")
local ts_utils = require("nvim-paredit.utils.ts")
local ts = require("nvim-treesitter.ts_utils")
local config = require("nvim-paredit.config")
local langs = require("nvim-paredit.lang")

local M = {}
Expand Down Expand Up @@ -44,24 +47,79 @@ function M.drag_form_backwards()
ts.swap_nodes(root, sibling, buf, true)
end

function M.drag_element_forwards()
local lang = langs.get_language_api()
local current_node = lang.get_node_root(ts.get_node_at_cursor())
local function find_current_pair(pairs, current_node)
for i, pair in ipairs(pairs) do
for _, node in ipairs(pair) do
if node:equal(current_node) then
return i, pair
end
end
end
end

local sibling = current_node:next_named_sibling()
if not sibling then
local function drag_node_in_pair(current_node, nodes, opts)
local direction = 1
if opts.reversed then
direction = -1
end

local pairs = common.chunk_table(nodes, 2)
local chunk_index, pair = find_current_pair(pairs, current_node)

local corresponding_pair = pairs[chunk_index + direction]
if not corresponding_pair then
return
end

local buf = vim.api.nvim_get_current_buf()
ts.swap_nodes(current_node, sibling, buf, true)
if pair[2] and corresponding_pair[2] then
ts.swap_nodes(pair[2], corresponding_pair[2], buf, true)
end
if pair[1] and corresponding_pair[1] then
ts.swap_nodes(pair[1], corresponding_pair[1], buf, true)
end
end

function M.drag_element_backwards()
local function drag_pair(opts)
local lang = langs.get_language_api()
local current_node = lang.get_node_root(ts.get_node_at_cursor())
if not current_node then
return
end

local pairwise_nodes = ts_utils.find_pairwise_nodes(current_node, opts)
if not pairwise_nodes then
local parent = current_node:parent()
if not parent then
return
end

pairwise_nodes = traversal.get_children_ignoring_comments(parent, {
lang = lang,
})
end

drag_node_in_pair(current_node, pairwise_nodes, opts)
end

local function drag_element(opts)
local lang = langs.get_language_api()
local current_node = lang.get_node_root(ts.get_node_at_cursor())

if opts.dragging.auto_drag_pairs then
local pairwise_nodes = ts_utils.find_pairwise_nodes(current_node, { lang = lang })
if pairwise_nodes then
return drag_node_in_pair(current_node, pairwise_nodes, opts)
end
end

local sibling
if opts.reversed then
sibling = current_node:prev_named_sibling()
else
sibling = current_node:next_named_sibling()
end

local sibling = current_node:prev_named_sibling()
if not sibling then
return
end
Expand All @@ -70,4 +128,44 @@ function M.drag_element_backwards()
ts.swap_nodes(current_node, sibling, buf, true)
end

function M.drag_element_forwards(opts)
local drag_opts = vim.tbl_deep_extend(
"force",
{
dragging = config.config.dragging or {},
},
opts or {},
{
reversed = false,
}
)
drag_element(drag_opts)
end

function M.drag_element_backwards(opts)
local drag_opts = vim.tbl_deep_extend(
"force",
{
dragging = config.config.dragging or {},
},
opts or {},
{
reversed = true,
}
)
drag_element(drag_opts)
end

function M.drag_pair_forwards()
drag_pair({
reversed = false,
})
end

function M.drag_pair_backwards()
drag_pair({
reversed = true,
})
end

return M
4 changes: 4 additions & 0 deletions lua/nvim-paredit/api/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ local M = {

drag_element_forwards = dragging.drag_element_forwards,
drag_element_backwards = dragging.drag_element_backwards,

drag_pair_forwards = dragging.drag_pair_forwards,
drag_pair_backwards = dragging.drag_pair_backwards,

drag_form_forwards = dragging.drag_form_forwards,
drag_form_backwards = dragging.drag_form_backwards,

Expand Down
4 changes: 2 additions & 2 deletions lua/nvim-paredit/api/selections.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function M.get_range_around_form()
end

function M.get_range_around_top_level_form()
return get_range_around_form_impl(traversal.get_top_level_node_below_document)
return get_range_around_form_impl(traversal.find_local_root)
end

local function select_around_form_impl(range)
Expand Down Expand Up @@ -93,7 +93,7 @@ function M.get_range_in_form()
end

function M.get_range_in_top_level_form()
return get_range_in_form_impl(traversal.get_top_level_node_below_document)
return get_range_in_form_impl(traversal.find_local_root)
end

local function select_in_form_impl(range)
Expand Down
11 changes: 10 additions & 1 deletion lua/nvim-paredit/defaults.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ local unwrap = require("nvim-paredit.api.unwrap")
local M = {}

M.default_keys = {
["<localleader>@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp", },
["<localleader>@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp" },

[">)"] = { api.slurp_forwards, "Slurp forwards" },
[">("] = { api.barf_backwards, "Barf backwards" },
Expand All @@ -15,6 +15,9 @@ M.default_keys = {
[">e"] = { api.drag_element_forwards, "Drag element right" },
["<e"] = { api.drag_element_backwards, "Drag element left" },

[">p"] = { api.drag_pair_forwards, "Drag element pairs right" },
["<p"] = { api.drag_pair_backwards, "Drag element pairs left" },

[">f"] = { api.drag_form_forwards, "Drag form right" },
["<f"] = { api.drag_form_backwards, "Drag form left" },

Expand Down Expand Up @@ -107,6 +110,12 @@ M.default_keys = {
M.defaults = {
use_default_keys = true,
cursor_behaviour = "auto", -- remain, follow, auto
dragging = {
-- If set to `true` paredit will attempt to infer if an element being
-- dragged is part of a 'paired' form like as a map. If so then the element
-- will be dragged along with it's pair.
auto_drag_pairs = true,
},
indent = {
enabled = false,
indentor = require("nvim-paredit.indentation.native").indentor,
Expand Down
4 changes: 4 additions & 0 deletions lua/nvim-paredit/lang/clojure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,23 @@ function M.get_form_edges(node)
local left_bracket_range = { form:field("open")[1]:range() }
local right_bracket_range = { form:field("close")[1]:range() }

-- stylua: ignore
local left_range = {
outer_range[1], outer_range[2],
left_bracket_range[3], left_bracket_range[4]
}
-- stylua: ignore
local right_range = {
right_bracket_range[1], right_bracket_range[2],
outer_range[3], outer_range[4],
}

-- stylua: ignore
local left_text = vim.api.nvim_buf_get_text(0,
left_range[1], left_range[2],
left_range[3], left_range[4],
{})
-- stylua: ignore
local right_text = vim.api.nvim_buf_get_text(0,
right_range[1], right_range[2],
right_range[3], right_range[4],
Expand Down
5 changes: 2 additions & 3 deletions lua/nvim-paredit/lang/init.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
local common = require("nvim-paredit.utils.common")

local langs = {
clojure = require("nvim-paredit.lang.clojure"),
}
Expand All @@ -14,13 +12,14 @@ local function keys(tbl)
return result
end

--- @return table<string, function>
function M.get_language_api()
for l in string.gmatch(vim.bo.filetype, "[^.]+") do
if langs[l] ~= nil then
return langs[l]
end
end
return nil
error("Could not find language extension for filetype " .. vim.bo.filetype, vim.log.levels.ERROR)
end

function M.add_language_extension(filetype, api)
Expand Down
14 changes: 14 additions & 0 deletions lua/nvim-paredit/utils/common.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ function M.included_in_table(table, item)
return false
end

function M.chunk_table(tbl, chunk_size)
local result = {}
for i = 1, #tbl, chunk_size do
local chunk = {}
for j = 0, chunk_size - 1 do
if tbl[i + j] then
table.insert(chunk, tbl[i + j])
end
end
table.insert(result, chunk)
end
return result
end

-- Compares the two given { col, row } position tuples and returns -1/0/1 depending
-- on whether `a` is less than, equal to or greater than `b`
--
Expand Down
52 changes: 26 additions & 26 deletions lua/nvim-paredit/utils/traversal.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ function M.find_nearest_form(current_node, opts)
end
end

function M.get_children_ignoring_comments(node, opts)
local children = {}

local index = 0
local child = node:named_child(index)
while child do
if not child:extra() and not opts.lang.node_is_comment(child) then
table.insert(children, child)
end
index = index + 1
child = node:named_child(index)
end

return children
end

local function get_child_ignoring_comments(node, index, opts)
if index < 0 or index >= node:named_child_count() then
return
Expand Down Expand Up @@ -139,34 +155,18 @@ function M.find_root_element_relative_to(root, child)
return M.find_root_element_relative_to(root, parent)
end

function M.get_top_level_node_below_document(node)
-- Document
-- - Branch A
-- -- Node X
-- --- Sub-node 1
-- - Branch B
-- -- Node Y
-- --- Sub-node 2
-- --- Sub-node 3

-- If we call this function on "Sub-node 2" we expect "Branch B" to be
-- returned, the top level one below the document itself. We know which
-- node is the document because it lacks a parent, just like Batman.

local parent = node:parent()

-- Does the node have a parent? If so, we might be at the right level.
-- If not, we should just return the node right away, we're already too high.
if parent then
-- If the parent _also_ has a parent then we still need to go higher, recur.
if parent:parent() then
return M.get_top_level_node_below_document(parent)
-- Find the root node of the tree `node` is a member of, excluding the root
-- 'source' document.
function M.find_local_root(node)
local current = node
while true do
local next = current:parent()
if not next or next:type() == "source" then
break
end
current = next
end

-- As soon as we don't have a grandparent or parent, return the node
-- we're on because it means we're one step below the top level document node.
return node
return current
end

return M
Loading

0 comments on commit 6bf71ba

Please sign in to comment.