diff --git a/lua/nvim-paredit/api/cursor.lua b/lua/nvim-paredit/api/cursor.lua index 8862d5b..7755183 100644 --- a/lua/nvim-paredit/api/cursor.lua +++ b/lua/nvim-paredit/api/cursor.lua @@ -4,26 +4,40 @@ function M.insert_mode() vim.api.nvim_feedkeys("i", "n", true) end -function M.place_cursor(form, opts) - if not form then +function M.get_cursor_pos(range_or_node, opts) + local range + + if type(range_or_node) == "table" then + range = range_or_node + elseif type(range_or_node) == "userdata" then + range = { range_or_node:range() } + range[4] = range[4] - 1 + end + + if not range then return end - local range = { form:range() } local cursor_pos if opts.placement == "left_edge" then cursor_pos = { range[1] + 1, range[2] } elseif opts.placement == "inner_start" then cursor_pos = { range[1] + 1, range[2] + 1 } - elseif opts.placement == "inned_end" then - cursor_pos = { range[3] + 1, range[4] - 2 } + elseif opts.placement == "inner_end" then + cursor_pos = { range[3] + 1, range[4] } else - cursor_pos = { range[3] + 1, range[4] - 1 } + cursor_pos = { range[3] + 1, range[4] + 1 } end - vim.api.nvim_win_set_cursor(0, cursor_pos) + return cursor_pos +end - if opts.mode == "insert" then - M.insert_mode() +function M.place_cursor(range_or_node, opts) + local cursor_pos = M.get_cursor_pos(range_or_node, opts) + if cursor_pos then + vim.api.nvim_win_set_cursor(0, cursor_pos) + if opts.mode == "insert" then + M.insert_mode() + end end end diff --git a/lua/nvim-paredit/api/wrap.lua b/lua/nvim-paredit/api/wrap.lua index 537e490..f3bf156 100644 --- a/lua/nvim-paredit/api/wrap.lua +++ b/lua/nvim-paredit/api/wrap.lua @@ -45,6 +45,18 @@ function M.wrap_element(buf, element, prefix, suffix) local range = { element:range() } vim.api.nvim_buf_set_text(buf, range[3], range[4], range[3], range[4], { suffix }) vim.api.nvim_buf_set_text(buf, range[1], range[2], range[1], range[2], { prefix }) + local end_col + if (range[1] == range[3]) then + end_col = range[4] + prefix:len() + suffix:len() - 1 + else + end_col = range[4] + suffix:len() - 1 + end + return { + range[1], + range[2], + range[3], + end_col, + } end function M.wrap_element_under_cursor(prefix, suffix) @@ -62,12 +74,7 @@ function M.wrap_element_under_cursor(prefix, suffix) return end - M.wrap_element(buf, current_element, prefix, suffix) - - reparse(buf) - - current_element = lang.get_node_root(ts.get_node_at_cursor()) - return M.find_form(current_element, lang) + return M.wrap_element(buf, current_element, prefix, suffix) end function M.wrap_enclosing_form_under_cursor(prefix, suffix) @@ -79,24 +86,15 @@ function M.wrap_enclosing_form_under_cursor(prefix, suffix) return end - local use_direct_parent = common.is_whitespace_under_cursor(lang) or lang.node_is_comment(ts.get_node_at_cursor()) + local use_direct_parent = common.is_whitespace_under_cursor(lang) + or lang.node_is_comment(ts.get_node_at_cursor()) local form = M.find_form(current_element, lang) if not use_direct_parent and form:type() ~= "source" then form = M.find_parend_form(current_element, lang) end - M.wrap_element(buf, form, prefix, suffix) - - reparse(buf) - - current_element = M.find_element_under_cursor(lang) - if use_direct_parent then - form = current_element - else - form = M.find_parend_form(current_element, lang) - end - return M.find_parend_form(form, lang) + return M.wrap_element(buf, form, prefix, suffix) end return M diff --git a/tests/nvim-paredit/cursor_spec.lua b/tests/nvim-paredit/cursor_spec.lua new file mode 100644 index 0000000..c469ef5 --- /dev/null +++ b/tests/nvim-paredit/cursor_spec.lua @@ -0,0 +1,71 @@ +local paredit = require("nvim-paredit") +local ts = require("nvim-treesitter.ts_utils") +local prepare_buffer = require("tests.nvim-paredit.utils").prepare_buffer + +describe("cursor pos api tests", function() + vim.api.nvim_buf_set_option(0, "filetype", "clojure") + + it("should place cursor inside form at the beginning", function() + prepare_buffer({ + content = { "(a (b))" }, + cursor = { 1, 0 }, + }) + + local cursor_pos = paredit.cursor.get_cursor_pos({ 0, 0, 0, 6 }, { placement = "inner_start" }) + + assert.are.same({ 1, 1 }, cursor_pos) + + local node = ts.get_node_at_cursor() + cursor_pos = paredit.cursor.get_cursor_pos(node, { placement = "inner_start" }) + + assert.are.same({ 1, 1 }, cursor_pos) + end) + + it("should place cursor outside form at the beginning", function() + prepare_buffer({ + content = { "(a (b))" }, + cursor = { 1, 0 }, + }) + + local cursor_pos = paredit.cursor.get_cursor_pos({ 0, 0, 0, 6 }, { placement = "left_edge" }) + + assert.are.same({ 1, 0 }, cursor_pos) + + local node = ts.get_node_at_cursor() + cursor_pos = paredit.cursor.get_cursor_pos(node, { placement = "left_edge" }) + + assert.are.same({ 1, 0 }, cursor_pos) + end) + + it("should place cursor inside form at the end", function() + prepare_buffer({ + content = { "(a ", " (b))" }, + cursor = { 1, 0 }, + }) + + local cursor_pos = paredit.cursor.get_cursor_pos({ 0, 0, 1, 4 }, { placement = "inner_end" }) + + assert.are.same({ 2, 4 }, cursor_pos) + + local node = ts.get_node_at_cursor() + cursor_pos = paredit.cursor.get_cursor_pos(node, { placement = "inner_end" }) + + assert.are.same({ 2, 4 }, cursor_pos) + end) + + it("should place cursor outside form at the end", function() + prepare_buffer({ + content = { "(a ", " (b))" }, + cursor = { 1, 0 }, + }) + + local cursor_pos = paredit.cursor.get_cursor_pos({ 0, 0, 1, 4 }, { placement = "right_edge" }) + + assert.are.same({ 2, 5 }, cursor_pos) + + local node = ts.get_node_at_cursor() + cursor_pos = paredit.cursor.get_cursor_pos(node, { placement = "right_edge" }) + + assert.are.same({ 2, 5 }, cursor_pos) + end) +end) diff --git a/tests/nvim-paredit/form_and_element_wrap_spec.lua b/tests/nvim-paredit/form_and_element_wrap_spec.lua index 95a6d05..b9cb10a 100644 --- a/tests/nvim-paredit/form_and_element_wrap_spec.lua +++ b/tests/nvim-paredit/form_and_element_wrap_spec.lua @@ -11,7 +11,8 @@ describe("element and form wrap", function() cursor = { 1, 4 }, }) - paredit.wrap.wrap_element_under_cursor("(", ")") + local range = paredit.wrap.wrap_element_under_cursor("(", ")") + assert.falsy(range) expect({ content = { "(+ 2 :foo/bar)" }, }) @@ -23,7 +24,8 @@ describe("element and form wrap", function() cursor = { 1, 7 }, }) - paredit.wrap.wrap_element_under_cursor("(", ")") + local range = paredit.wrap.wrap_element_under_cursor("(", ")") + assert.are.same({ 0, 5, 0, 14 }, range) expect({ content = { "(+ 2 (:foo/bar))" }, }) @@ -35,7 +37,8 @@ describe("element and form wrap", function() cursor = { 1, 0 }, }) - paredit.wrap.wrap_element_under_cursor("(", ")") + local range = paredit.wrap.wrap_element_under_cursor("(", ")") + assert.are.same({ 0, 0, 0, 15 }, range) expect({ content = { "((+ 2 :foo/bar))" }, }) @@ -43,13 +46,14 @@ describe("element and form wrap", function() it("should wrap namespaced keyword", function() prepare_buffer({ - content = { '(+ 2 "lol")' }, + content = { "(+ 2 :foo/lol)" }, cursor = { 1, 7 }, }) - paredit.wrap.wrap_element_under_cursor("(", ")") + local range = paredit.wrap.wrap_element_under_cursor("(", ")") + assert.are.same({ 0, 5, 0, 14 }, range) expect({ - content = { '(+ 2 ("lol"))' }, + content = { "(+ 2 (:foo/lol))" }, }) end) @@ -62,7 +66,8 @@ describe("element and form wrap", function() cursor = { 2, 4 }, }) - paredit.wrap.wrap_enclosing_form_under_cursor("(", ")") + local range = paredit.wrap.wrap_enclosing_form_under_cursor("(", ")") + assert.are.same({ 0, 0, 1, 10 }, range) expect({ content = { "((+ 2", @@ -77,7 +82,8 @@ describe("element and form wrap", function() cursor = { 1, 0 }, }) - paredit.wrap.wrap_enclosing_form_under_cursor("(", ")") + local range = paredit.wrap.wrap_enclosing_form_under_cursor("(", ")") + assert.are.same({ 0, 0, 0, 15 }, range) expect({ content = { "((+ 2 :foo/bar))" }, }) @@ -93,7 +99,8 @@ describe("element and form wrap", function() cursor = { 2, 4 }, }) - paredit.wrap.wrap_enclosing_form_under_cursor("(", ")") + local range = paredit.wrap.wrap_enclosing_form_under_cursor("(", ")") + assert.are.same({ 0, 0, 2, 10 }, range) expect({ content = { "((+ 2",