summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md1
-rw-r--r--locale/en-us/script.lua1
-rw-r--r--locale/zh-cn/script.lua1
-rw-r--r--script/core/code-action.lua102
-rw-r--r--script/provider/provider.lua14
-rw-r--r--script/utility.lua30
-rw-r--r--test/code_action/init.lua85
7 files changed, 203 insertions, 31 deletions
diff --git a/changelog.md b/changelog.md
index 1c6e674d..41234e7a 100644
--- a/changelog.md
+++ b/changelog.md
@@ -3,6 +3,7 @@
## 1.6.0
* `NEW` auto require local modules
* `NEW` hover function by keyword `function`
+* `NEW` code action: swap params
* `CHG` unbind the relative path between binaries and scripts
## 1.5.0
diff --git a/locale/en-us/script.lua b/locale/en-us/script.lua
index 39dcd8ee..553f5a76 100644
--- a/locale/en-us/script.lua
+++ b/locale/en-us/script.lua
@@ -164,6 +164,7 @@ ACTION_FIX_DO_AS_THEN = 'Modify to `do` .'
ACTION_ADD_END = 'Add `end` (infer the addition location ny indentations).'
ACTION_FIX_COMMENT_PREFIX = 'Modify to `--` .'
ACTION_RUNTIME_UNICODE_NAME = 'Allow Unicode characters.'
+ACTION_SWAP_PARAMS = 'Change to parameter {index} or `{node}`'
COMMAND_DISABLE_DIAG = 'Disable diagnostics'
COMMAND_MARK_GLOBAL = 'Mark defined global'
diff --git a/locale/zh-cn/script.lua b/locale/zh-cn/script.lua
index 976c123f..f708981f 100644
--- a/locale/zh-cn/script.lua
+++ b/locale/zh-cn/script.lua
@@ -163,6 +163,7 @@ ACTION_FIX_DO_AS_THEN = '改为 `do` 。'
ACTION_ADD_END = '添加 `end` (根据缩进推测添加位置)。'
ACTION_FIX_COMMENT_PREFIX = '改为 `--` 。'
ACTION_RUNTIME_UNICODE_NAME = '允许使用 Unicode 字符。'
+ACTION_SWAP_PARAMS = '将其改为 `{node}` 的第 {index} 个参数'
COMMAND_DISABLE_DIAG = '禁用诊断'
COMMAND_MARK_GLOBAL = '标记全局变量'
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
index 13d5220d..d3546bbb 100644
--- a/script/core/code-action.lua
+++ b/script/core/code-action.lua
@@ -4,6 +4,7 @@ local define = require 'proto.define'
local guide = require 'parser.guide'
local util = require 'utility'
local sp = require 'bee.subprocess'
+local vm = require 'vm'
local function disableDiagnostic(uri, code, results)
results[#results+1] = {
@@ -133,7 +134,8 @@ local function solveSyntaxByAddDoEnd(uri, err, results)
changes = {
[uri] = {
{
- range = define.range(lines, text, err.start, err.finish),
+ start = err.start,
+ finish = err.finish,
newText = ('do %s end'):format(text:sub(err.start, err.finish)),
},
}
@@ -148,7 +150,8 @@ local function solveSyntaxByFix(uri, err, results)
local changes = {}
for _, fix in ipairs(err.fix) do
changes[#changes+1] = {
- range = define.range(lines, text, fix.start, fix.finish),
+ start = fix.start,
+ finish = fix.finish,
newText = fix.text,
}
end
@@ -279,28 +282,92 @@ local function solveDiagnostic(uri, diag, results)
disableDiagnostic(uri, diag.code, results)
end
+local function checkQuickFix(results, uri, diagnostics)
+ if not diagnostics then
+ return
+ end
+ for _, diag in ipairs(diagnostics) do
+ solveDiagnostic(uri, diag, results)
+ end
+end
+
local function checkSwapParams(results, uri, start, finish)
- local ast = files.getAst(uri)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
if not ast then
return
end
- local result = guide.eachSourceBetween(ast.ast, start, finish, function (source)
- if source.type == 'callargs' then
- return {
- node = source.parent.node,
- args = source,
- }
- end
- if source.type == 'funcargs' then
- return {
- node = source.parent,
- args = source,
+ local args = {}
+ guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ if source.type == 'callargs'
+ or source.type == 'funcargs' then
+ local targetIndex
+ for index, arg in ipairs(source) do
+ if arg.start - 1 <= finish and arg.finish >= start then
+ -- should select only one param
+ if targetIndex then
+ return
+ end
+ targetIndex = index
+ end
+ end
+ if not targetIndex then
+ return
+ end
+ local node
+ if source.type == 'callargs' then
+ node = text:sub(source.parent.node.start, source.parent.node.finish)
+ elseif source.type == 'funcargs' then
+ local var = source.parent.parent
+ if vm.isSet(var) then
+ node = text:sub(var.start, var.finish)
+ else
+ node = lang.script.SYMBOL_ANONYMOUS
+ end
+ end
+ args[#args+1] = {
+ source = source,
+ index = targetIndex,
+ node = node,
}
end
end)
- if not result then
+ if #args == 0 then
return
end
+ table.sort(args, function (a, b)
+ return a.source.start > b.source.start
+ end)
+ local target = args[1]
+ uri = files.getOriginUri(uri)
+ local myArg = target.source[target.index]
+ for i, targetArg in ipairs(target.source) do
+ if i ~= target.index then
+ results[#results+1] = {
+ title = lang.script('ACTION_SWAP_PARAMS', {
+ node = target.node,
+ index = i,
+ }),
+ kind = 'refactor.rewrite',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ start = myArg.start,
+ finish = myArg.finish,
+ newText = text:sub(targetArg.start, targetArg.finish),
+ },
+ {
+ start = targetArg.start,
+ finish = targetArg.finish,
+ newText = text:sub(myArg.start, myArg.finish),
+ },
+ }
+ }
+ }
+ }
+ end
+ end
end
local function checkExtractAsFunction(results, uri, start, finish)
@@ -315,10 +382,7 @@ return function (uri, start, finish, diagnostics)
local results = {}
- for _, diag in ipairs(diagnostics) do
- solveDiagnostic(uri, diag, results)
- end
-
+ checkQuickFix(results, uri, diagnostics)
checkSwapParams(results, uri, start, finish)
checkExtractAsFunction(results, uri, start, finish)
diff --git a/script/provider/provider.lua b/script/provider/provider.lua
index ab6c1ccb..5d7bc6c2 100644
--- a/script/provider/provider.lua
+++ b/script/provider/provider.lua
@@ -581,6 +581,20 @@ proto.on('textDocument/codeAction', function (params)
return nil
end
+ for _, res in ipairs(results) do
+ if res.edit then
+ for turi, changes in pairs(res.edit.changes) do
+ local ttext = files.getText(turi)
+ local tlines = files.getLines(turi)
+ for _, change in ipairs(changes) do
+ change.range = define.range(tlines, ttext, change.start, change.finish)
+ change.start = nil
+ change.finish = nil
+ end
+ end
+ end
+ end
+
return results
end)
diff --git a/script/utility.lua b/script/utility.lua
index a1ea1804..2b9c7d2f 100644
--- a/script/utility.lua
+++ b/script/utility.lua
@@ -29,6 +29,9 @@ local function formatNumber(n)
or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则
return ('%q'):format(n)
end
+ if mathType(n) == 'integer' then
+ return tostring(n)
+ end
local str = ('%.10f'):format(n)
str = str:gsub('%.?0*$', '')
return str
@@ -556,4 +559,31 @@ function m.eachLine(text)
end
end
+function m.sortByScore(tbl, callbacks)
+ if type(callbacks) ~= 'table' then
+ callbacks = { callbacks }
+ end
+ local size = #callbacks
+ local scoreCache = {}
+ for i = 1, size do
+ scoreCache[i] = {}
+ end
+ tableSort(tbl, function (a, b)
+ for i = 1, size do
+ local callback = callbacks[i]
+ local cache = scoreCache[i]
+ local sa = cache[a] or callback(a)
+ local sb = cache[b] or callback(b)
+ cache[a] = sa
+ cache[b] = sb
+ if sa > sb then
+ return true
+ elseif sa < sb then
+ return false
+ end
+ end
+ return false
+ end)
+end
+
return m
diff --git a/test/code_action/init.lua b/test/code_action/init.lua
index d3b2d768..1dde7fed 100644
--- a/test/code_action/init.lua
+++ b/test/code_action/init.lua
@@ -4,17 +4,47 @@ local lang = require 'language'
rawset(_G, 'TEST', true)
+local EXISTS = {}
+
+local function eq(a, b)
+ if a == EXISTS and b ~= nil then
+ return true
+ end
+ if b == EXISTS and a ~= nil then
+ return true
+ end
+ local tp1, tp2 = type(a), type(b)
+ if tp1 ~= tp2 then
+ return false
+ end
+ if tp1 == 'table' then
+ local mark = {}
+ for k in pairs(a) do
+ if not eq(a[k], b[k]) then
+ return false
+ end
+ mark[k] = true
+ end
+ for k in pairs(b) do
+ if not mark[k] then
+ return false
+ end
+ end
+ return true
+ end
+ return a == b
+end
+
function TEST(script)
return function (expect)
files.removeAll()
local start = script:find('<?', 1, true)
local finish = script:find('?>', 1, true)
- local pos = (start + finish) // 2 + 1
local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ')
files.setText('', new_script)
- local results = core('', pos)
+ local results = core('', start, finish)
assert(results)
- assert(expect == results)
+ assert(eq(expect, results))
end
end
@@ -23,14 +53,45 @@ print(<?a?>, b, c)
]]
{
{
- title = lang.script.ACTION_SWAP_PARAMS,
+ title = '将其改为 `print` 的第 2 个参数',
+ kind = 'refactor.rewrite',
+ edit = EXISTS,
+ },
+ {
+ title = '将其改为 `print` 的第 3 个参数',
+ kind = 'refactor.rewrite',
+ edit = EXISTS,
+ },
+}
+
+TEST [[
+local function f(<?a?>, b, c) end
+]]
+{
+ {
+ title = '将其改为 `f` 的第 2 个参数',
+ kind = 'refactor.rewrite',
+ edit = EXISTS,
+ },
+ {
+ title = '将其改为 `f` 的第 3 个参数',
+ kind = 'refactor.rewrite',
+ edit = EXISTS,
+ },
+}
+
+TEST [[
+return function(<?a?>, b, c) end
+]]
+{
+ {
+ title = '将其改为 `<匿名函数>` 的第 2 个参数',
+ kind = 'refactor.rewrite',
+ edit = EXISTS,
+ },
+ {
+ title = '将其改为 `<匿名函数>` 的第 3 个参数',
kind = 'refactor.rewrite',
- edit = {
- change = {
- ['file:///.lua'] = {
-
- }
- }
- }
- }
+ edit = EXISTS,
+ },
}