Skip to content

Commit

Permalink
Merge pull request #52 from ThePrimeagen/code-generation-2
Browse files Browse the repository at this point in the history
refactor: code gen is now dumber than it was before.
  • Loading branch information
ThePrimeagen authored Aug 13, 2021
2 parents 86c7a66 + 4f5d25b commit 6c61114
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 147 deletions.
37 changes: 18 additions & 19 deletions lua/refactoring/code_generation/go.lua
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
local utils = require("refactoring.code_generation.utils")

local go = {
extract_function = function(opts)
return {
create = string.format(
[[
constant = function(opts)
return string.format("%s := %s\n", opts.name, opts.value)
end,
["return"] = function(code)
return string.format("return %s", utils.stringify_code(code))
end,
["function"] = function(opts)
return string.format(
[[
func %s(%s) {
%s
return %s
}
]],

opts.name,
table.concat(opts.args, ", "),
type(opts.body) == "table"
and table.concat(opts.body, "\n")
or opts.body,
opts.ret
),
call = string.format(
"%s := %s(%s)",
opts.ret,
opts.name,
table.concat(opts.args, ", ")
),
}
opts.name,
table.concat(opts.args, ", "),
utils.stringify_code(opts.body)
)
end,
call_function = function(opts)
return string.format("%s(%s)", opts.name, table.concat(opts.args, ", "))
end,
}
return go
36 changes: 16 additions & 20 deletions lua/refactoring/code_generation/lua.lua
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
local utils = require("refactoring.code_generation.utils")

local lua = {
create_constant = function(opts)
constant = function(opts)
return string.format("local %s = %s\n", opts.name, opts.value)
end,
extract_function = function(opts)
return {
create = string.format(
[[
["function"] = function(opts)
return string.format(
[[
local function %s(%s)
%s
return %s
end
]],
opts.name,
table.concat(opts.args, ", "),
type(opts.body) == "table"
and table.concat(opts.body, "\n")
or opts.body,
opts.ret
),
opts.name,
table.concat(opts.args, ", "),
utils.stringify_code(opts.body)
)
end,
["return"] = function(code)
return string.format("return %s", utils.stringify_code(code))
end,

call = string.format(
"local %s = %s(%s)",
opts.ret,
opts.name,
table.concat(opts.args, ", ")
),
}
call_function = function(opts)
return string.format("%s(%s)", opts.name, table.concat(opts.args, ", "))
end,
}
return lua
38 changes: 19 additions & 19 deletions lua/refactoring/code_generation/python.lua
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
local utils = require("refactoring.code_generation.utils")

local python = {
extract_function = function(opts)
return {
create = string.format(
[[
constant = function(opts)
return string.format("%s = %s\n", opts.name, opts.value)
end,
["return"] = function(code)
return string.format("return %s", utils.stringify_code(code))
end,

["function"] = function(opts)
return string.format(
[[
def %s(%s):
%s
return %s
]],
opts.name,
table.concat(opts.args, ", "),
type(opts.body) == "table"
and table.concat(opts.body, "\n")
or opts.body,
opts.ret
),
call = string.format(
"%s = %s(%s)",
opts.ret,
opts.name,
table.concat(opts.args, ", ")
),
}
opts.name,
table.concat(opts.args, ", "),
utils.stringify_code(opts.body)
)
end,
call_function = function(opts)
return string.format("%s(%s)", opts.name, table.concat(opts.args, ", "))
end,
}
return python
51 changes: 7 additions & 44 deletions lua/refactoring/code_generation/typescript.lua
Original file line number Diff line number Diff line change
@@ -1,42 +1,12 @@
local typescript = {
create_constant = function(opts)
return string.format("const %s = %s;\n", opts.name, opts.value)
end,
extract_function = function(opts)
return {
create = string.format(
[[
function %s(%s) {
%s
return %s
}
local utils = require("refactoring.code_generation.utils")

]],
opts.name,
table.concat(opts.args, ", "),
type(opts.body) == "table"
and table.concat(opts.body, "\n")
or opts.body,
opts.ret
),
-- TODO: OBVI THIS NEEDS TO BE DIFFERENT...
call = string.format(
"const %s = %s(%s)",
opts.ret,
opts.name,
table.concat(opts.args, ", ")
),
}
end,
}
--[[
local typescript = {
constant = function(opts)
return string.format("const %s = %s;\n", opts.name, opts.value)
end,

["return"] = function(opts)
return string.format("return %s", opts.ret)
["return"] = function(code)
return string.format("return %s", utils.stringify_code(code))
end,

["function"] = function(opts)
Expand All @@ -45,24 +15,17 @@ local typescript = {
function %s(%s) {
%s
}
]]
--[[,
]],
opts.name,
table.concat(opts.args, ", "),
type(opts.body) == "table"
and table.concat(opts.body, "\n")
or opts.body
utils.stringify_code(opts.body)
)
end,

call_function = function(opts)
return string.format(
"%s(%s)",
opts.name,
table.concat(opts.args, ", ")
)
return string.format("%s(%s)", opts.name, table.concat(opts.args, ", "))
end,
}

--]]
return typescript
7 changes: 7 additions & 0 deletions lua/refactoring/code_generation/utils.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
local M = {}

function M.stringify_code(code)
return type(code) == "table" and table.concat(code, "\n") or code
end

return M
1 change: 1 addition & 0 deletions lua/refactoring/pipeline/refactor_setup.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ local function refactor_setup(bufnr, options)
local filetype = vim.bo[bufnr].filetype
local root = Query.get_root(bufnr, filetype)
local refactor = {
code = options.get_code_generation_for(filetype),
filetype = filetype,
bufnr = bufnr,
query = Query:new(
Expand Down
72 changes: 36 additions & 36 deletions lua/refactoring/refactor/106.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,6 @@ local Config = require("refactoring.config")

local M = {}

local function get_code(
bufnr,
lang,
region,
selected_local_references,
function_name,
ret
)
return Config.get_code_generation_for(lang).extract_function({
args = vim.fn.sort(vim.tbl_keys(selected_local_references)),
body = region:get_text(bufnr),
name = function_name,
ret = ret,
})
end

local function get_local_definitions(bufnr, local_defs, function_args)
local local_def_map = {}

Expand Down Expand Up @@ -89,24 +73,32 @@ M.extract_to_file = function(bufnr)
refactor
)
local function_name = get_input("106: Extract Function Name > ")
local extract_function = get_code(
refactor.bufnr,
refactor.filetype,
refactor.region,
selected_local_references,
function_name,
"fill_me"
)

local function_body = refactor.region:get_text()
table.insert(function_body, refactor.code["return"]("fill_me"))
local args = vim.fn.sort(vim.tbl_keys(selected_local_references))

local function_code = refactor.code["function"]({
name = function_name,
args = args,
body = function_body,
})

refactor.text_edits = {
{
region = utils.get_top_of_file_region(refactor.scope),
text = extract_function.create,
text = function_code,
bufnr = refactor.buffers[2],
},
{
region = refactor.region,
text = extract_function.call,
text = refactor.code.constant({
name = "fill_me",
value = refactor.code.call_function({
name = function_name,
args = args,
}),
}),
},
}

Expand All @@ -125,23 +117,31 @@ M.extract = function(bufnr)
)

local function_name = get_input("106: Extract Function Name > ")
local extract_function = get_code(
refactor.bufnr,
refactor.filetype,
refactor.region,
selected_local_references,
function_name,
"fill_me"
)
local function_body = refactor.region:get_text()
table.insert(function_body, refactor.code["return"]("fill_me"))
local args = vim.fn.sort(vim.tbl_keys(selected_local_references))

local function_code = refactor.code["function"]({
name = function_name,
args = args,
body = function_body,
})

refactor.text_edits = {
{
region = utils.region_above_node(refactor.scope),
text = extract_function.create,
text = function_code,
bufnr = refactor.buffers[2],
},
{
region = refactor.region,
text = extract_function.call,
text = refactor.code.constant({
name = "fill_me",
value = refactor.code.call_function({
name = function_name,
args = args,
}),
}),
},
}

Expand Down
9 changes: 4 additions & 5 deletions lua/refactoring/refactor/119.lua
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,10 @@ function M.extract_var(bufnr)
"Extract var unable to determine its containing statement within the block scope, please post issue with exact highlight + code! Thanks"
)
end
local code =
Config.get_code_generation_for(refactor.filetype).create_constant({
name = var_name,
value = extract_node_text,
})
local code = refactor.code.constant({
name = var_name,
value = extract_node_text,
})

table.insert(refactor.text_edits, {
add_newline = false,
Expand Down
3 changes: 2 additions & 1 deletion lua/refactoring/tests/extract.simple-function.expected.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ func foo_bar(a, test, test_other) {
for idx := test - 1; idx < test_other; idx++ {
fmt.Println(idx, a)
}
return fill_me
return fill_me
}

func simple_function(a int) {
var test int = 1
test_other := 1

fill_me := foo_bar(a, test, test_other)

}
3 changes: 2 additions & 1 deletion lua/refactoring/tests/extract.simple-function.expected.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ local function foo_bar(a, test, test_other)
for idx = test - 1, test_other do
print(idx, a)
end
return fill_me
return fill_me
end


Expand All @@ -13,4 +13,5 @@ function simple_function(a)
local test_other = 11

local fill_me = foo_bar(a, test, test_other)

end
Loading

0 comments on commit 6c61114

Please sign in to comment.