diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2020-12-07 11:35:27 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2020-12-07 11:35:27 +0800 |
commit | e8bb3d6a84ab54249347ada01740e1127147fc13 (patch) | |
tree | 84a2e621a8fb6c9739162ee7878ad3f2b62d3552 /script | |
parent | 3238be9f9828d3728fb81e2284ac2a5c2c9b3208 (diff) | |
download | lua-language-server-e8bb3d6a84ab54249347ada01740e1127147fc13.zip |
code action: swap params
Diffstat (limited to 'script')
-rw-r--r-- | script/core/code-action.lua | 102 | ||||
-rw-r--r-- | script/provider/provider.lua | 14 | ||||
-rw-r--r-- | script/utility.lua | 30 |
3 files changed, 127 insertions, 19 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua index 13d5220d..d3546bbb 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -4,6 +4,7 @@ local define = require 'proto.define' local guide = require 'parser.guide' local util = require 'utility' local sp = require 'bee.subprocess' +local vm = require 'vm' local function disableDiagnostic(uri, code, results) results[#results+1] = { @@ -133,7 +134,8 @@ local function solveSyntaxByAddDoEnd(uri, err, results) changes = { [uri] = { { - range = define.range(lines, text, err.start, err.finish), + start = err.start, + finish = err.finish, newText = ('do %s end'):format(text:sub(err.start, err.finish)), }, } @@ -148,7 +150,8 @@ local function solveSyntaxByFix(uri, err, results) local changes = {} for _, fix in ipairs(err.fix) do changes[#changes+1] = { - range = define.range(lines, text, fix.start, fix.finish), + start = fix.start, + finish = fix.finish, newText = fix.text, } end @@ -279,28 +282,92 @@ local function solveDiagnostic(uri, diag, results) disableDiagnostic(uri, diag.code, results) end +local function checkQuickFix(results, uri, diagnostics) + if not diagnostics then + return + end + for _, diag in ipairs(diagnostics) do + solveDiagnostic(uri, diag, results) + end +end + local function checkSwapParams(results, uri, start, finish) - local ast = files.getAst(uri) + local ast = files.getAst(uri) + local text = files.getText(uri) if not ast then return end - local result = guide.eachSourceBetween(ast.ast, start, finish, function (source) - if source.type == 'callargs' then - return { - node = source.parent.node, - args = source, - } - end - if source.type == 'funcargs' then - return { - node = source.parent, - args = source, + local args = {} + guide.eachSourceBetween(ast.ast, start, finish, function (source) + if source.type == 'callargs' + or source.type == 'funcargs' then + local targetIndex + for index, arg in ipairs(source) do + if arg.start - 1 <= finish and arg.finish >= start then + -- should select only one param + if targetIndex then + return + end + targetIndex = index + end + end + if not targetIndex then + return + end + local node + if source.type == 'callargs' then + node = text:sub(source.parent.node.start, source.parent.node.finish) + elseif source.type == 'funcargs' then + local var = source.parent.parent + if vm.isSet(var) then + node = text:sub(var.start, var.finish) + else + node = lang.script.SYMBOL_ANONYMOUS + end + end + args[#args+1] = { + source = source, + index = targetIndex, + node = node, } end end) - if not result then + if #args == 0 then return end + table.sort(args, function (a, b) + return a.source.start > b.source.start + end) + local target = args[1] + uri = files.getOriginUri(uri) + local myArg = target.source[target.index] + for i, targetArg in ipairs(target.source) do + if i ~= target.index then + results[#results+1] = { + title = lang.script('ACTION_SWAP_PARAMS', { + node = target.node, + index = i, + }), + kind = 'refactor.rewrite', + edit = { + changes = { + [uri] = { + { + start = myArg.start, + finish = myArg.finish, + newText = text:sub(targetArg.start, targetArg.finish), + }, + { + start = targetArg.start, + finish = targetArg.finish, + newText = text:sub(myArg.start, myArg.finish), + }, + } + } + } + } + end + end end local function checkExtractAsFunction(results, uri, start, finish) @@ -315,10 +382,7 @@ return function (uri, start, finish, diagnostics) local results = {} - for _, diag in ipairs(diagnostics) do - solveDiagnostic(uri, diag, results) - end - + checkQuickFix(results, uri, diagnostics) checkSwapParams(results, uri, start, finish) checkExtractAsFunction(results, uri, start, finish) diff --git a/script/provider/provider.lua b/script/provider/provider.lua index ab6c1ccb..5d7bc6c2 100644 --- a/script/provider/provider.lua +++ b/script/provider/provider.lua @@ -581,6 +581,20 @@ proto.on('textDocument/codeAction', function (params) return nil end + for _, res in ipairs(results) do + if res.edit then + for turi, changes in pairs(res.edit.changes) do + local ttext = files.getText(turi) + local tlines = files.getLines(turi) + for _, change in ipairs(changes) do + change.range = define.range(tlines, ttext, change.start, change.finish) + change.start = nil + change.finish = nil + end + end + end + end + return results end) diff --git a/script/utility.lua b/script/utility.lua index a1ea1804..2b9c7d2f 100644 --- a/script/utility.lua +++ b/script/utility.lua @@ -29,6 +29,9 @@ local function formatNumber(n) or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则 return ('%q'):format(n) end + if mathType(n) == 'integer' then + return tostring(n) + end local str = ('%.10f'):format(n) str = str:gsub('%.?0*$', '') return str @@ -556,4 +559,31 @@ function m.eachLine(text) end end +function m.sortByScore(tbl, callbacks) + if type(callbacks) ~= 'table' then + callbacks = { callbacks } + end + local size = #callbacks + local scoreCache = {} + for i = 1, size do + scoreCache[i] = {} + end + tableSort(tbl, function (a, b) + for i = 1, size do + local callback = callbacks[i] + local cache = scoreCache[i] + local sa = cache[a] or callback(a) + local sb = cache[b] or callback(b) + cache[a] = sa + cache[b] = sb + if sa > sb then + return true + elseif sa < sb then + return false + end + end + return false + end) +end + return m |