summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2020-12-07 11:35:27 +0800
committer最萌小汐 <sumneko@hotmail.com>2020-12-07 11:35:27 +0800
commite8bb3d6a84ab54249347ada01740e1127147fc13 (patch)
tree84a2e621a8fb6c9739162ee7878ad3f2b62d3552 /script
parent3238be9f9828d3728fb81e2284ac2a5c2c9b3208 (diff)
downloadlua-language-server-e8bb3d6a84ab54249347ada01740e1127147fc13.zip
code action: swap params
Diffstat (limited to 'script')
-rw-r--r--script/core/code-action.lua102
-rw-r--r--script/provider/provider.lua14
-rw-r--r--script/utility.lua30
3 files changed, 127 insertions, 19 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
index 13d5220d..d3546bbb 100644
--- a/script/core/code-action.lua
+++ b/script/core/code-action.lua
@@ -4,6 +4,7 @@ local define = require 'proto.define'
local guide = require 'parser.guide'
local util = require 'utility'
local sp = require 'bee.subprocess'
+local vm = require 'vm'
local function disableDiagnostic(uri, code, results)
results[#results+1] = {
@@ -133,7 +134,8 @@ local function solveSyntaxByAddDoEnd(uri, err, results)
changes = {
[uri] = {
{
- range = define.range(lines, text, err.start, err.finish),
+ start = err.start,
+ finish = err.finish,
newText = ('do %s end'):format(text:sub(err.start, err.finish)),
},
}
@@ -148,7 +150,8 @@ local function solveSyntaxByFix(uri, err, results)
local changes = {}
for _, fix in ipairs(err.fix) do
changes[#changes+1] = {
- range = define.range(lines, text, fix.start, fix.finish),
+ start = fix.start,
+ finish = fix.finish,
newText = fix.text,
}
end
@@ -279,28 +282,92 @@ local function solveDiagnostic(uri, diag, results)
disableDiagnostic(uri, diag.code, results)
end
+local function checkQuickFix(results, uri, diagnostics)
+ if not diagnostics then
+ return
+ end
+ for _, diag in ipairs(diagnostics) do
+ solveDiagnostic(uri, diag, results)
+ end
+end
+
local function checkSwapParams(results, uri, start, finish)
- local ast = files.getAst(uri)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
if not ast then
return
end
- local result = guide.eachSourceBetween(ast.ast, start, finish, function (source)
- if source.type == 'callargs' then
- return {
- node = source.parent.node,
- args = source,
- }
- end
- if source.type == 'funcargs' then
- return {
- node = source.parent,
- args = source,
+ local args = {}
+ guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ if source.type == 'callargs'
+ or source.type == 'funcargs' then
+ local targetIndex
+ for index, arg in ipairs(source) do
+ if arg.start - 1 <= finish and arg.finish >= start then
+ -- should select only one param
+ if targetIndex then
+ return
+ end
+ targetIndex = index
+ end
+ end
+ if not targetIndex then
+ return
+ end
+ local node
+ if source.type == 'callargs' then
+ node = text:sub(source.parent.node.start, source.parent.node.finish)
+ elseif source.type == 'funcargs' then
+ local var = source.parent.parent
+ if vm.isSet(var) then
+ node = text:sub(var.start, var.finish)
+ else
+ node = lang.script.SYMBOL_ANONYMOUS
+ end
+ end
+ args[#args+1] = {
+ source = source,
+ index = targetIndex,
+ node = node,
}
end
end)
- if not result then
+ if #args == 0 then
return
end
+ table.sort(args, function (a, b)
+ return a.source.start > b.source.start
+ end)
+ local target = args[1]
+ uri = files.getOriginUri(uri)
+ local myArg = target.source[target.index]
+ for i, targetArg in ipairs(target.source) do
+ if i ~= target.index then
+ results[#results+1] = {
+ title = lang.script('ACTION_SWAP_PARAMS', {
+ node = target.node,
+ index = i,
+ }),
+ kind = 'refactor.rewrite',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ start = myArg.start,
+ finish = myArg.finish,
+ newText = text:sub(targetArg.start, targetArg.finish),
+ },
+ {
+ start = targetArg.start,
+ finish = targetArg.finish,
+ newText = text:sub(myArg.start, myArg.finish),
+ },
+ }
+ }
+ }
+ }
+ end
+ end
end
local function checkExtractAsFunction(results, uri, start, finish)
@@ -315,10 +382,7 @@ return function (uri, start, finish, diagnostics)
local results = {}
- for _, diag in ipairs(diagnostics) do
- solveDiagnostic(uri, diag, results)
- end
-
+ checkQuickFix(results, uri, diagnostics)
checkSwapParams(results, uri, start, finish)
checkExtractAsFunction(results, uri, start, finish)
diff --git a/script/provider/provider.lua b/script/provider/provider.lua
index ab6c1ccb..5d7bc6c2 100644
--- a/script/provider/provider.lua
+++ b/script/provider/provider.lua
@@ -581,6 +581,20 @@ proto.on('textDocument/codeAction', function (params)
return nil
end
+ for _, res in ipairs(results) do
+ if res.edit then
+ for turi, changes in pairs(res.edit.changes) do
+ local ttext = files.getText(turi)
+ local tlines = files.getLines(turi)
+ for _, change in ipairs(changes) do
+ change.range = define.range(tlines, ttext, change.start, change.finish)
+ change.start = nil
+ change.finish = nil
+ end
+ end
+ end
+ end
+
return results
end)
diff --git a/script/utility.lua b/script/utility.lua
index a1ea1804..2b9c7d2f 100644
--- a/script/utility.lua
+++ b/script/utility.lua
@@ -29,6 +29,9 @@ local function formatNumber(n)
or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则
return ('%q'):format(n)
end
+ if mathType(n) == 'integer' then
+ return tostring(n)
+ end
local str = ('%.10f'):format(n)
str = str:gsub('%.?0*$', '')
return str
@@ -556,4 +559,31 @@ function m.eachLine(text)
end
end
+function m.sortByScore(tbl, callbacks)
+ if type(callbacks) ~= 'table' then
+ callbacks = { callbacks }
+ end
+ local size = #callbacks
+ local scoreCache = {}
+ for i = 1, size do
+ scoreCache[i] = {}
+ end
+ tableSort(tbl, function (a, b)
+ for i = 1, size do
+ local callback = callbacks[i]
+ local cache = scoreCache[i]
+ local sa = cache[a] or callback(a)
+ local sb = cache[b] or callback(b)
+ cache[a] = sa
+ cache[b] = sb
+ if sa > sb then
+ return true
+ elseif sa < sb then
+ return false
+ end
+ end
+ return false
+ end)
+end
+
return m