diff options
-rw-r--r-- | changelog.md | 1 | ||||
-rw-r--r-- | locale/en-us/script.lua | 1 | ||||
-rw-r--r-- | locale/zh-cn/script.lua | 1 | ||||
-rw-r--r-- | script/core/code-action.lua | 102 | ||||
-rw-r--r-- | script/provider/provider.lua | 14 | ||||
-rw-r--r-- | script/utility.lua | 30 | ||||
-rw-r--r-- | test/code_action/init.lua | 85 |
7 files changed, 203 insertions, 31 deletions
diff --git a/changelog.md b/changelog.md index 1c6e674d..41234e7a 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,7 @@ ## 1.6.0 * `NEW` auto require local modules * `NEW` hover function by keyword `function` +* `NEW` code action: swap params * `CHG` unbind the relative path between binaries and scripts ## 1.5.0 diff --git a/locale/en-us/script.lua b/locale/en-us/script.lua index 39dcd8ee..553f5a76 100644 --- a/locale/en-us/script.lua +++ b/locale/en-us/script.lua @@ -164,6 +164,7 @@ ACTION_FIX_DO_AS_THEN = 'Modify to `do` .' ACTION_ADD_END = 'Add `end` (infer the addition location ny indentations).' ACTION_FIX_COMMENT_PREFIX = 'Modify to `--` .' ACTION_RUNTIME_UNICODE_NAME = 'Allow Unicode characters.' +ACTION_SWAP_PARAMS = 'Change to parameter {index} or `{node}`' COMMAND_DISABLE_DIAG = 'Disable diagnostics' COMMAND_MARK_GLOBAL = 'Mark defined global' diff --git a/locale/zh-cn/script.lua b/locale/zh-cn/script.lua index 976c123f..f708981f 100644 --- a/locale/zh-cn/script.lua +++ b/locale/zh-cn/script.lua @@ -163,6 +163,7 @@ ACTION_FIX_DO_AS_THEN = '改为 `do` 。' ACTION_ADD_END = '添加 `end` (根据缩进推测添加位置)。' ACTION_FIX_COMMENT_PREFIX = '改为 `--` 。' ACTION_RUNTIME_UNICODE_NAME = '允许使用 Unicode 字符。' +ACTION_SWAP_PARAMS = '将其改为 `{node}` 的第 {index} 个参数' COMMAND_DISABLE_DIAG = '禁用诊断' COMMAND_MARK_GLOBAL = '标记全局变量' diff --git a/script/core/code-action.lua b/script/core/code-action.lua index 13d5220d..d3546bbb 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -4,6 +4,7 @@ local define = require 'proto.define' local guide = require 'parser.guide' local util = require 'utility' local sp = require 'bee.subprocess' +local vm = require 'vm' local function disableDiagnostic(uri, code, results) results[#results+1] = { @@ -133,7 +134,8 @@ local function solveSyntaxByAddDoEnd(uri, err, results) changes = { [uri] = { { - range = define.range(lines, text, err.start, err.finish), + start = err.start, + finish = err.finish, newText = ('do %s end'):format(text:sub(err.start, err.finish)), }, } @@ -148,7 +150,8 @@ local function solveSyntaxByFix(uri, err, results) local changes = {} for _, fix in ipairs(err.fix) do changes[#changes+1] = { - range = define.range(lines, text, fix.start, fix.finish), + start = fix.start, + finish = fix.finish, newText = fix.text, } end @@ -279,28 +282,92 @@ local function solveDiagnostic(uri, diag, results) disableDiagnostic(uri, diag.code, results) end +local function checkQuickFix(results, uri, diagnostics) + if not diagnostics then + return + end + for _, diag in ipairs(diagnostics) do + solveDiagnostic(uri, diag, results) + end +end + local function checkSwapParams(results, uri, start, finish) - local ast = files.getAst(uri) + local ast = files.getAst(uri) + local text = files.getText(uri) if not ast then return end - local result = guide.eachSourceBetween(ast.ast, start, finish, function (source) - if source.type == 'callargs' then - return { - node = source.parent.node, - args = source, - } - end - if source.type == 'funcargs' then - return { - node = source.parent, - args = source, + local args = {} + guide.eachSourceBetween(ast.ast, start, finish, function (source) + if source.type == 'callargs' + or source.type == 'funcargs' then + local targetIndex + for index, arg in ipairs(source) do + if arg.start - 1 <= finish and arg.finish >= start then + -- should select only one param + if targetIndex then + return + end + targetIndex = index + end + end + if not targetIndex then + return + end + local node + if source.type == 'callargs' then + node = text:sub(source.parent.node.start, source.parent.node.finish) + elseif source.type == 'funcargs' then + local var = source.parent.parent + if vm.isSet(var) then + node = text:sub(var.start, var.finish) + else + node = lang.script.SYMBOL_ANONYMOUS + end + end + args[#args+1] = { + source = source, + index = targetIndex, + node = node, } end end) - if not result then + if #args == 0 then return end + table.sort(args, function (a, b) + return a.source.start > b.source.start + end) + local target = args[1] + uri = files.getOriginUri(uri) + local myArg = target.source[target.index] + for i, targetArg in ipairs(target.source) do + if i ~= target.index then + results[#results+1] = { + title = lang.script('ACTION_SWAP_PARAMS', { + node = target.node, + index = i, + }), + kind = 'refactor.rewrite', + edit = { + changes = { + [uri] = { + { + start = myArg.start, + finish = myArg.finish, + newText = text:sub(targetArg.start, targetArg.finish), + }, + { + start = targetArg.start, + finish = targetArg.finish, + newText = text:sub(myArg.start, myArg.finish), + }, + } + } + } + } + end + end end local function checkExtractAsFunction(results, uri, start, finish) @@ -315,10 +382,7 @@ return function (uri, start, finish, diagnostics) local results = {} - for _, diag in ipairs(diagnostics) do - solveDiagnostic(uri, diag, results) - end - + checkQuickFix(results, uri, diagnostics) checkSwapParams(results, uri, start, finish) checkExtractAsFunction(results, uri, start, finish) diff --git a/script/provider/provider.lua b/script/provider/provider.lua index ab6c1ccb..5d7bc6c2 100644 --- a/script/provider/provider.lua +++ b/script/provider/provider.lua @@ -581,6 +581,20 @@ proto.on('textDocument/codeAction', function (params) return nil end + for _, res in ipairs(results) do + if res.edit then + for turi, changes in pairs(res.edit.changes) do + local ttext = files.getText(turi) + local tlines = files.getLines(turi) + for _, change in ipairs(changes) do + change.range = define.range(tlines, ttext, change.start, change.finish) + change.start = nil + change.finish = nil + end + end + end + end + return results end) diff --git a/script/utility.lua b/script/utility.lua index a1ea1804..2b9c7d2f 100644 --- a/script/utility.lua +++ b/script/utility.lua @@ -29,6 +29,9 @@ local function formatNumber(n) or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则 return ('%q'):format(n) end + if mathType(n) == 'integer' then + return tostring(n) + end local str = ('%.10f'):format(n) str = str:gsub('%.?0*$', '') return str @@ -556,4 +559,31 @@ function m.eachLine(text) end end +function m.sortByScore(tbl, callbacks) + if type(callbacks) ~= 'table' then + callbacks = { callbacks } + end + local size = #callbacks + local scoreCache = {} + for i = 1, size do + scoreCache[i] = {} + end + tableSort(tbl, function (a, b) + for i = 1, size do + local callback = callbacks[i] + local cache = scoreCache[i] + local sa = cache[a] or callback(a) + local sb = cache[b] or callback(b) + cache[a] = sa + cache[b] = sb + if sa > sb then + return true + elseif sa < sb then + return false + end + end + return false + end) +end + return m diff --git a/test/code_action/init.lua b/test/code_action/init.lua index d3b2d768..1dde7fed 100644 --- a/test/code_action/init.lua +++ b/test/code_action/init.lua @@ -4,17 +4,47 @@ local lang = require 'language' rawset(_G, 'TEST', true) +local EXISTS = {} + +local function eq(a, b) + if a == EXISTS and b ~= nil then + return true + end + if b == EXISTS and a ~= nil then + return true + end + local tp1, tp2 = type(a), type(b) + if tp1 ~= tp2 then + return false + end + if tp1 == 'table' then + local mark = {} + for k in pairs(a) do + if not eq(a[k], b[k]) then + return false + end + mark[k] = true + end + for k in pairs(b) do + if not mark[k] then + return false + end + end + return true + end + return a == b +end + function TEST(script) return function (expect) files.removeAll() local start = script:find('<?', 1, true) local finish = script:find('?>', 1, true) - local pos = (start + finish) // 2 + 1 local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ') files.setText('', new_script) - local results = core('', pos) + local results = core('', start, finish) assert(results) - assert(expect == results) + assert(eq(expect, results)) end end @@ -23,14 +53,45 @@ print(<?a?>, b, c) ]] { { - title = lang.script.ACTION_SWAP_PARAMS, + title = '将其改为 `print` 的第 2 个参数', + kind = 'refactor.rewrite', + edit = EXISTS, + }, + { + title = '将其改为 `print` 的第 3 个参数', + kind = 'refactor.rewrite', + edit = EXISTS, + }, +} + +TEST [[ +local function f(<?a?>, b, c) end +]] +{ + { + title = '将其改为 `f` 的第 2 个参数', + kind = 'refactor.rewrite', + edit = EXISTS, + }, + { + title = '将其改为 `f` 的第 3 个参数', + kind = 'refactor.rewrite', + edit = EXISTS, + }, +} + +TEST [[ +return function(<?a?>, b, c) end +]] +{ + { + title = '将其改为 `<匿名函数>` 的第 2 个参数', + kind = 'refactor.rewrite', + edit = EXISTS, + }, + { + title = '将其改为 `<匿名函数>` 的第 3 个参数', kind = 'refactor.rewrite', - edit = { - change = { - ['file:///.lua'] = { - - } - } - } - } + edit = EXISTS, + }, } |