summaryrefslogtreecommitdiff
path: root/script-beta/src/core
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2019-11-22 23:26:32 +0800
committer最萌小汐 <sumneko@hotmail.com>2019-11-22 23:26:32 +0800
commitd0ff66c9abe9d6abbca12fd811e0c3cb69c1033a (patch)
treebb34518d70b85de7656dbdbe958dfa221a3ff3b3 /script-beta/src/core
parent0a2c2ad15e1ec359171fb0dd4c72e57c5b66e9ba (diff)
downloadlua-language-server-d0ff66c9abe9d6abbca12fd811e0c3cb69c1033a.zip
整理一下目录结构
Diffstat (limited to 'script-beta/src/core')
-rw-r--r--script-beta/src/core/definition.lua105
-rw-r--r--script-beta/src/core/diagnostics/ambiguity-1.lua69
-rw-r--r--script-beta/src/core/diagnostics/duplicate-index.lua62
-rw-r--r--script-beta/src/core/diagnostics/emmy-lua.lua3
-rw-r--r--script-beta/src/core/diagnostics/empty-block.lua49
-rw-r--r--script-beta/src/core/diagnostics/global-in-nil-env.lua66
-rw-r--r--script-beta/src/core/diagnostics/init.lua41
-rw-r--r--script-beta/src/core/diagnostics/lowercase-global.lua39
-rw-r--r--script-beta/src/core/diagnostics/newfield-call.lua37
-rw-r--r--script-beta/src/core/diagnostics/newline-call.lua38
-rw-r--r--script-beta/src/core/diagnostics/redefined-local.lua32
-rw-r--r--script-beta/src/core/diagnostics/redundant-parameter.lua102
-rw-r--r--script-beta/src/core/diagnostics/redundant-value.lua24
-rw-r--r--script-beta/src/core/diagnostics/trailing-space.lua55
-rw-r--r--script-beta/src/core/diagnostics/undefined-env-child.lua32
-rw-r--r--script-beta/src/core/diagnostics/undefined-global.lua63
-rw-r--r--script-beta/src/core/diagnostics/unused-function.lua45
-rw-r--r--script-beta/src/core/diagnostics/unused-label.lua22
-rw-r--r--script-beta/src/core/diagnostics/unused-local.lua46
-rw-r--r--script-beta/src/core/diagnostics/unused-vararg.lua31
-rw-r--r--script-beta/src/core/highlight.lua230
-rw-r--r--script-beta/src/core/hover/arg.lua20
-rw-r--r--script-beta/src/core/hover/init.lua56
-rw-r--r--script-beta/src/core/hover/label.lua103
-rw-r--r--script-beta/src/core/hover/name.lua64
-rw-r--r--script-beta/src/core/hover/return.lua34
-rw-r--r--script-beta/src/core/hover/table.lua35
-rw-r--r--script-beta/src/core/reference.lua84
-rw-r--r--script-beta/src/core/rename.lua374
29 files changed, 1961 insertions, 0 deletions
diff --git a/script-beta/src/core/definition.lua b/script-beta/src/core/definition.lua
new file mode 100644
index 00000000..865fc7cb
--- /dev/null
+++ b/script-beta/src/core/definition.lua
@@ -0,0 +1,105 @@
+local guide = require 'parser.guide'
+local workspace = require 'workspace'
+local files = require 'files'
+local vm = require 'vm'
+
+local function findDef(source, callback)
+ if source.type ~= 'local'
+ and source.type ~= 'getlocal'
+ and source.type ~= 'setlocal'
+ and source.type ~= 'setglobal'
+ and source.type ~= 'getglobal'
+ and source.type ~= 'field'
+ and source.type ~= 'method'
+ and source.type ~= 'string'
+ and source.type ~= 'number'
+ and source.type ~= 'boolean'
+ and source.type ~= 'goto' then
+ return
+ end
+ vm.eachDef(source, function (info)
+ if info.mode == 'declare'
+ or info.mode == 'set'
+ or info.mode == 'return' then
+ local src = info.source
+ local root = guide.getRoot(src)
+ local uri = root.uri
+ if src.type == 'setfield'
+ or src.type == 'getfield'
+ or src.type == 'tablefield' then
+ callback(src.field, uri)
+ elseif src.type == 'setindex'
+ or src.type == 'getindex'
+ or src.type == 'tableindex' then
+ callback(src.index, uri)
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ callback(src.method, uri)
+ else
+ callback(src, uri)
+ end
+ end
+ end)
+end
+
+local function checkRequire(source, offset, callback)
+ if source.type ~= 'call' then
+ return
+ end
+ local func = source.node
+ local pathSource = source.args and source.args[1]
+ if not pathSource then
+ return
+ end
+ if not guide.isContain(pathSource, offset) then
+ return
+ end
+ local literal = guide.getLiteral(pathSource)
+ if type(literal) ~= 'string' then
+ return
+ end
+ local name = func.special
+ if name == 'require' then
+ local result = workspace.findUrisByRequirePath(literal, true)
+ for _, uri in ipairs(result) do
+ callback(uri)
+ end
+ elseif name == 'dofile'
+ or name == 'loadfile' then
+ local result = workspace.findUrisByFilePath(literal, true)
+ for _, uri in ipairs(result) do
+ callback(uri)
+ end
+ end
+end
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local results = {}
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ checkRequire(source, offset, function (uri)
+ results[#results+1] = {
+ uri = files.getOriginUri(uri),
+ source = source,
+ target = {
+ start = 0,
+ finish = 0,
+ }
+ }
+ end)
+ findDef(source, function (target, uri)
+ results[#results+1] = {
+ target = target,
+ uri = files.getOriginUri(uri),
+ source = source,
+ }
+ end)
+ end)
+ if #results == 0 then
+ return nil
+ end
+ return results
+end
diff --git a/script-beta/src/core/diagnostics/ambiguity-1.lua b/script-beta/src/core/diagnostics/ambiguity-1.lua
new file mode 100644
index 00000000..37815fb5
--- /dev/null
+++ b/script-beta/src/core/diagnostics/ambiguity-1.lua
@@ -0,0 +1,69 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+local opMap = {
+ ['+'] = true,
+ ['-'] = true,
+ ['*'] = true,
+ ['/'] = true,
+ ['//'] = true,
+ ['^'] = true,
+ ['<<'] = true,
+ ['>>'] = true,
+ ['&'] = true,
+ ['|'] = true,
+ ['~'] = true,
+ ['..'] = true,
+}
+
+local literalMap = {
+ ['number'] = true,
+ ['boolean'] = true,
+ ['string'] = true,
+ ['table'] = true,
+}
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local text = files.getText(uri)
+ guide.eachSourceType(ast.ast, 'binary', function (source)
+ if source.op.type ~= 'or' then
+ return
+ end
+ local first = source[1]
+ local second = source[2]
+ -- a + (b or 0) --> (a + b) or 0
+ do
+ if opMap[first.op and first.op.type]
+ and first.type ~= 'unary'
+ and not second.op
+ and literalMap[second.type]
+ and not literalMap[first[2].type]
+ then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_AMBIGUITY_1', text:sub(first.start, first.finish))
+ }
+ end
+ end
+ -- (a or 0) + c --> a or (0 + c)
+ do
+ if opMap[second.op and second.op.type]
+ and second.type ~= 'unary'
+ and not first.op
+ and literalMap[second[1].type]
+ then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_AMBIGUITY_1', text:sub(second.start, second.finish))
+ }
+ end
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/duplicate-index.lua b/script-beta/src/core/diagnostics/duplicate-index.lua
new file mode 100644
index 00000000..76b1c958
--- /dev/null
+++ b/script-beta/src/core/diagnostics/duplicate-index.lua
@@ -0,0 +1,62 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'table', function (source)
+ local mark = {}
+ for _, obj in ipairs(source) do
+ if obj.type == 'tablefield'
+ or obj.type == 'tableindex' then
+ local name = guide.getKeyName(obj)
+ if name then
+ if not mark[name] then
+ mark[name] = {}
+ end
+ mark[name][#mark[name]+1] = obj.field or obj.index
+ end
+ end
+ end
+
+ for name, defs in pairs(mark) do
+ local sname = name:match '^.|(.+)$'
+ if #defs > 1 and sname then
+ local related = {}
+ for i = 1, #defs do
+ local def = defs[i]
+ related[i] = {
+ start = def.start,
+ finish = def.finish,
+ uri = uri,
+ }
+ end
+ for i = 1, #defs - 1 do
+ local def = defs[i]
+ callback {
+ start = def.start,
+ finish = def.finish,
+ related = related,
+ message = lang.script('DIAG_DUPLICATE_INDEX', sname),
+ level = define.DiagnosticSeverity.Hint,
+ tags = { define.DiagnosticTag.Unnecessary },
+ }
+ end
+ for i = #defs, #defs do
+ local def = defs[i]
+ callback {
+ start = def.start,
+ finish = def.finish,
+ related = related,
+ message = lang.script('DIAG_DUPLICATE_INDEX', sname),
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/emmy-lua.lua b/script-beta/src/core/diagnostics/emmy-lua.lua
new file mode 100644
index 00000000..b3d19c21
--- /dev/null
+++ b/script-beta/src/core/diagnostics/emmy-lua.lua
@@ -0,0 +1,3 @@
+return function ()
+
+end
diff --git a/script-beta/src/core/diagnostics/empty-block.lua b/script-beta/src/core/diagnostics/empty-block.lua
new file mode 100644
index 00000000..2024f4e3
--- /dev/null
+++ b/script-beta/src/core/diagnostics/empty-block.lua
@@ -0,0 +1,49 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+
+-- 检查空代码块
+-- 但是排除忙等待(repeat/while)
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'if', function (source)
+ for _, block in ipairs(source) do
+ if #block > 0 then
+ return
+ end
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+ guide.eachSourceType(ast.ast, 'loop', function (source)
+ if #source > 0 then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+ guide.eachSourceType(ast.ast, 'in', function (source)
+ if #source > 0 then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/global-in-nil-env.lua b/script-beta/src/core/diagnostics/global-in-nil-env.lua
new file mode 100644
index 00000000..9a0d4f35
--- /dev/null
+++ b/script-beta/src/core/diagnostics/global-in-nil-env.lua
@@ -0,0 +1,66 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+-- TODO: 检查路径是否可达
+local function mayRun(path)
+ return true
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local root = guide.getRoot(ast.ast)
+ local env = guide.getENV(root)
+
+ local nilDefs = {}
+ if not env.ref then
+ return
+ end
+ for _, ref in ipairs(env.ref) do
+ if ref.type == 'setlocal' then
+ if ref.value and ref.value.type == 'nil' then
+ nilDefs[#nilDefs+1] = ref
+ end
+ end
+ end
+
+ if #nilDefs == 0 then
+ return
+ end
+
+ local function check(source)
+ local node = source.node
+ if node.tag == '_ENV' then
+ local ok
+ for _, nilDef in ipairs(nilDefs) do
+ local mode, pathA = guide.getPath(nilDef, source)
+ if mode == 'before'
+ and mayRun(pathA) then
+ ok = nilDef
+ break
+ end
+ end
+ if ok then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ uri = uri,
+ message = lang.script.DIAG_GLOBAL_IN_NIL_ENV,
+ related = {
+ {
+ start = ok.start,
+ finish = ok.finish,
+ uri = uri,
+ }
+ }
+ }
+ end
+ end
+ end
+
+ guide.eachSourceType(ast.ast, 'getglobal', check)
+ guide.eachSourceType(ast.ast, 'setglobal', check)
+end
diff --git a/script-beta/src/core/diagnostics/init.lua b/script-beta/src/core/diagnostics/init.lua
new file mode 100644
index 00000000..0d523f26
--- /dev/null
+++ b/script-beta/src/core/diagnostics/init.lua
@@ -0,0 +1,41 @@
+local files = require 'files'
+local define = require 'proto.define'
+local config = require 'config'
+local await = require 'await'
+
+local function check(uri, name, level, results)
+ if config.config.diagnostics.disable[name] then
+ return
+ end
+ level = config.config.diagnostics.severity[name] or level
+ local severity = define.DiagnosticSeverity[level]
+ local clock = os.clock()
+ require('core.diagnostics.' .. name)(uri, function (result)
+ result.level = severity or result.level
+ result.code = name
+ results[#results+1] = result
+ end, name)
+ local passed = os.clock() - clock
+ if passed >= 0.5 then
+ log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed))
+ await.delay()
+ end
+end
+
+return function (uri)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local results = {}
+
+ for name, level in pairs(define.DiagnosticDefaultSeverity) do
+ check(uri, name, level, results)
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ return results
+end
diff --git a/script-beta/src/core/diagnostics/lowercase-global.lua b/script-beta/src/core/diagnostics/lowercase-global.lua
new file mode 100644
index 00000000..bc48e1e6
--- /dev/null
+++ b/script-beta/src/core/diagnostics/lowercase-global.lua
@@ -0,0 +1,39 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local config = require 'config'
+local library = require 'library'
+
+-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删)
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local definedGlobal = {}
+ for name in pairs(config.config.diagnostics.globals) do
+ definedGlobal[name] = true
+ end
+ for name in pairs(library.global) do
+ definedGlobal[name] = true
+ end
+
+ guide.eachSourceType(ast.ast, 'setglobal', function (source)
+ local name = guide.getName(source)
+ if definedGlobal[name] then
+ return
+ end
+ local first = name:match '%w'
+ if not first then
+ return
+ end
+ if first:match '%l' then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script.DIAG_LOWERCASE_GLOBAL,
+ }
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/newfield-call.lua b/script-beta/src/core/diagnostics/newfield-call.lua
new file mode 100644
index 00000000..75681cbc
--- /dev/null
+++ b/script-beta/src/core/diagnostics/newfield-call.lua
@@ -0,0 +1,37 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+
+ guide.eachSourceType(ast.ast, 'table', function (source)
+ for i = 1, #source do
+ local field = source[i]
+ if field.type == 'call' then
+ local func = field.node
+ local args = field.args
+ if args then
+ local funcLine = guide.positionOf(lines, func.finish)
+ local argsLine = guide.positionOf(lines, args.start)
+ if argsLine > funcLine then
+ callback {
+ start = field.start,
+ finish = field.finish,
+ message = lang.script('DIAG_PREFIELD_CALL'
+ , text:sub(func.start, func.finish)
+ , text:sub(args.start, args.finish)
+ )
+ }
+ end
+ end
+ end
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/newline-call.lua b/script-beta/src/core/diagnostics/newline-call.lua
new file mode 100644
index 00000000..cb318380
--- /dev/null
+++ b/script-beta/src/core/diagnostics/newline-call.lua
@@ -0,0 +1,38 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local lines = files.getLines(uri)
+
+ guide.eachSourceType(ast.ast, 'call', function (source)
+ local node = source.node
+ local args = source.args
+ if not args then
+ return
+ end
+
+ -- 必须有其他人在继续使用当前对象
+ if not source.next then
+ return
+ end
+
+ local nodeRow = guide.positionOf(lines, node.finish)
+ local argRow = guide.positionOf(lines, args.start)
+ if nodeRow == argRow then
+ return
+ end
+
+ if #args == 1 then
+ callback {
+ start = args.start,
+ finish = args.finish,
+ message = lang.script.DIAG_PREVIOUS_CALL,
+ }
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/redefined-local.lua b/script-beta/src/core/diagnostics/redefined-local.lua
new file mode 100644
index 00000000..f6176794
--- /dev/null
+++ b/script-beta/src/core/diagnostics/redefined-local.lua
@@ -0,0 +1,32 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'local', function (source)
+ local name = source[1]
+ if name == '_'
+ or name == '_ENV' then
+ return
+ end
+ local exist = guide.getLocal(source, name, source.start-1)
+ if exist then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_REDEFINED_LOCAL', name),
+ related = {
+ {
+ start = exist.start,
+ finish = exist.finish,
+ uri = uri,
+ }
+ },
+ }
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/redundant-parameter.lua b/script-beta/src/core/diagnostics/redundant-parameter.lua
new file mode 100644
index 00000000..ec14188e
--- /dev/null
+++ b/script-beta/src/core/diagnostics/redundant-parameter.lua
@@ -0,0 +1,102 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+local define = require 'proto.define'
+local await = require 'await'
+
+local function countLibraryArgs(source)
+ local func = vm.getLibrary(source)
+ if not func then
+ return nil
+ end
+ local result = 0
+ if not func.args then
+ return result
+ end
+ if func.args[#func.args].type == '...' then
+ return math.maxinteger
+ end
+ result = result + #func.args
+ return result
+end
+
+local function countCallArgs(source)
+ local result = 0
+ if not source.args then
+ return 0
+ end
+ if source.node and source.node.type == 'getmethod' then
+ result = result + 1
+ end
+ result = result + #source.args
+ return result
+end
+
+local function countFuncArgs(source)
+ local result = 0
+ if not source.args then
+ return result
+ end
+ if source.args[#source.args].type == '...' then
+ return math.maxinteger
+ end
+ if source.parent and source.parent.type == 'setmethod' then
+ result = result + 1
+ end
+ result = result + #source.args
+ return result
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'call', function (source)
+ local callArgs = countCallArgs(source)
+ if callArgs == 0 then
+ return
+ end
+
+ await.delay(function ()
+ return files.globalVersion
+ end)
+
+ local func = source.node
+ local funcArgs
+ vm.eachDef(func, function (info)
+ if info.mode == 'value' then
+ local src = info.source
+ if src.type == 'function' then
+ local args = countFuncArgs(src)
+ if not funcArgs or args > funcArgs then
+ funcArgs = args
+ end
+ end
+ end
+ end)
+
+ funcArgs = funcArgs or countLibraryArgs(func)
+ if not funcArgs then
+ return
+ end
+
+ local delta = callArgs - funcArgs
+ if delta <= 0 then
+ return
+ end
+ for i = #source.args - delta + 1, #source.args do
+ local arg = source.args[i]
+ if arg then
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
+ }
+ end
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/redundant-value.lua b/script-beta/src/core/diagnostics/redundant-value.lua
new file mode 100644
index 00000000..be483448
--- /dev/null
+++ b/script-beta/src/core/diagnostics/redundant-value.lua
@@ -0,0 +1,24 @@
+local files = require 'files'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback, code)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local diags = ast.diags[code]
+ if not diags then
+ return
+ end
+
+ for _, info in ipairs(diags) do
+ callback {
+ start = info.start,
+ finish = info.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_OVER_MAX_VALUES', info.max, info.passed)
+ }
+ end
+end
diff --git a/script-beta/src/core/diagnostics/trailing-space.lua b/script-beta/src/core/diagnostics/trailing-space.lua
new file mode 100644
index 00000000..e54a6e60
--- /dev/null
+++ b/script-beta/src/core/diagnostics/trailing-space.lua
@@ -0,0 +1,55 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+
+local function isInString(ast, offset)
+ local result = false
+ guide.eachSourceType(ast, 'string', function (source)
+ if offset >= source.start and offset <= source.finish then
+ result = true
+ end
+ end)
+ return result
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ for i = 1, #lines do
+ local start = lines[i].start
+ local range = lines[i].range
+ local lastChar = text:sub(range, range)
+ if lastChar ~= ' ' and lastChar ~= '\t' then
+ goto NEXT_LINE
+ end
+ if isInString(ast.ast, range) then
+ goto NEXT_LINE
+ end
+ local first = start
+ for n = range - 1, start, -1 do
+ local char = text:sub(n, n)
+ if char ~= ' ' and char ~= '\t' then
+ first = n + 1
+ break
+ end
+ end
+ if first == start then
+ callback {
+ start = first,
+ finish = range,
+ message = lang.script.DIAG_LINE_ONLY_SPACE,
+ }
+ else
+ callback {
+ start = first,
+ finish = range,
+ message = lang.script.DIAG_LINE_POST_SPACE,
+ }
+ end
+ ::NEXT_LINE::
+ end
+end
diff --git a/script-beta/src/core/diagnostics/undefined-env-child.lua b/script-beta/src/core/diagnostics/undefined-env-child.lua
new file mode 100644
index 00000000..df096cb8
--- /dev/null
+++ b/script-beta/src/core/diagnostics/undefined-env-child.lua
@@ -0,0 +1,32 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ -- 再遍历一次 getglobal ,找出 _ENV 被重载的情况
+ guide.eachSourceType(ast.ast, 'getglobal', function (source)
+ -- 单独验证自己是否在重载过的 _ENV 中有定义
+ if source.node.tag == '_ENV' then
+ return
+ end
+ local setInENV = vm.eachRef(source, function (info)
+ if info.mode == 'set' then
+ return true
+ end
+ end)
+ if setInENV then
+ return
+ end
+ local key = source[1]
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_UNDEF_ENV_CHILD', key),
+ }
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/undefined-global.lua b/script-beta/src/core/diagnostics/undefined-global.lua
new file mode 100644
index 00000000..ed81ced3
--- /dev/null
+++ b/script-beta/src/core/diagnostics/undefined-global.lua
@@ -0,0 +1,63 @@
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local library = require 'library'
+local config = require 'config'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local globalCache = {}
+
+ -- 遍历全局变量,检查所有没有 mode['set'] 的全局变量
+ local globals = vm.getGlobals(ast.ast)
+ for key, infos in pairs(globals) do
+ if infos.mode['set'] == true then
+ goto CONTINUE
+ end
+ if globalCache[key] then
+ goto CONTINUE
+ end
+ local skey = key and key:match '^s|(.+)$'
+ if not skey then
+ goto CONTINUE
+ end
+ if library.global[skey] then
+ goto CONTINUE
+ end
+ if config.config.diagnostics.globals[skey] then
+ goto CONTINUE
+ end
+ if globalCache[key] == nil then
+ local uris = files.findGlobals(key)
+ for i = 1, #uris do
+ local destAst = files.getAst(uris[i])
+ local destGlobals = vm.getGlobals(destAst.ast)
+ if destGlobals[key] and destGlobals[key].mode['set'] then
+ globalCache[key] = true
+ goto CONTINUE
+ end
+ end
+ end
+ globalCache[key] = false
+ local message = lang.script('DIAG_UNDEF_GLOBAL', skey)
+ local otherVersion = library.other[skey]
+ local customVersion = library.custom[skey]
+ if otherVersion then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version))
+ elseif customVersion then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customVersion, '/')))
+ end
+ for _, info in ipairs(infos) do
+ callback {
+ start = info.source.start,
+ finish = info.source.finish,
+ message = message,
+ }
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script-beta/src/core/diagnostics/unused-function.lua b/script-beta/src/core/diagnostics/unused-function.lua
new file mode 100644
index 00000000..6c53cdf7
--- /dev/null
+++ b/script-beta/src/core/diagnostics/unused-function.lua
@@ -0,0 +1,45 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local define = require 'proto.define'
+local lang = require 'language'
+local await = require 'await'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ -- 只检查局部函数与全局函数
+ guide.eachSourceType(ast.ast, 'function', function (source)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type ~= 'local'
+ and parent.type ~= 'setlocal'
+ and parent.type ~= 'setglobal' then
+ return
+ end
+ local hasSet
+ local hasGet = vm.eachRef(source, function (info)
+ if info.mode == 'get' then
+ return true
+ elseif info.mode == 'set'
+ or info.mode == 'declare' then
+ hasSet = true
+ end
+ end)
+ if not hasGet and hasSet then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_FUNCTION,
+ }
+ end
+ await.delay(function ()
+ return files.globalVersion
+ end)
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/unused-label.lua b/script-beta/src/core/diagnostics/unused-label.lua
new file mode 100644
index 00000000..e6d998ba
--- /dev/null
+++ b/script-beta/src/core/diagnostics/unused-label.lua
@@ -0,0 +1,22 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'label', function (source)
+ if not source.ref then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LABEL', source[1]),
+ }
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/unused-local.lua b/script-beta/src/core/diagnostics/unused-local.lua
new file mode 100644
index 00000000..22b2e16b
--- /dev/null
+++ b/script-beta/src/core/diagnostics/unused-local.lua
@@ -0,0 +1,46 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+local function hasGet(loc)
+ if not loc.ref then
+ return false
+ end
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'getlocal' then
+ if not ref.next then
+ return true
+ end
+ local nextType = ref.next.type
+ if nextType ~= 'setmethod'
+ and nextType ~= 'setfield'
+ and nextType ~= 'setindex' then
+ return true
+ end
+ end
+ end
+ return false
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'local', function (source)
+ local name = source[1]
+ if name == '_'
+ or name == '_ENV' then
+ return
+ end
+ if not hasGet(source) then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LOCAL', name),
+ }
+ end
+ end)
+end
diff --git a/script-beta/src/core/diagnostics/unused-vararg.lua b/script-beta/src/core/diagnostics/unused-vararg.lua
new file mode 100644
index 00000000..74cc08e7
--- /dev/null
+++ b/script-beta/src/core/diagnostics/unused-vararg.lua
@@ -0,0 +1,31 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'function', function (source)
+ local args = source.args
+ if not args then
+ return
+ end
+
+ for _, arg in ipairs(args) do
+ if arg.type == '...' then
+ if not arg.ref then
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_VARARG,
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script-beta/src/core/highlight.lua b/script-beta/src/core/highlight.lua
new file mode 100644
index 00000000..61e3f91a
--- /dev/null
+++ b/script-beta/src/core/highlight.lua
@@ -0,0 +1,230 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm'
+local define = require 'proto.define'
+
+local function ofLocal(source, callback)
+ callback(source)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback(ref)
+ end
+ end
+end
+
+local function ofField(source, uri, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ local myKey = guide.getKeyName(source)
+ if parent.type == 'tableindex'
+ or parent.type == 'tablefield' then
+ local tbl = parent.parent
+ vm.eachField(tbl, function (info)
+ if info.key ~= myKey then
+ return
+ end
+ local destUri = guide.getRoot(info.source).uri
+ if destUri ~= uri then
+ return
+ end
+ callback(info.source)
+ end)
+ else
+ vm.eachField(parent.node, function (info)
+ if info.key ~= myKey then
+ return
+ end
+ local destUri = guide.getRoot(info.source).uri
+ if destUri ~= uri then
+ return
+ end
+ callback(info.source)
+ end)
+ end
+end
+
+local function ofIndex(source, uri, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ ofField(source, uri, callback)
+ end
+end
+
+local function ofLabel(source, callback)
+ vm.eachRef(source, function (info)
+ callback(info.source)
+ end)
+end
+
+local function find(source, uri, callback)
+ if source.type == 'local' then
+ ofLocal(source, callback)
+ elseif source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ ofLocal(source.node, callback)
+ elseif source.type == 'field'
+ or source.type == 'method' then
+ ofField(source, uri, callback)
+ elseif source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number' then
+ ofIndex(source, uri, callback)
+ callback(source)
+ elseif source.type == 'nil' then
+ callback(source)
+ elseif source.type == 'goto'
+ or source.type == 'label' then
+ ofLabel(source, callback)
+ end
+end
+
+local function checkInIf(source, text, offset)
+ -- 检查 end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if offset >= endA
+ and offset <= endB
+ and text:sub(endA, endB) == 'end' then
+ return true
+ end
+ -- 检查每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ if offset >= start and offset <= finish then
+ return true
+ end
+ end
+ end
+ return false
+end
+
+local function makeIf(source, text, callback)
+ -- end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if text:sub(endA, endB) == 'end' then
+ callback(endA, endB)
+ end
+ -- 每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ return false
+end
+
+local function findKeyword(source, text, offset, callback)
+ if source.type == 'do'
+ or source.type == 'function'
+ or source.type == 'loop'
+ or source.type == 'in'
+ or source.type == 'while'
+ or source.type == 'repeat' then
+ local ok
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ if offset >= start and offset <= finish then
+ ok = true
+ break
+ end
+ end
+ if ok then
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ elseif source.type == 'if' then
+ local ok = checkInIf(source, text, offset)
+ if ok then
+ makeIf(source, text, callback)
+ end
+ end
+end
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local text = files.getText(uri)
+ local results = {}
+ local mark = {}
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ find(source, uri, function (target)
+ local kind
+ if target.type == 'getfield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setfield'
+ or target.type == 'tablefield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setindex'
+ or target.type == 'tableindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getlocal'
+ or target.type == 'getglobal'
+ or target.type == 'goto' then
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setlocal'
+ or target.type == 'local'
+ or target.type == 'setglobal'
+ or target.type == 'label' then
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'string'
+ or target.type == 'boolean'
+ or target.type == 'number'
+ or target.type == 'nil' then
+ kind = define.DocumentHighlightKind.Text
+ else
+ log.warn('Unknow target.type:', target.type)
+ return
+ end
+ if mark[target] then
+ return
+ end
+ mark[target] = true
+ results[#results+1] = {
+ start = target.start,
+ finish = target.finish,
+ kind = kind,
+ }
+ end)
+ findKeyword(source, text, offset, function (start, finish)
+ results[#results+1] = {
+ start = start,
+ finish = finish,
+ kind = define.DocumentHighlightKind.Write
+ }
+ end)
+ end)
+ if #results == 0 then
+ return nil
+ end
+ return results
+end
diff --git a/script-beta/src/core/hover/arg.lua b/script-beta/src/core/hover/arg.lua
new file mode 100644
index 00000000..be344488
--- /dev/null
+++ b/script-beta/src/core/hover/arg.lua
@@ -0,0 +1,20 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local function asFunction(source)
+ if not source.args then
+ return ''
+ end
+ local args = {}
+ for i = 1, #source.args do
+ local arg = source.args[i]
+ args[i] = ('%s: %s'):format(guide.getName(arg), vm.getType(arg))
+ end
+ return table.concat(args, ', ')
+end
+
+return function (source)
+ if source.type == 'function' then
+ return asFunction(source)
+ end
+end
diff --git a/script-beta/src/core/hover/init.lua b/script-beta/src/core/hover/init.lua
new file mode 100644
index 00000000..b99c14b2
--- /dev/null
+++ b/script-beta/src/core/hover/init.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local getLabel = require 'core.hover.label'
+
+local function getHoverAsFunction(source)
+ local values = vm.getValue(source)
+ local labels = {}
+ for _, value in ipairs(values) do
+ if value.type == 'function' then
+ labels[#labels+1] = getLabel(value.source)
+ end
+ end
+
+ local label = table.concat(labels, '\n')
+ return {
+ label = label,
+ source = source,
+ }
+end
+
+local function getHoverAsValue(source)
+ local label = getLabel(source)
+ return {
+ label = label,
+ source = source,
+ }
+end
+
+local function getHover(source)
+ local isFunction = vm.hasType(source, 'function')
+ if isFunction then
+ return getHoverAsFunction(source)
+ else
+ return getHoverAsValue(source)
+ end
+end
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local hover = guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal'
+ or source.type == 'field'
+ or source.type == 'method' then
+ return getHover(source)
+ end
+ end)
+ return hover
+end
diff --git a/script-beta/src/core/hover/label.lua b/script-beta/src/core/hover/label.lua
new file mode 100644
index 00000000..72ce60f4
--- /dev/null
+++ b/script-beta/src/core/hover/label.lua
@@ -0,0 +1,103 @@
+local buildName = require 'core.hover.name'
+local buildArg = require 'core.hover.arg'
+local buildReturn = require 'core.hover.return'
+local buildTable = require 'core.hover.table'
+local vm = require 'vm'
+local util = require 'utility'
+
+local function asFunction(source)
+ local name = buildName(source)
+ local arg = buildArg(source)
+ local rtn = buildReturn(source)
+ local lines = {}
+ lines[1] = ('function %s(%s)'):format(name, arg)
+ lines[2] = rtn
+ return table.concat(lines, '\n')
+end
+
+local function asLocal(source)
+ local name = buildName(source)
+ local type = vm.getType(source)
+ local literal = vm.getLiteral(source)
+ if type == 'table' then
+ type = buildTable(source)
+ end
+ if literal == nil then
+ return ('local %s: %s'):format(name, type)
+ else
+ return ('local %s: %s = %s'):format(name, type, util.viewLiteral(literal))
+ end
+end
+
+local function asGlobal(source)
+ local name = buildName(source)
+ local type = vm.getType(source)
+ local literal = vm.getLiteral(source)
+ if type == 'table' then
+ type = buildTable(source)
+ end
+ if literal == nil then
+ return ('global %s: %s'):format(name, type)
+ else
+ return ('global %s: %s = %s'):format(name, type, util.viewLiteral(literal))
+ end
+end
+
+local function isGlobalField(source)
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ if source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod'
+ or source.type == 'tablefield' then
+ local node = source.node
+ if node.type == 'setglobal'
+ or node.type == 'getglobal' then
+ return true
+ end
+ return isGlobalField(node)
+ else
+ return false
+ end
+end
+
+local function asField(source)
+ if isGlobalField(source) then
+ return asGlobal(source)
+ end
+ local name = buildName(source)
+ local type = vm.getType(source)
+ local literal = vm.getLiteral(source)
+ if type == 'table' then
+ type = buildTable(source)
+ end
+ if literal == nil then
+ return ('field %s: %s'):format(name, type)
+ else
+ return ('field %s: %s = %s'):format(name, type, util.viewLiteral(literal))
+ end
+end
+
+return function (source)
+ if source.type == 'function' then
+ return asFunction(source)
+ elseif source.type == 'local'
+ or source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ return asLocal(source)
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return asGlobal(source)
+ elseif source.type == 'getfield'
+ or source.type == 'setfield'
+ or source.type == 'getmethod'
+ or source.type == 'setmethod'
+ or source.type == 'tablefield'
+ or source.type == 'field'
+ or source.type == 'method' then
+ return asField(source)
+ end
+end
diff --git a/script-beta/src/core/hover/name.lua b/script-beta/src/core/hover/name.lua
new file mode 100644
index 00000000..a22a8b5a
--- /dev/null
+++ b/script-beta/src/core/hover/name.lua
@@ -0,0 +1,64 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local function asLocal(source)
+ return guide.getName(source)
+end
+
+local function asMethod(source)
+ local class = vm.eachField(source.node, function (info)
+ if info.key == 's|type' or info.key == 's|__name' or info.key == 's|name' then
+ if info.value and info.value.type == 'string' then
+ return info.value[1]
+ end
+ end
+ end)
+ local node = class or guide.getName(source.node) or '?'
+ local method = guide.getName(source)
+ return ('%s:%s'):format(node, method)
+end
+
+local function asField(source)
+ local class = vm.eachField(source.node, function (info)
+ if info.key == 's|type' or info.key == 's|__name' or info.key == 's|name' then
+ if info.value and info.value.type == 'string' then
+ return info.value[1]
+ end
+ end
+ end)
+ local node = class or guide.getName(source.node) or '?'
+ local method = guide.getName(source)
+ return ('%s.%s'):format(node, method)
+end
+
+local function asGlobal(source)
+ return guide.getName(source)
+end
+
+local function buildName(source)
+ if source.type == 'local'
+ or source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ return asLocal(source) or ''
+ end
+ if source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return asGlobal(source) or ''
+ end
+ if source.type == 'setmethod'
+ or source.type == 'getmethod' then
+ return asMethod(source) or ''
+ end
+ if source.type == 'setfield'
+ or source.tyoe == 'getfield'
+ or source.type == 'tablefield' then
+ return asField(source) or ''
+ end
+ local parent = source.parent
+ if parent then
+ return buildName(parent)
+ end
+ return ''
+end
+
+return buildName
diff --git a/script-beta/src/core/hover/return.lua b/script-beta/src/core/hover/return.lua
new file mode 100644
index 00000000..c22626a6
--- /dev/null
+++ b/script-beta/src/core/hover/return.lua
@@ -0,0 +1,34 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local function asFunction(source)
+ if not source.returns then
+ return nil
+ end
+ local returns = {}
+ for _, rtn in ipairs(source.returns) do
+ for i = 1, #rtn do
+ local values = vm.getValue(rtn[i])
+ returns[#returns+1] = values
+ end
+ break
+ end
+ if #returns == 0 then
+ return nil
+ end
+ local lines = {}
+ for i = 1, #returns do
+ if i == 1 then
+ lines[i] = (' -> %s'):format(vm.viewType(returns[i]))
+ else
+ lines[i] = ('% 3d. %s'):format(i, returns[i])
+ end
+ end
+ return table.concat(lines, '\n')
+end
+
+return function (source)
+ if source.type == 'function' then
+ return asFunction(source)
+ end
+end
diff --git a/script-beta/src/core/hover/table.lua b/script-beta/src/core/hover/table.lua
new file mode 100644
index 00000000..9ed86692
--- /dev/null
+++ b/script-beta/src/core/hover/table.lua
@@ -0,0 +1,35 @@
+local vm = require 'vm'
+
+local function checkClass(source)
+end
+
+return function (source)
+ local fields = {}
+ local class
+ vm.eachField(source, function (info)
+ if info.key == 's|type' or info.key == 's|__name' or info.key == 's|name' then
+ if info.value and info.value.type == 'string' then
+ class = info.value[1]
+ end
+ end
+ local type = vm.getType(info.source)
+ fields[#fields+1] = ('%s'):format(type)
+ end)
+ local fieldsBuf
+ if #fields == 0 then
+ fieldsBuf = '{}'
+ else
+ local lines = {}
+ lines[#lines+1] = '{'
+ for _, field in ipairs(fields) do
+ lines[#lines+1] = ' ' .. field
+ end
+ lines[#lines+1] = '}'
+ fieldsBuf = table.concat(lines, '\n')
+ end
+ if class then
+ return ('%s %s'):format(class, fieldsBuf)
+ else
+ return fieldsBuf
+ end
+end
diff --git a/script-beta/src/core/reference.lua b/script-beta/src/core/reference.lua
new file mode 100644
index 00000000..7e265e97
--- /dev/null
+++ b/script-beta/src/core/reference.lua
@@ -0,0 +1,84 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm'
+
+local function isFunction(source, offset)
+ if source.type ~= 'function' then
+ return false
+ end
+ -- 必须点在 `function` 这个单词上才能查找函数引用
+ return offset >= source.start and offset < source.start + #'function'
+end
+
+local function findRef(source, offset, callback)
+ if source.type ~= 'local'
+ and source.type ~= 'getlocal'
+ and source.type ~= 'setlocal'
+ and source.type ~= 'setglobal'
+ and source.type ~= 'getglobal'
+ and source.type ~= 'field'
+ and source.type ~= 'tablefield'
+ and source.type ~= 'method'
+ and source.type ~= 'string'
+ and source.type ~= 'number'
+ and source.type ~= 'boolean'
+ and source.type ~= 'goto'
+ and source.type ~= 'label'
+ and not isFunction(source, offset) then
+ return
+ end
+ vm.eachRef(source, function (info)
+ if info.mode == 'declare'
+ or info.mode == 'set'
+ or info.mode == 'get'
+ or info.mode == 'return' then
+ local src = info.source
+ local root = guide.getRoot(src)
+ local uri = root.uri
+ if src.type == 'setfield'
+ or src.type == 'getfield'
+ or src.type == 'tablefield' then
+ callback(src.field, uri)
+ elseif src.type == 'setindex'
+ or src.type == 'getindex'
+ or src.type == 'tableindex' then
+ callback(src.index, uri)
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ callback(src.method, uri)
+ else
+ callback(src, uri)
+ end
+ end
+ if info.mode == 'value' then
+ local src = info.source
+ local root = guide.getRoot(src)
+ local uri = root.uri
+ if src.type == 'function' then
+ if src.parent.type == 'return' then
+ callback(src, uri)
+ end
+ end
+ end
+ end)
+end
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local results = {}
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ findRef(source, offset, function (target, uri)
+ results[#results+1] = {
+ target = target,
+ uri = files.getOriginUri(uri),
+ }
+ end)
+ end)
+ if #results == 0 then
+ return nil
+ end
+ return results
+end
diff --git a/script-beta/src/core/rename.lua b/script-beta/src/core/rename.lua
new file mode 100644
index 00000000..3e4512da
--- /dev/null
+++ b/script-beta/src/core/rename.lua
@@ -0,0 +1,374 @@
+local files = require 'files'
+local vm = require 'vm'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local define = require 'proto.define'
+local util = require 'utility'
+
+local Forcing
+
+local function askForcing(str)
+ if TEST then
+ return true
+ end
+ if Forcing == false then
+ return false
+ end
+ local version = files.globalVersion
+ -- TODO
+ local item = proto.awaitRequest('window/showMessageRequest', {
+ type = define.MessageType.Warning,
+ message = ('[%s]不是有效的标识符,是否强制替换?'):format(str),
+ actions = {
+ {
+ title = '强制替换',
+ },
+ {
+ title = '取消',
+ },
+ }
+ })
+ if version ~= files.globalVersion then
+ Forcing = false
+ proto.notify('window/showMessage', {
+ type = define.MessageType.Warning,
+ message = '文件发生了变化,替换取消。'
+ })
+ return false
+ end
+ if not item then
+ Forcing = false
+ return false
+ end
+ if item.title == '强制替换' then
+ Forcing = true
+ return true
+ else
+ Forcing = false
+ return false
+ end
+end
+
+local function askForMultiChange(results, newname)
+ if TEST then
+ return true
+ end
+ local uris = {}
+ for _, result in ipairs(results) do
+ local uri = result.uri
+ if not uris[uri] then
+ uris[uri] = 0
+ uris[#uris+1] = uri
+ end
+ uris[uri] = uris[uri] + 1
+ end
+ if #uris <= 1 then
+ return true
+ end
+
+ local version = files.globalVersion
+ -- TODO
+ local item = proto.awaitRequest('window/showMessageRequest', {
+ type = define.MessageType.Warning,
+ message = ('将修改 %d 个文件,共 %d 处。'):format(
+ #uris,
+ #results
+ ),
+ actions = {
+ {
+ title = '继续',
+ },
+ {
+ title = '放弃',
+ },
+ }
+ })
+ if version ~= files.globalVersion then
+ proto.notify('window/showMessage', {
+ type = define.MessageType.Warning,
+ message = '文件发生了变化,替换取消。'
+ })
+ return false
+ end
+ if item and item.title == '继续' then
+ local fileList = {}
+ for _, uri in ipairs(uris) do
+ fileList[#fileList+1] = ('%s (%d)'):format(uri, uris[uri])
+ end
+
+ log.debug(('Renamed [%s]\r\n%s'):format(newname, table.concat(fileList, '\r\n')))
+ return true
+ end
+ return false
+end
+
+local function trim(str)
+ return str:match '^%s*(%S+)%s*$'
+end
+
+local function isValidName(str)
+ return str:match '^[%a_][%w_]*$'
+end
+
+local function isValidGlobal(str)
+ for s in str:gmatch '[^%.]*' do
+ if not isValidName(trim(s)) then
+ return false
+ end
+ end
+ return true
+end
+
+local function isValidFunctionName(str)
+ if isValidGlobal(str) then
+ return true
+ end
+ local pos = str:find(':', 1, true)
+ if not pos then
+ return false
+ end
+ return isValidGlobal(trim(str:sub(1, pos-1)))
+ and isValidName(trim(str:sub(pos+1)))
+end
+
+local function isFunctionGlobalName(source)
+ local parent = source.parent
+ if parent.type ~= 'setglobal' then
+ return false
+ end
+ local value = parent.value
+ if not value.type ~= 'function' then
+ return false
+ end
+ return value.start <= parent.start
+end
+
+local function renameLocal(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return
+ end
+ if askForcing(newname) then
+ callback(source, source.start, source.finish, newname)
+ end
+end
+
+local function renameField(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ local parent = source.parent
+ if parent.type == 'setfield'
+ or parent.type == 'getfield' then
+ local dot = parent.dot
+ local newstr = '[' .. util.viewString(newname) .. ']'
+ callback(source, dot.start, source.finish, newstr)
+ elseif parent.type == 'tablefield' then
+ local newstr = '[' .. util.viewString(newname) .. ']'
+ callback(source, source.start, source.finish, newstr)
+ elseif parent.type == 'getmethod' then
+ if not askForcing(newname) then
+ return false
+ end
+ callback(source, source.start, source.finish, newname)
+ elseif parent.type == 'setmethod' then
+ local uri = guide.getRoot(source).uri
+ local text = files.getText(uri)
+ local func = parent.value
+ -- function mt:name () end --> mt['newname'] = function (self) end
+ local newstr = string.format('%s[%s] = function '
+ , text:sub(parent.start, parent.node.finish)
+ , util.viewString(newname)
+ )
+ callback(source, func.start, parent.finish, newstr)
+ local pl = text:find('(', parent.finish, true)
+ if pl then
+ if func.args then
+ callback(source, pl + 1, pl, 'self, ')
+ else
+ callback(source, pl + 1, pl, 'self')
+ end
+ end
+ end
+ return true
+end
+
+local function renameGlobal(source, newname, callback)
+ if isValidGlobal(newname) then
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ if isValidFunctionName(newname) then
+ if not isFunctionGlobalName(source) then
+ askForcing(newname)
+ end
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ local newstr = '_ENV[' .. util.viewString(newname) .. ']'
+ -- function name () end --> _ENV['newname'] = function () end
+ if source.value and source.value.type == 'function'
+ and source.value.start < source.start then
+ callback(source, source.value.start, source.finish, newstr .. ' = function ')
+ return true
+ end
+ callback(source, source.start, source.finish, newstr)
+ return true
+end
+
+local function ofLocal(source, newname, callback)
+ renameLocal(source, newname, callback)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ renameLocal(ref, newname, callback)
+ end
+ end
+end
+
+local function ofField(source, newname, callback)
+ return vm.eachRef(source, function (info)
+ local src = info.source
+ if src.type == 'tablefield'
+ or src.type == 'getfield'
+ or src.type == 'setfield' then
+ src = src.field
+ elseif src.type == 'tableindex'
+ or src.type == 'getindex'
+ or src.type == 'setindex' then
+ src = src.index
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ src = src.method
+ end
+ if src.type == 'string' then
+ local quo = src[2]
+ local text = util.viewString(newname, quo)
+ callback(src, src.start, src.finish, text)
+ return
+ elseif src.type == 'field'
+ or src.type == 'method' then
+ local suc = renameField(src, newname, callback)
+ if not suc then
+ return false
+ end
+ elseif src.type == 'setglobal'
+ or src.type == 'getglobal' then
+ local suc = renameGlobal(src, newname, callback)
+ if not suc then
+ return false
+ end
+ end
+ end)
+end
+
+local function rename(source, newname, callback)
+ if source.type == 'label'
+ or source.type == 'goto' then
+ if not isValidName(newname) and not askForcing(newname)then
+ return false
+ end
+ vm.eachRef(source, function (info)
+ callback(info.source, info.source.start, info.source.finish, newname)
+ end)
+ elseif source.type == 'local' then
+ return ofLocal(source, newname, callback)
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ return ofLocal(source.node, newname, callback)
+ elseif source.type == 'field'
+ or source.type == 'method'
+ or source.type == 'tablefield'
+ or source.type == 'string'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return ofField(source, newname, callback)
+ end
+ return true
+end
+
+local function prepareRename(source)
+ if source.type == 'label'
+ or source.type == 'goto'
+ or source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'field'
+ or source.type == 'method'
+ or source.type == 'tablefield'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return source, source[1]
+ elseif source.type == 'string' then
+ local parent = source.parent
+ if not parent then
+ return nil
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return source, source[1]
+ end
+ return nil
+ end
+ return nil
+end
+
+local m = {}
+
+function m.rename(uri, pos, newname)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local results = {}
+
+ guide.eachSourceContain(ast.ast, pos, function(source)
+ rename(source, newname, function (target, start, finish, text)
+ results[#results+1] = {
+ start = start,
+ finish = finish,
+ text = text,
+ uri = guide.getRoot(target).uri,
+ }
+ end)
+ end)
+
+ if Forcing == false then
+ Forcing = nil
+ return nil
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ if not askForMultiChange(results, newname) then
+ return nil
+ end
+
+ return results
+end
+
+function m.prepareRename(uri, pos)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local result
+ guide.eachSourceContain(ast.ast, pos, function(source)
+ local res, text = prepareRename(source)
+ if res then
+ result = {
+ start = source.start,
+ finish = source.finish,
+ text = text,
+ }
+ end
+ end)
+
+ return result
+end
+
+return m