From 4f5d25b69d54692c04e9569f327957702a60e59b Mon Sep 17 00:00:00 2001 From: ThePrimeagen Date: Fri, 13 Aug 2021 11:12:35 -0600 Subject: [PATCH] refactor: code gen is now dumber than it was before. This should make adding new features / indenting / language specific things easier. --- lua/refactoring/code_generation/go.lua | 37 +++++----- lua/refactoring/code_generation/lua.lua | 36 +++++----- lua/refactoring/code_generation/python.lua | 38 +++++----- .../code_generation/typescript.lua | 51 ++----------- lua/refactoring/code_generation/utils.lua | 7 ++ lua/refactoring/pipeline/refactor_setup.lua | 1 + lua/refactoring/refactor/106.lua | 72 +++++++++---------- lua/refactoring/refactor/119.lua | 9 ++- .../tests/extract.simple-function.expected.go | 3 +- .../extract.simple-function.expected.lua | 3 +- .../tests/extract.simple-function.expected.py | 3 +- .../tests/extract.simple-function.expected.ts | 3 +- 12 files changed, 116 insertions(+), 147 deletions(-) create mode 100644 lua/refactoring/code_generation/utils.lua diff --git a/lua/refactoring/code_generation/go.lua b/lua/refactoring/code_generation/go.lua index 1626a70a..b410a76f 100644 --- a/lua/refactoring/code_generation/go.lua +++ b/lua/refactoring/code_generation/go.lua @@ -1,28 +1,27 @@ +local utils = require("refactoring.code_generation.utils") + local go = { - extract_function = function(opts) - return { - create = string.format( - [[ + constant = function(opts) + return string.format("%s := %s\n", opts.name, opts.value) + end, + ["return"] = function(code) + return string.format("return %s", utils.stringify_code(code)) + end, + ["function"] = function(opts) + return string.format( + [[ func %s(%s) { %s - return %s } ]], - opts.name, - table.concat(opts.args, ", "), - type(opts.body) == "table" - and table.concat(opts.body, "\n") - or opts.body, - opts.ret - ), - call = string.format( - "%s := %s(%s)", - opts.ret, - opts.name, - table.concat(opts.args, ", ") - ), - } + opts.name, + table.concat(opts.args, ", "), + utils.stringify_code(opts.body) + ) + end, + call_function = function(opts) + return string.format("%s(%s)", opts.name, table.concat(opts.args, ", ")) end, } return go diff --git a/lua/refactoring/code_generation/lua.lua b/lua/refactoring/code_generation/lua.lua index 11dcb1a9..1ed4e8ca 100644 --- a/lua/refactoring/code_generation/lua.lua +++ b/lua/refactoring/code_generation/lua.lua @@ -1,32 +1,28 @@ +local utils = require("refactoring.code_generation.utils") + local lua = { - create_constant = function(opts) + constant = function(opts) return string.format("local %s = %s\n", opts.name, opts.value) end, - extract_function = function(opts) - return { - create = string.format( - [[ + ["function"] = function(opts) + return string.format( + [[ local function %s(%s) %s - return %s end ]], - opts.name, - table.concat(opts.args, ", "), - type(opts.body) == "table" - and table.concat(opts.body, "\n") - or opts.body, - opts.ret - ), + opts.name, + table.concat(opts.args, ", "), + utils.stringify_code(opts.body) + ) + end, + ["return"] = function(code) + return string.format("return %s", utils.stringify_code(code)) + end, - call = string.format( - "local %s = %s(%s)", - opts.ret, - opts.name, - table.concat(opts.args, ", ") - ), - } + call_function = function(opts) + return string.format("%s(%s)", opts.name, table.concat(opts.args, ", ")) end, } return lua diff --git a/lua/refactoring/code_generation/python.lua b/lua/refactoring/code_generation/python.lua index b79ed899..641c72d6 100644 --- a/lua/refactoring/code_generation/python.lua +++ b/lua/refactoring/code_generation/python.lua @@ -1,28 +1,28 @@ +local utils = require("refactoring.code_generation.utils") + local python = { - extract_function = function(opts) - return { - create = string.format( - [[ + constant = function(opts) + return string.format("%s = %s\n", opts.name, opts.value) + end, + ["return"] = function(code) + return string.format("return %s", utils.stringify_code(code)) + end, + + ["function"] = function(opts) + return string.format( + [[ def %s(%s): %s - return %s ]], - opts.name, - table.concat(opts.args, ", "), - type(opts.body) == "table" - and table.concat(opts.body, "\n") - or opts.body, - opts.ret - ), - call = string.format( - "%s = %s(%s)", - opts.ret, - opts.name, - table.concat(opts.args, ", ") - ), - } + opts.name, + table.concat(opts.args, ", "), + utils.stringify_code(opts.body) + ) + end, + call_function = function(opts) + return string.format("%s(%s)", opts.name, table.concat(opts.args, ", ")) end, } return python diff --git a/lua/refactoring/code_generation/typescript.lua b/lua/refactoring/code_generation/typescript.lua index edeb3b38..0ff40543 100644 --- a/lua/refactoring/code_generation/typescript.lua +++ b/lua/refactoring/code_generation/typescript.lua @@ -1,42 +1,12 @@ -local typescript = { - create_constant = function(opts) - return string.format("const %s = %s;\n", opts.name, opts.value) - end, - extract_function = function(opts) - return { - create = string.format( - [[ -function %s(%s) { - %s - return %s -} +local utils = require("refactoring.code_generation.utils") -]], - opts.name, - table.concat(opts.args, ", "), - type(opts.body) == "table" - and table.concat(opts.body, "\n") - or opts.body, - opts.ret - ), - -- TODO: OBVI THIS NEEDS TO BE DIFFERENT... - call = string.format( - "const %s = %s(%s)", - opts.ret, - opts.name, - table.concat(opts.args, ", ") - ), - } - end, -} ---[[ local typescript = { constant = function(opts) return string.format("const %s = %s;\n", opts.name, opts.value) end, - ["return"] = function(opts) - return string.format("return %s", opts.ret) + ["return"] = function(code) + return string.format("return %s", utils.stringify_code(code)) end, ["function"] = function(opts) @@ -45,24 +15,17 @@ local typescript = { function %s(%s) { %s } - ]] ---[[, + + ]], opts.name, table.concat(opts.args, ", "), - type(opts.body) == "table" - and table.concat(opts.body, "\n") - or opts.body + utils.stringify_code(opts.body) ) end, call_function = function(opts) - return string.format( - "%s(%s)", - opts.name, - table.concat(opts.args, ", ") - ) + return string.format("%s(%s)", opts.name, table.concat(opts.args, ", ")) end, } ---]] return typescript diff --git a/lua/refactoring/code_generation/utils.lua b/lua/refactoring/code_generation/utils.lua new file mode 100644 index 00000000..7beb8dd6 --- /dev/null +++ b/lua/refactoring/code_generation/utils.lua @@ -0,0 +1,7 @@ +local M = {} + +function M.stringify_code(code) + return type(code) == "table" and table.concat(code, "\n") or code +end + +return M diff --git a/lua/refactoring/pipeline/refactor_setup.lua b/lua/refactoring/pipeline/refactor_setup.lua index 8e5e5c71..5ff04cf2 100644 --- a/lua/refactoring/pipeline/refactor_setup.lua +++ b/lua/refactoring/pipeline/refactor_setup.lua @@ -7,6 +7,7 @@ local function refactor_setup(bufnr, options) local filetype = vim.bo[bufnr].filetype local root = Query.get_root(bufnr, filetype) local refactor = { + code = options.get_code_generation_for(filetype), filetype = filetype, bufnr = bufnr, query = Query:new( diff --git a/lua/refactoring/refactor/106.lua b/lua/refactoring/refactor/106.lua index 09dd1db8..518244e8 100644 --- a/lua/refactoring/refactor/106.lua +++ b/lua/refactoring/refactor/106.lua @@ -14,22 +14,6 @@ local Config = require("refactoring.config") local M = {} -local function get_code( - bufnr, - lang, - region, - selected_local_references, - function_name, - ret -) - return Config.get_code_generation_for(lang).extract_function({ - args = vim.fn.sort(vim.tbl_keys(selected_local_references)), - body = region:get_text(bufnr), - name = function_name, - ret = ret, - }) -end - local function get_local_definitions(bufnr, local_defs, function_args) local local_def_map = {} @@ -89,24 +73,32 @@ M.extract_to_file = function(bufnr) refactor ) local function_name = get_input("106: Extract Function Name > ") - local extract_function = get_code( - refactor.bufnr, - refactor.filetype, - refactor.region, - selected_local_references, - function_name, - "fill_me" - ) + + local function_body = refactor.region:get_text() + table.insert(function_body, refactor.code["return"]("fill_me")) + local args = vim.fn.sort(vim.tbl_keys(selected_local_references)) + + local function_code = refactor.code["function"]({ + name = function_name, + args = args, + body = function_body, + }) refactor.text_edits = { { region = utils.get_top_of_file_region(refactor.scope), - text = extract_function.create, + text = function_code, bufnr = refactor.buffers[2], }, { region = refactor.region, - text = extract_function.call, + text = refactor.code.constant({ + name = "fill_me", + value = refactor.code.call_function({ + name = function_name, + args = args, + }), + }), }, } @@ -125,23 +117,31 @@ M.extract = function(bufnr) ) local function_name = get_input("106: Extract Function Name > ") - local extract_function = get_code( - refactor.bufnr, - refactor.filetype, - refactor.region, - selected_local_references, - function_name, - "fill_me" - ) + local function_body = refactor.region:get_text() + table.insert(function_body, refactor.code["return"]("fill_me")) + local args = vim.fn.sort(vim.tbl_keys(selected_local_references)) + + local function_code = refactor.code["function"]({ + name = function_name, + args = args, + body = function_body, + }) refactor.text_edits = { { region = utils.region_above_node(refactor.scope), - text = extract_function.create, + text = function_code, + bufnr = refactor.buffers[2], }, { region = refactor.region, - text = extract_function.call, + text = refactor.code.constant({ + name = "fill_me", + value = refactor.code.call_function({ + name = function_name, + args = args, + }), + }), }, } diff --git a/lua/refactoring/refactor/119.lua b/lua/refactoring/refactor/119.lua index 5492c0a4..7c5900b8 100644 --- a/lua/refactoring/refactor/119.lua +++ b/lua/refactoring/refactor/119.lua @@ -80,11 +80,10 @@ function M.extract_var(bufnr) "Extract var unable to determine its containing statement within the block scope, please post issue with exact highlight + code! Thanks" ) end - local code = - Config.get_code_generation_for(refactor.filetype).create_constant({ - name = var_name, - value = extract_node_text, - }) + local code = refactor.code.constant({ + name = var_name, + value = extract_node_text, + }) table.insert(refactor.text_edits, { add_newline = false, diff --git a/lua/refactoring/tests/extract.simple-function.expected.go b/lua/refactoring/tests/extract.simple-function.expected.go index baffa0ad..47eeaa03 100644 --- a/lua/refactoring/tests/extract.simple-function.expected.go +++ b/lua/refactoring/tests/extract.simple-function.expected.go @@ -6,7 +6,7 @@ func foo_bar(a, test, test_other) { for idx := test - 1; idx < test_other; idx++ { fmt.Println(idx, a) } - return fill_me +return fill_me } func simple_function(a int) { @@ -14,4 +14,5 @@ func simple_function(a int) { test_other := 1 fill_me := foo_bar(a, test, test_other) + } diff --git a/lua/refactoring/tests/extract.simple-function.expected.lua b/lua/refactoring/tests/extract.simple-function.expected.lua index 5713ad87..1f6e0211 100644 --- a/lua/refactoring/tests/extract.simple-function.expected.lua +++ b/lua/refactoring/tests/extract.simple-function.expected.lua @@ -4,7 +4,7 @@ local function foo_bar(a, test, test_other) for idx = test - 1, test_other do print(idx, a) end - return fill_me +return fill_me end @@ -13,4 +13,5 @@ function simple_function(a) local test_other = 11 local fill_me = foo_bar(a, test, test_other) + end diff --git a/lua/refactoring/tests/extract.simple-function.expected.py b/lua/refactoring/tests/extract.simple-function.expected.py index 843af2b1..91d93f12 100644 --- a/lua/refactoring/tests/extract.simple-function.expected.py +++ b/lua/refactoring/tests/extract.simple-function.expected.py @@ -2,7 +2,7 @@ def foo_bar(a, test, test_other): for x in range(test_other + test): print(x, a) - return fill_me +return fill_me def simple_function(a): @@ -10,3 +10,4 @@ def simple_function(a): test_other = 11 fill_me = foo_bar(a, test, test_other) + diff --git a/lua/refactoring/tests/extract.simple-function.expected.ts b/lua/refactoring/tests/extract.simple-function.expected.ts index 1635e209..9e74c074 100644 --- a/lua/refactoring/tests/extract.simple-function.expected.ts +++ b/lua/refactoring/tests/extract.simple-function.expected.ts @@ -10,5 +10,6 @@ function simple_function(a: number) { let test = 1; let test_other = 11 - const fill_me = foo_bar(a, test, test_other) + const fill_me = foo_bar(a, test, test_other); + }