summaryrefslogtreecommitdiff
path: root/script/core
diff options
context:
space:
mode:
Diffstat (limited to 'script/core')
-rw-r--r--script/core/code-action.lua269
-rw-r--r--script/core/command/removeSpace.lua56
-rw-r--r--script/core/command/solve.lua96
-rw-r--r--script/core/completion.lua1284
-rw-r--r--script/core/definition.lua156
-rw-r--r--script/core/diagnostics/ambiguity-1.lua69
-rw-r--r--script/core/diagnostics/circle-doc-class.lua54
-rw-r--r--script/core/diagnostics/code-after-break.lua34
-rw-r--r--script/core/diagnostics/doc-field-no-class.lua41
-rw-r--r--script/core/diagnostics/duplicate-doc-class.lua46
-rw-r--r--script/core/diagnostics/duplicate-doc-field.lua34
-rw-r--r--script/core/diagnostics/duplicate-doc-param.lua37
-rw-r--r--script/core/diagnostics/duplicate-index.lua63
-rw-r--r--script/core/diagnostics/empty-block.lua49
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua66
-rw-r--r--script/core/diagnostics/init.lua56
-rw-r--r--script/core/diagnostics/lowercase-global.lua56
-rw-r--r--script/core/diagnostics/newfield-call.lua37
-rw-r--r--script/core/diagnostics/newline-call.lua38
-rw-r--r--script/core/diagnostics/redefined-local.lua32
-rw-r--r--script/core/diagnostics/redundant-parameter.lua82
-rw-r--r--script/core/diagnostics/redundant-value.lua24
-rw-r--r--script/core/diagnostics/trailing-space.lua55
-rw-r--r--script/core/diagnostics/undefined-doc-class.lua46
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua60
-rw-r--r--script/core/diagnostics/undefined-doc-param.lua52
-rw-r--r--script/core/diagnostics/undefined-env-child.lua27
-rw-r--r--script/core/diagnostics/undefined-global.lua40
-rw-r--r--script/core/diagnostics/unused-function.lua40
-rw-r--r--script/core/diagnostics/unused-label.lua22
-rw-r--r--script/core/diagnostics/unused-local.lua93
-rw-r--r--script/core/diagnostics/unused-vararg.lua31
-rw-r--r--script/core/document-symbol.lua307
-rw-r--r--script/core/find-source.lua14
-rw-r--r--script/core/highlight.lua252
-rw-r--r--script/core/hover/arg.lua71
-rw-r--r--script/core/hover/description.lua204
-rw-r--r--script/core/hover/init.lua164
-rw-r--r--script/core/hover/label.lua211
-rw-r--r--script/core/hover/name.lua101
-rw-r--r--script/core/hover/return.lua125
-rw-r--r--script/core/hover/table.lua257
-rw-r--r--script/core/keyword.lua264
-rw-r--r--script/core/matchkey.lua33
-rw-r--r--script/core/reference.lua116
-rw-r--r--script/core/rename.lua448
-rw-r--r--script/core/semantic-tokens.lua161
-rw-r--r--script/core/signature.lua106
-rw-r--r--script/core/workspace-symbol.lua69
49 files changed, 6048 insertions, 0 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
new file mode 100644
index 00000000..69304f98
--- /dev/null
+++ b/script/core/code-action.lua
@@ -0,0 +1,269 @@
+local files = require 'files'
+local lang = require 'language'
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+local util = require 'utility'
+local sp = require 'bee.subprocess'
+
+local function disableDiagnostic(uri, code, results)
+ results[#results+1] = {
+ title = lang.script('ACTION_DISABLE_DIAG', code),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_DISABLE_DIAG,
+ command = 'lua.config',
+ arguments = {
+ {
+ key = 'Lua.diagnostics.disable',
+ action = 'add',
+ value = code,
+ uri = uri,
+ }
+ }
+ }
+ }
+end
+
+local function markGlobal(uri, name, results)
+ results[#results+1] = {
+ title = lang.script('ACTION_MARK_GLOBAL', name),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_MARK_GLOBAL,
+ command = 'lua.config',
+ arguments = {
+ {
+ key = 'Lua.diagnostics.globals',
+ action = 'add',
+ value = name,
+ uri = uri,
+ }
+ }
+ }
+ }
+end
+
+local function changeVersion(uri, version, results)
+ results[#results+1] = {
+ title = lang.script('ACTION_RUNTIME_VERSION', version),
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_RUNTIME_VERSION,
+ command = 'lua.config',
+ arguments = {
+ {
+ key = 'Lua.runtime.version',
+ action = 'set',
+ value = version,
+ uri = uri,
+ }
+ }
+ },
+ }
+end
+
+local function solveUndefinedGlobal(uri, diag, results)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ local offset = define.offsetOfWord(lines, text, diag.range.start)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type ~= 'getglobal' then
+ return
+ end
+
+ local name = guide.getName(source)
+ markGlobal(uri, name, results)
+
+ -- TODO check other version
+ end)
+end
+
+local function solveLowercaseGlobal(uri, diag, results)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ local offset = define.offsetOfWord(lines, text, diag.range.start)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type ~= 'setglobal' then
+ return
+ end
+
+ local name = guide.getName(source)
+ markGlobal(uri, name, results)
+ end)
+end
+
+local function findSyntax(uri, diag)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ for _, err in ipairs(ast.errs) do
+ if err.type:lower():gsub('_', '-') == diag.code then
+ local range = define.range(lines, text, err.start, err.finish)
+ if util.equal(range, diag.range) then
+ return err
+ end
+ end
+ end
+ return nil
+end
+
+local function solveSyntaxByChangeVersion(uri, err, results)
+ if type(err.version) == 'table' then
+ for _, version in ipairs(err.version) do
+ changeVersion(uri, version, results)
+ end
+ else
+ changeVersion(uri, err.version, results)
+ end
+end
+
+local function solveSyntaxByAddDoEnd(uri, err, results)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ results[#results+1] = {
+ title = lang.script.ACTION_ADD_DO_END,
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = define.range(lines, text, err.start, err.finish),
+ newText = ('do %s end'):format(text:sub(err.start, err.finish)),
+ },
+ }
+ }
+ }
+ }
+end
+
+local function solveSyntaxByFix(uri, err, results)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ local changes = {}
+ for _, fix in ipairs(err.fix) do
+ changes[#changes+1] = {
+ range = define.range(lines, text, fix.start, fix.finish),
+ newText = fix.text,
+ }
+ end
+ results[#results+1] = {
+ title = lang.script['ACTION_' .. err.fix.title],
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = changes,
+ }
+ }
+ }
+end
+
+local function solveSyntax(uri, diag, results)
+ local err = findSyntax(uri, diag)
+ if not err then
+ return
+ end
+ if err.version then
+ solveSyntaxByChangeVersion(uri, err, results)
+ end
+ if err.type == 'ACTION_AFTER_BREAK' or err.type == 'ACTION_AFTER_RETURN' then
+ solveSyntaxByAddDoEnd(uri, err, results)
+ end
+ if err.fix then
+ solveSyntaxByFix(uri, err, results)
+ end
+end
+
+local function solveNewlineCall(uri, diag, results)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ results[#results+1] = {
+ title = lang.script.ACTION_ADD_SEMICOLON,
+ kind = 'quickfix',
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = {
+ start = diag.range.start,
+ ['end'] = diag.range.start,
+ },
+ newText = ';',
+ }
+ }
+ }
+ }
+ }
+end
+
+local function solveAmbiguity1(uri, diag, results)
+ results[#results+1] = {
+ title = lang.script.ACTION_ADD_BRACKETS,
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_ADD_BRACKETS,
+ command = 'lua.solve:' .. sp:get_id(),
+ arguments = {
+ {
+ name = 'ambiguity-1',
+ uri = uri,
+ range = diag.range,
+ }
+ }
+ },
+ }
+end
+
+local function solveTrailingSpace(uri, diag, results)
+ results[#results+1] = {
+ title = lang.script.ACTION_REMOVE_SPACE,
+ kind = 'quickfix',
+ command = {
+ title = lang.script.COMMAND_REMOVE_SPACE,
+ command = 'lua.removeSpace:' .. sp:get_id(),
+ arguments = {
+ {
+ uri = uri,
+ }
+ }
+ },
+ }
+end
+
+local function solveDiagnostic(uri, diag, results)
+ if diag.source == lang.script.DIAG_SYNTAX_CHECK then
+ solveSyntax(uri, diag, results)
+ return
+ end
+ if not diag.code then
+ return
+ end
+ if diag.code == 'undefined-global' then
+ solveUndefinedGlobal(uri, diag, results)
+ elseif diag.code == 'lowercase-global' then
+ solveLowercaseGlobal(uri, diag, results)
+ elseif diag.code == 'newline-call' then
+ solveNewlineCall(uri, diag, results)
+ elseif diag.code == 'ambiguity-1' then
+ solveAmbiguity1(uri, diag, results)
+ elseif diag.code == 'trailing-space' then
+ solveTrailingSpace(uri, diag, results)
+ end
+ disableDiagnostic(uri, diag.code, results)
+end
+
+return function (uri, range, diagnostics)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local results = {}
+
+ for _, diag in ipairs(diagnostics) do
+ solveDiagnostic(uri, diag, results)
+ end
+
+ return results
+end
diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua
new file mode 100644
index 00000000..e8b09932
--- /dev/null
+++ b/script/core/command/removeSpace.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
+
+local function isInString(ast, offset)
+ return guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'string' then
+ return true
+ end
+ end) or false
+end
+
+return function (data)
+ local uri = data.uri
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local ast = files.getAst(uri)
+ if not lines then
+ return
+ end
+
+ local textEdit = {}
+ for i = 1, #lines do
+ local line = guide.lineContent(lines, text, i, true)
+ local pos = line:find '[ \t]+$'
+ if pos then
+ local start, finish = guide.lineRange(lines, i, true)
+ start = start + pos - 1
+ if isInString(ast, start) then
+ goto NEXT_LINE
+ end
+ textEdit[#textEdit+1] = {
+ range = define.range(lines, text, start, finish),
+ newText = '',
+ }
+ goto NEXT_LINE
+ end
+
+ ::NEXT_LINE::
+ end
+
+ if #textEdit == 0 then
+ return
+ end
+
+ proto.awaitRequest('workspace/applyEdit', {
+ label = lang.script.COMMAND_REMOVE_SPACE,
+ edit = {
+ changes = {
+ [uri] = textEdit,
+ }
+ },
+ })
+end
diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua
new file mode 100644
index 00000000..d3b8f94e
--- /dev/null
+++ b/script/core/command/solve.lua
@@ -0,0 +1,96 @@
+local files = require 'files'
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
+
+local opMap = {
+ ['+'] = true,
+ ['-'] = true,
+ ['*'] = true,
+ ['/'] = true,
+ ['//'] = true,
+ ['^'] = true,
+ ['<<'] = true,
+ ['>>'] = true,
+ ['&'] = true,
+ ['|'] = true,
+ ['~'] = true,
+ ['..'] = true,
+}
+
+local literalMap = {
+ ['number'] = true,
+ ['boolean'] = true,
+ ['string'] = true,
+ ['table'] = true,
+}
+
+return function (data)
+ local uri = data.uri
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local start = define.offsetOfWord(lines, text, data.range.start)
+ local finish = define.offsetOfWord(lines, text, data.range['end'])
+
+ local result = guide.eachSourceContain(ast.ast, start, function (source)
+ if source.start ~= start
+ or source.finish ~= finish then
+ return
+ end
+ if not source.op or source.op.type ~= 'or' then
+ return
+ end
+ local first = source[1]
+ local second = source[2]
+ -- a + b or 0 --> a + (b or 0)
+ do
+ if first.op
+ and opMap[first.op.type]
+ and first.type ~= 'unary'
+ and not second.op
+ and literalMap[second.type] then
+ return {
+ start = source[1][2].start,
+ finish = source[2].finish,
+ }
+ end
+ end
+ -- a or b + c --> (a or b) + c
+ do
+ if second.op
+ and opMap[second.op.type]
+ and second.type ~= 'unary'
+ and not first.op
+ and literalMap[second[1].type] then
+ return {
+ start = source[1].start,
+ finish = source[2][1].finish,
+ }
+ end
+ end
+ end)
+
+ if not result then
+ return
+ end
+
+ proto.awaitRequest('workspace/applyEdit', {
+ label = lang.script.COMMAND_REMOVE_SPACE,
+ edit = {
+ changes = {
+ [uri] = {
+ {
+ range = define.range(lines, text, result.start, result.finish),
+ newText = ('(%s)'):format(text:sub(result.start, result.finish)),
+ }
+ },
+ }
+ },
+ })
+end
diff --git a/script/core/completion.lua b/script/core/completion.lua
new file mode 100644
index 00000000..44874b39
--- /dev/null
+++ b/script/core/completion.lua
@@ -0,0 +1,1284 @@
+local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local matchKey = require 'core.matchkey'
+local vm = require 'vm'
+local getLabel = require 'core.hover.label'
+local getName = require 'core.hover.name'
+local getArg = require 'core.hover.arg'
+local getDesc = require 'core.hover.description'
+local getHover = require 'core.hover'
+local config = require 'config'
+local util = require 'utility'
+local markdown = require 'provider.markdown'
+local findSource = require 'core.find-source'
+local await = require 'await'
+local parser = require 'parser'
+local keyWordMap = require 'core.keyword'
+local workspace = require 'workspace'
+local furi = require 'file-uri'
+local rpath = require 'workspace.require-path'
+local lang = require 'language'
+
+local stackID = 0
+local stacks = {}
+local function stack(callback)
+ stackID = stackID + 1
+ stacks[stackID] = callback
+ return stackID
+end
+
+local function clearStack()
+ stacks = {}
+end
+
+local function resolveStack(id)
+ local callback = stacks[id]
+ if not callback then
+ return nil
+ end
+
+ -- 当进行新的 resolve 时,放弃当前的 resolve
+ await.close('completion.resove')
+ return await.await(callback, 'completion.resove')
+end
+
+local function trim(str)
+ return str:match '^%s*(%S+)%s*$'
+end
+
+local function isSpace(char)
+ if char == ' '
+ or char == '\n'
+ or char == '\r'
+ or char == '\t' then
+ return true
+ end
+ return false
+end
+
+local function skipSpace(text, offset)
+ for i = offset, 1, -1 do
+ local char = text:sub(i, i)
+ if not isSpace(char) then
+ return i
+ end
+ end
+ return 0
+end
+
+local function findWord(text, offset)
+ for i = offset, 1, -1 do
+ if not text:sub(i, i):match '[%w_]' then
+ if i == offset then
+ return nil
+ end
+ return text:sub(i+1, offset), i+1
+ end
+ end
+ return text:sub(1, offset), 1
+end
+
+local function findSymbol(text, offset)
+ for i = offset, 1, -1 do
+ local char = text:sub(i, i)
+ if isSpace(char) then
+ goto CONTINUE
+ end
+ if char == '.'
+ or char == ':'
+ or char == '(' then
+ return char, i
+ else
+ return nil
+ end
+ ::CONTINUE::
+ end
+ return nil
+end
+
+local function findAnyPos(text, offset)
+ for i = offset, 1, -1 do
+ if not isSpace(text:sub(i, i)) then
+ return i
+ end
+ end
+ return nil
+end
+
+local function findParent(ast, text, offset)
+ for i = offset, 1, -1 do
+ local char = text:sub(i, i)
+ if isSpace(char) then
+ goto CONTINUE
+ end
+ local oop
+ if char == '.' then
+ -- `..` 的情况
+ if text:sub(i-1, i-1) == '.' then
+ return nil, nil
+ end
+ oop = false
+ elseif char == ':' then
+ oop = true
+ else
+ return nil, nil
+ end
+ local anyPos = findAnyPos(text, i-1)
+ if not anyPos then
+ return nil, nil
+ end
+ local parent = guide.eachSourceContain(ast.ast, anyPos, function (source)
+ if source.finish == anyPos then
+ return source
+ end
+ end)
+ if parent then
+ return parent, oop
+ end
+ ::CONTINUE::
+ end
+ return nil, nil
+end
+
+local function findParentInStringIndex(ast, text, offset)
+ local near, nearStart
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ local start = guide.getStartFinish(source)
+ if not start then
+ return
+ end
+ if not nearStart or nearStart < start then
+ near = source
+ nearStart = start
+ end
+ end)
+ if not near or near.type ~= 'string' then
+ return
+ end
+ local parent = near.parent
+ if not parent or parent.index ~= near then
+ return
+ end
+ -- index不可能是oop模式
+ return parent.node, false
+end
+
+local function buildFunctionSnip(source, oop)
+ local name = getName(source):gsub('^.-[$.:]', '')
+ local defs = vm.getDefs(source, 'deep')
+ local args = ''
+ for _, def in ipairs(defs) do
+ local defArgs = getArg(def, oop)
+ if defArgs ~= '' then
+ args = defArgs
+ break
+ end
+ end
+ local id = 0
+ args = args:gsub('[^,]+', function (arg)
+ id = id + 1
+ return arg:gsub('^(%s*)(.+)', function (sp, word)
+ return ('%s${%d:%s}'):format(sp, id, word)
+ end)
+ end)
+ return ('%s(%s)'):format(name, args)
+end
+
+local function buildDetail(source)
+ local types = vm.getInferType(source, 'deep')
+ local literals = vm.getInferLiteral(source, 'deep')
+ if literals then
+ return types .. ' = ' .. literals
+ else
+ return types
+ end
+end
+
+local function getSnip(source)
+ local context = config.config.completion.displayContext
+ if context <= 0 then
+ return nil
+ end
+ local defs = vm.getRefs(source, 'deep')
+ for _, def in ipairs(defs) do
+ def = guide.getObjectValue(def) or def
+ if def ~= source and def.type == 'function' then
+ local uri = guide.getUri(def)
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ if not text then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(uri) then
+ goto CONTINUE
+ end
+ local row = guide.positionOf(lines, def.start)
+ local firstRow = lines[row]
+ local lastRow = lines[math.min(row + context - 1, #lines)]
+ local snip = text:sub(firstRow.start, lastRow.finish)
+ return snip
+ end
+ ::CONTINUE::
+ end
+end
+
+local function buildDesc(source)
+ local hover = getHover.get(source)
+ local md = markdown()
+ md:add('lua', hover.label)
+ md:add('md', hover.description)
+ local snip = getSnip(source)
+ if snip then
+ md:add('md', '-------------')
+ md:add('lua', snip)
+ end
+ return md:string()
+end
+
+local function buildFunction(results, source, oop, data)
+ local snipType = config.config.completion.callSnippet
+ if snipType == 'Disable' or snipType == 'Both' then
+ results[#results+1] = data
+ end
+ if snipType == 'Both' or snipType == 'Replace' then
+ local snipData = util.deepCopy(data)
+ snipData.kind = define.CompletionItemKind.Snippet
+ snipData.label = snipData.label .. '()'
+ snipData.insertText = buildFunctionSnip(source, oop)
+ snipData.insertTextFormat = 2
+ snipData.id = stack(function ()
+ return {
+ detail = buildDetail(source),
+ description = buildDesc(source),
+ }
+ end)
+ results[#results+1] = snipData
+ end
+end
+
+local function isSameSource(ast, source, pos)
+ if not files.eq(guide.getUri(source), guide.getUri(ast.ast)) then
+ return false
+ end
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ return source.start <= pos and source.finish >= pos
+end
+
+local function checkLocal(ast, word, offset, results)
+ local locals = guide.getVisibleLocals(ast.ast, offset)
+ for name, source in pairs(locals) do
+ if isSameSource(ast, source, offset) then
+ goto CONTINUE
+ end
+ if not matchKey(word, name) then
+ goto CONTINUE
+ end
+ if vm.hasType(source, 'function') then
+ buildFunction(results, source, false, {
+ label = name,
+ kind = define.CompletionItemKind.Function,
+ id = stack(function ()
+ return {
+ detail = buildDetail(source),
+ description = buildDesc(source),
+ }
+ end),
+ })
+ else
+ results[#results+1] = {
+ label = name,
+ kind = define.CompletionItemKind.Variable,
+ id = stack(function ()
+ return {
+ detail = buildDetail(source),
+ description = buildDesc(source),
+ }
+ end),
+ }
+ end
+ ::CONTINUE::
+ end
+end
+
+local function checkFieldFromFieldToIndex(name, parent, word, start, offset)
+ if name:match '^[%a_][%w_]*$' then
+ return nil
+ end
+ local textEdit, additionalTextEdits
+ local uri = guide.getUri(parent)
+ local text = files.getText(uri)
+ local wordStart
+ if word == '' then
+ wordStart = text:match('()%S', start + 1) or (offset + 1)
+ else
+ wordStart = offset - #word + 1
+ end
+ textEdit = {
+ start = wordStart,
+ finish = offset,
+ newText = ('[%q]'):format(name),
+ }
+ local nxt = parent.next
+ if nxt then
+ local dotStart
+ if nxt.type == 'setfield'
+ or nxt.type == 'getfield'
+ or nxt.type == 'tablefield' then
+ dotStart = nxt.dot.start
+ elseif nxt.type == 'setmethod'
+ or nxt.type == 'getmethod' then
+ dotStart = nxt.colon.start
+ end
+ if dotStart then
+ additionalTextEdits = {
+ {
+ start = dotStart,
+ finish = dotStart,
+ newText = '',
+ }
+ }
+ end
+ else
+ if config.config.runtime.version == 'Lua 5.1'
+ or config.config.runtime.version == 'LuaJIT' then
+ textEdit.newText = '_G' .. textEdit.newText
+ else
+ textEdit.newText = '_ENV' .. textEdit.newText
+ end
+ end
+ return textEdit, additionalTextEdits
+end
+
+local function checkFieldThen(name, src, word, start, offset, parent, oop, results)
+ local value = guide.getObjectValue(src) or src
+ local kind = define.CompletionItemKind.Field
+ if value.type == 'function' then
+ if oop then
+ kind = define.CompletionItemKind.Method
+ else
+ kind = define.CompletionItemKind.Function
+ end
+ buildFunction(results, src, oop, {
+ label = name,
+ kind = kind,
+ deprecated = vm.isDeprecated(src) or nil,
+ id = stack(function ()
+ return {
+ detail = buildDetail(src),
+ description = buildDesc(src),
+ }
+ end),
+ })
+ return
+ end
+ if oop then
+ return
+ end
+ local literal = guide.getLiteral(value)
+ if literal ~= nil then
+ kind = define.CompletionItemKind.Enum
+ end
+ local textEdit, additionalTextEdits
+ if parent.next and parent.next.index then
+ local str = parent.next.index
+ textEdit = {
+ start = str.start + #str[2],
+ finish = offset,
+ newText = name,
+ }
+ else
+ textEdit, additionalTextEdits = checkFieldFromFieldToIndex(name, parent, word, start, offset)
+ end
+ results[#results+1] = {
+ label = name,
+ kind = kind,
+ textEdit = textEdit,
+ additionalTextEdits = additionalTextEdits,
+ id = stack(function ()
+ return {
+ detail = buildDetail(src),
+ description = buildDesc(src),
+ }
+ end)
+ }
+end
+
+local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, isGlobal)
+ local fields = {}
+ local count = 0
+ for _, src in ipairs(refs) do
+ local key = vm.getKeyName(src)
+ if not key or key:sub(1, 1) ~= 's' then
+ goto CONTINUE
+ end
+ if isSameSource(ast, src, start) then
+ -- 由于fastGlobal的优化,全局变量只会找出一个值,有可能找出自己
+ -- 所以遇到自己的时候重新找一下有没有其他定义
+ if not isGlobal then
+ goto CONTINUE
+ end
+ if #vm.getGlobals(key) <= 1 then
+ goto CONTINUE
+ elseif not vm.isSet(src) then
+ src = vm.getGlobalSets(key)[1] or src
+ end
+ end
+ local name = key:sub(3)
+ if locals and locals[name] then
+ goto CONTINUE
+ end
+ if not matchKey(word, name, count >= 100) then
+ goto CONTINUE
+ end
+ local last = fields[name]
+ if not last then
+ fields[name] = src
+ count = count + 1
+ goto CONTINUE
+ end
+ if src.type == 'tablefield'
+ or src.type == 'setfield'
+ or src.type == 'tableindex'
+ or src.type == 'setindex'
+ or src.type == 'setmethod'
+ or src.type == 'setglobal' then
+ fields[name] = src
+ goto CONTINUE
+ end
+ ::CONTINUE::
+ end
+ for name, src in util.sortPairs(fields) do
+ checkFieldThen(name, src, word, start, offset, parent, oop, results)
+ end
+end
+
+local function checkField(ast, word, start, offset, parent, oop, results)
+ local refs = vm.getFields(parent, 'deep')
+ checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
+end
+
+local function checkGlobal(ast, word, start, offset, parent, oop, results)
+ local locals = guide.getVisibleLocals(ast.ast, offset)
+ local refs = vm.getGlobalSets '*'
+ checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global')
+end
+
+local function checkTableField(ast, word, start, results)
+ local source = guide.eachSourceContain(ast.ast, start, function (source)
+ if source.start == start
+ and source.parent
+ and source.parent.type == 'table' then
+ return source
+ end
+ end)
+ if not source then
+ return
+ end
+ local used = {}
+ guide.eachSourceType(ast.ast, 'tablefield', function (src)
+ if not src.field then
+ return
+ end
+ local key = src.field[1]
+ if not used[key]
+ and matchKey(word, key)
+ and src ~= source then
+ used[key] = true
+ results[#results+1] = {
+ label = key,
+ kind = define.CompletionItemKind.Property,
+ }
+ end
+ end)
+end
+
+local function checkCommon(word, text, offset, results)
+ local used = {}
+ for _, result in ipairs(results) do
+ used[result.label] = true
+ end
+ for _, data in ipairs(keyWordMap) do
+ used[data[1]] = true
+ end
+ for str, pos in text:gmatch '([%a_][%w_]*)()' do
+ if not used[str] and pos - 1 ~= offset then
+ used[str] = true
+ if matchKey(word, str) then
+ results[#results+1] = {
+ label = str,
+ kind = define.CompletionItemKind.Text,
+ }
+ end
+ end
+ end
+end
+
+local function isInString(ast, offset)
+ return guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'string' then
+ return true
+ end
+ end)
+end
+
+local function checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results)
+ local snipType = config.config.completion.keywordSnippet
+ for _, data in ipairs(keyWordMap) do
+ local key = data[1]
+ local eq
+ if hasSpace then
+ eq = word == key
+ else
+ eq = matchKey(word, key)
+ end
+ if afterLocal and key ~= 'function' then
+ eq = false
+ end
+ if eq then
+ local replaced
+ local extra
+ if snipType == 'Both' or snipType == 'Replace' then
+ local func = data[2]
+ if func then
+ replaced = func(hasSpace, results)
+ extra = true
+ end
+ end
+ if snipType == 'Both' then
+ replaced = false
+ end
+ if not replaced then
+ if not hasSpace then
+ local item = {
+ label = key,
+ kind = define.CompletionItemKind.Keyword,
+ }
+ if extra then
+ table.insert(results, #results, item)
+ else
+ results[#results+1] = item
+ end
+ end
+ end
+ local checkStop = data[3]
+ if checkStop then
+ local stop = checkStop(ast, start)
+ if stop then
+ return true
+ end
+ end
+ end
+ end
+end
+
+local function checkProvideLocal(ast, word, start, results)
+ local block
+ guide.eachSourceContain(ast.ast, start, function (source)
+ if source.type == 'function'
+ or source.type == 'main' then
+ block = source
+ end
+ end)
+ if not block then
+ return
+ end
+ local used = {}
+ guide.eachSourceType(block, 'getglobal', function (source)
+ if source.start > start
+ and not used[source[1]]
+ and matchKey(word, source[1]) then
+ used[source[1]] = true
+ results[#results+1] = {
+ label = source[1],
+ kind = define.CompletionItemKind.Variable,
+ }
+ end
+ end)
+ guide.eachSourceType(block, 'getlocal', function (source)
+ if source.start > start
+ and not used[source[1]]
+ and matchKey(word, source[1]) then
+ used[source[1]] = true
+ results[#results+1] = {
+ label = source[1],
+ kind = define.CompletionItemKind.Variable,
+ }
+ end
+ end)
+end
+
+local function checkFunctionArgByDocParam(ast, word, start, results)
+ local func = guide.eachSourceContain(ast.ast, start, function (source)
+ if source.type == 'function' then
+ return source
+ end
+ end)
+ if not func then
+ return
+ end
+ local docs = func.bindDocs
+ if not docs then
+ return
+ end
+ local params = {}
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.param' then
+ params[#params+1] = doc
+ end
+ end
+ local firstArg = func.args and func.args[1]
+ if not firstArg
+ or firstArg.start <= start and firstArg.finish >= start then
+ local firstParam = params[1]
+ if firstParam and matchKey(word, firstParam.param[1]) then
+ local label = {}
+ for _, param in ipairs(params) do
+ label[#label+1] = param.param[1]
+ end
+ results[#results+1] = {
+ label = table.concat(label, ', '),
+ kind = define.CompletionItemKind.Snippet,
+ }
+ end
+ end
+ for _, doc in ipairs(params) do
+ if matchKey(word, doc.param[1]) then
+ results[#results+1] = {
+ label = doc.param[1],
+ kind = define.CompletionItemKind.Interface,
+ }
+ end
+ end
+end
+
+local function isAfterLocal(text, start)
+ local pos = skipSpace(text, start-1)
+ local word = findWord(text, pos)
+ return word == 'local'
+end
+
+local function checkUri(ast, text, offset, results)
+ local collect = {}
+ local myUri = guide.getUri(ast.ast)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type ~= 'string' then
+ return
+ end
+ local callargs = source.parent
+ if not callargs or callargs.type ~= 'callargs' then
+ return
+ end
+ if callargs[1] ~= source then
+ return
+ end
+ local call = callargs.parent
+ local func = call.node
+ local literal = guide.getLiteral(source)
+ local libName = vm.getLibraryName(func)
+ if not libName then
+ return
+ end
+ if libName == 'require' then
+ for uri in files.eachFile() do
+ uri = files.getOriginUri(uri)
+ if files.eq(myUri, uri) then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(uri) then
+ goto CONTINUE
+ end
+ local path = workspace.getRelativePath(uri)
+ local infos = rpath.getVisiblePath(path, config.config.runtime.path)
+ for _, info in ipairs(infos) do
+ if matchKey(literal, info.expect) then
+ if not collect[info.expect] then
+ collect[info.expect] = {
+ textEdit = {
+ start = source.start + #source[2],
+ finish = source.finish - #source[2],
+ }
+ }
+ end
+ collect[info.expect][#collect[info.expect]+1] = ([=[* [%s](%s) %s]=]):format(
+ path,
+ uri,
+ lang.script('HOVER_USE_LUA_PATH', info.searcher)
+ )
+ end
+ end
+ ::CONTINUE::
+ end
+ elseif libName == 'dofile'
+ or libName == 'loadfile' then
+ for uri in files.eachFile() do
+ uri = files.getOriginUri(uri)
+ if files.eq(myUri, uri) then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(uri) then
+ goto CONTINUE
+ end
+ local path = workspace.getRelativePath(uri)
+ if matchKey(literal, path) then
+ if not collect[path] then
+ collect[path] = {
+ textEdit = {
+ start = source.start + #source[2],
+ finish = source.finish - #source[2],
+ }
+ }
+ end
+ collect[path][#collect[path]+1] = ([=[[%s](%s)]=]):format(
+ path,
+ uri
+ )
+ end
+ ::CONTINUE::
+ end
+ end
+ end)
+ for label, infos in util.sortPairs(collect) do
+ local mark = {}
+ local des = {}
+ for _, info in ipairs(infos) do
+ if not mark[info] then
+ mark[info] = true
+ des[#des+1] = info
+ end
+ end
+ results[#results+1] = {
+ label = label,
+ kind = define.CompletionItemKind.Reference,
+ description = table.concat(des, '\n'),
+ textEdit = infos.textEdit,
+ }
+ end
+end
+
+local function checkLenPlusOne(ast, text, offset, results)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'getindex'
+ or source.type == 'setindex' then
+ local _, pos = text:find('%s*%[%s*%#', source.node.finish)
+ if not pos then
+ return
+ end
+ local nodeText = text:sub(source.node.start, source.node.finish)
+ local writingText = trim(text:sub(pos + 1, offset - 1)) or ''
+ if not matchKey(writingText, nodeText) then
+ return
+ end
+ if source.parent == guide.getParentBlock(source) then
+ -- state
+ local label = text:match('%#[ \t]*', pos) .. nodeText .. '+1'
+ local eq = text:find('^%s*%]?%s*%=', source.finish)
+ local newText = label .. ']'
+ if not eq then
+ newText = newText .. ' = '
+ end
+ results[#results+1] = {
+ label = label,
+ kind = define.CompletionItemKind.Snippet,
+ textEdit = {
+ start = pos,
+ finish = source.finish,
+ newText = newText,
+ },
+ }
+ else
+ -- exp
+ local label = text:match('%#[ \t]*', pos) .. nodeText
+ local newText = label .. ']'
+ results[#results+1] = {
+ label = label,
+ kind = define.CompletionItemKind.Snippet,
+ textEdit = {
+ start = pos,
+ finish = source.finish,
+ newText = newText,
+ },
+ }
+ end
+ end
+ end)
+end
+
+local function isFuncArg(ast, offset)
+ return guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'funcargs' then
+ return true
+ end
+ end)
+end
+
+local function trySpecial(ast, text, offset, results)
+ if isInString(ast, offset) then
+ checkUri(ast, text, offset, results)
+ return
+ end
+ -- x[#x+1]
+ checkLenPlusOne(ast, text, offset, results)
+end
+
+local function tryIndex(ast, text, offset, results)
+ local parent, oop = findParentInStringIndex(ast, text, offset)
+ if not parent then
+ return
+ end
+ local word = parent.next.index[1]
+ checkField(ast, word, offset, offset, parent, oop, results)
+end
+
+local function tryWord(ast, text, offset, results)
+ local finish = skipSpace(text, offset)
+ local word, start = findWord(text, finish)
+ if not word then
+ return nil
+ end
+ local hasSpace = finish ~= offset
+ if isInString(ast, offset) then
+ else
+ local parent, oop = findParent(ast, text, start - 1)
+ if parent then
+ if not hasSpace then
+ checkField(ast, word, start, offset, parent, oop, results)
+ end
+ elseif isFuncArg(ast, offset) then
+ checkProvideLocal(ast, word, start, results)
+ checkFunctionArgByDocParam(ast, word, start, results)
+ else
+ local afterLocal = isAfterLocal(text, start)
+ local stop = checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results)
+ if stop then
+ return
+ end
+ if not hasSpace then
+ if afterLocal then
+ checkProvideLocal(ast, word, start, results)
+ else
+ checkLocal(ast, word, start, results)
+ checkTableField(ast, word, start, results)
+ local env = guide.getENV(ast.ast, start)
+ checkGlobal(ast, word, start, offset, env, false, results)
+ end
+ end
+ end
+ if not hasSpace then
+ checkCommon(word, text, offset, results)
+ end
+ end
+end
+
+local function trySymbol(ast, text, offset, results)
+ local symbol, start = findSymbol(text, offset)
+ if not symbol then
+ return nil
+ end
+ if isInString(ast, offset) then
+ return nil
+ end
+ if symbol == '.'
+ or symbol == ':' then
+ local parent, oop = findParent(ast, text, start)
+ if parent then
+ checkField(ast, '', start, offset, parent, oop, results)
+ end
+ end
+ if symbol == '(' then
+ checkFunctionArgByDocParam(ast, '', start, results)
+ end
+end
+
+local function getCallEnums(source, index)
+ if source.type == 'function' and source.bindDocs then
+ if not source.args then
+ return
+ end
+ local arg
+ if index <= #source.args then
+ arg = source.args[index]
+ else
+ local lastArg = source.args[#source.args]
+ if lastArg.type == '...' then
+ arg = lastArg
+ else
+ return
+ end
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.param'
+ and doc.param[1] == arg[1] then
+ local enums = {}
+ for _, enum in ipairs(vm.getDocEnums(doc.extends)) do
+ enums[#enums+1] = {
+ label = enum[1],
+ description = enum.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ end
+ return enums
+ elseif doc.type == 'doc.vararg'
+ and arg.type == '...' then
+ local enums = {}
+ for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do
+ enums[#enums+1] = {
+ label = enum[1],
+ description = enum.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ end
+ return enums
+ end
+ end
+ end
+end
+
+local function tryLabelInString(label, arg)
+ if not arg or arg.type ~= 'string' then
+ return label
+ end
+ local str = parser:grammar(label, 'String')
+ if not str then
+ return label
+ end
+ if not matchKey(arg[1], str[1]) then
+ return nil
+ end
+ return util.viewString(str[1], arg[2])
+end
+
+local function mergeEnums(a, b, text, arg)
+ local mark = {}
+ for _, enum in ipairs(a) do
+ mark[enum.label] = true
+ end
+ for _, enum in ipairs(b) do
+ local label = tryLabelInString(enum.label, arg)
+ if label and not mark[label] then
+ mark[label] = true
+ local result = {
+ label = label,
+ kind = define.CompletionItemKind.EnumMember,
+ description = enum.description,
+ textEdit = arg and {
+ start = arg.start,
+ finish = arg.finish,
+ newText = label,
+ },
+ }
+ a[#a+1] = result
+ end
+ end
+end
+
+local function findCall(ast, text, offset)
+ local call
+ guide.eachSourceContain(ast.ast, offset, function (src)
+ if src.type == 'call' then
+ if not call or call.start < src.start then
+ call = src
+ end
+ end
+ end)
+ return call
+end
+
+local function getCallArgInfo(call, text, offset)
+ if not call.args then
+ return 1, nil
+ end
+ for index, arg in ipairs(call.args) do
+ if arg.start <= offset and arg.finish >= offset then
+ return index, arg
+ end
+ end
+ return #call.args + 1, nil
+end
+
+local function tryCallArg(ast, text, offset, results)
+ local call = findCall(ast, text, offset)
+ if not call then
+ return
+ end
+ local myResults = {}
+ local argIndex, arg = getCallArgInfo(call, text, offset)
+ if arg and arg.type == 'function' then
+ return
+ end
+ local defs = vm.getDefs(call.node, 'deep')
+ for _, def in ipairs(defs) do
+ def = guide.getObjectValue(def) or def
+ local enums = getCallEnums(def, argIndex)
+ if enums then
+ mergeEnums(myResults, enums, text, arg)
+ end
+ end
+ for _, enum in ipairs(myResults) do
+ results[#results+1] = enum
+ end
+end
+
+local function getComment(ast, offset)
+ for _, comm in ipairs(ast.comms) do
+ if offset >= comm.start and offset <= comm.finish then
+ return comm
+ end
+ end
+ return nil
+end
+
+local function tryLuaDocCate(line, results)
+ local word = line:sub(3)
+ for _, docType in ipairs {
+ 'class',
+ 'type',
+ 'alias',
+ 'param',
+ 'return',
+ 'field',
+ 'generic',
+ 'vararg',
+ 'overload',
+ 'deprecated',
+ 'meta',
+ 'version',
+ } do
+ if matchKey(word, docType) then
+ results[#results+1] = {
+ label = docType,
+ kind = define.CompletionItemKind.Event,
+ }
+ end
+ end
+end
+
+local function getLuaDocByContain(ast, offset)
+ local result
+ local range = math.huge
+ guide.eachSourceContain(ast.ast.docs, offset, function (src)
+ if not src.start then
+ return
+ end
+ if range >= offset - src.start
+ and offset <= src.finish then
+ range = offset - src.start
+ result = src
+ end
+ end)
+ return result
+end
+
+local function getLuaDocByErr(ast, text, start, offset)
+ local targetError
+ for _, err in ipairs(ast.errs) do
+ if err.finish <= offset
+ and err.start >= start then
+ if not text:sub(err.finish + 1, offset):find '%S' then
+ targetError = err
+ break
+ end
+ end
+ end
+ if not targetError then
+ return nil
+ end
+ local targetDoc
+ for i = #ast.ast.docs, 1, -1 do
+ local doc = ast.ast.docs[i]
+ if doc.finish <= targetError.start then
+ targetDoc = doc
+ break
+ end
+ end
+ return targetError, targetDoc
+end
+
+local function tryLuaDocBySource(ast, offset, source, results)
+ if source.type == 'doc.extends.name' then
+ if source.parent.type == 'doc.class' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if doc.type == 'doc.class.name'
+ and doc.parent ~= source.parent
+ and matchKey(source[1], doc[1]) then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ end
+ elseif source.type == 'doc.type.name' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name')
+ and doc.parent ~= source.parent
+ and matchKey(source[1], doc[1]) then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ elseif source.type == 'doc.param.name' then
+ local funcs = {}
+ guide.eachSourceBetween(ast.ast, offset, math.huge, function (src)
+ if src.type == 'function' and src.start > offset then
+ funcs[#funcs+1] = src
+ end
+ end)
+ table.sort(funcs, function (a, b)
+ return a.start < b.start
+ end)
+ local func = funcs[1]
+ if not func or not func.args then
+ return
+ end
+ for _, arg in ipairs(func.args) do
+ if arg[1] and matchKey(source[1], arg[1]) then
+ results[#results+1] = {
+ label = arg[1],
+ kind = define.CompletionItemKind.Interface,
+ }
+ end
+ end
+ end
+end
+
+local function tryLuaDocByErr(ast, offset, err, docState, results)
+ if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if doc.type == 'doc.class.name'
+ and doc.parent ~= docState then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ elseif err.type == 'LUADOC_MISS_TYPE_NAME' then
+ for _, doc in ipairs(vm.getDocTypes '*') do
+ if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then
+ results[#results+1] = {
+ label = doc[1],
+ kind = define.CompletionItemKind.Class,
+ }
+ end
+ end
+ elseif err.type == 'LUADOC_MISS_PARAM_NAME' then
+ local funcs = {}
+ guide.eachSourceBetween(ast.ast, offset, math.huge, function (src)
+ if src.type == 'function' and src.start > offset then
+ funcs[#funcs+1] = src
+ end
+ end)
+ table.sort(funcs, function (a, b)
+ return a.start < b.start
+ end)
+ local func = funcs[1]
+ if not func or not func.args then
+ return
+ end
+ local label = {}
+ local insertText = {}
+ for i, arg in ipairs(func.args) do
+ if arg[1] then
+ label[#label+1] = arg[1]
+ if i == 1 then
+ insertText[i] = ('%s ${%d:any}'):format(arg[1], i)
+ else
+ insertText[i] = ('---@param %s ${%d:any}'):format(arg[1], i)
+ end
+ end
+ end
+ results[#results+1] = {
+ label = table.concat(label, ', '),
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = table.concat(insertText, '\n'),
+ }
+ for i, arg in ipairs(func.args) do
+ if arg[1] then
+ results[#results+1] = {
+ label = arg[1],
+ kind = define.CompletionItemKind.Interface,
+ }
+ end
+ end
+ end
+end
+
+local function tryLuaDocFeatures(line, ast, comm, offset, results)
+end
+
+local function tryLuaDoc(ast, text, offset, results)
+ local comm = getComment(ast, offset)
+ local line = text:sub(comm.start, offset)
+ if not line then
+ return
+ end
+ if line:sub(1, 2) ~= '-@' then
+ return
+ end
+ -- 尝试 ---@$
+ local cate = line:match('%a*', 3)
+ if #cate + 2 >= #line then
+ tryLuaDocCate(line, results)
+ return
+ end
+ -- 尝试一些其他特征
+ if tryLuaDocFeatures(line, ast, comm, offset, results) then
+ return
+ end
+ -- 根据输入中的source来补全
+ local source = getLuaDocByContain(ast, offset)
+ if source then
+ tryLuaDocBySource(ast, offset, source, results)
+ return
+ end
+ -- 根据附近的错误消息来补全
+ local err, doc = getLuaDocByErr(ast, text, comm.start, offset)
+ if err then
+ tryLuaDocByErr(ast, offset, err, doc, results)
+ return
+ end
+end
+
+local function completion(uri, offset)
+ local ast = files.getAst(uri)
+ local text = files.getText(uri)
+ local results = {}
+ clearStack()
+ if ast then
+ if getComment(ast, offset) then
+ tryLuaDoc(ast, text, offset, results)
+ else
+ trySpecial(ast, text, offset, results)
+ tryWord(ast, text, offset, results)
+ tryIndex(ast, text, offset, results)
+ trySymbol(ast, text, offset, results)
+ tryCallArg(ast, text, offset, results)
+ end
+ else
+ local word = findWord(text, offset)
+ if word then
+ checkCommon(word, text, offset, results)
+ end
+ end
+
+ if #results == 0 then
+ return nil
+ end
+ return results
+end
+
+local function resolve(id)
+ return resolveStack(id)
+end
+
+return {
+ completion = completion,
+ resolve = resolve,
+}
diff --git a/script/core/definition.lua b/script/core/definition.lua
new file mode 100644
index 00000000..c143939d
--- /dev/null
+++ b/script/core/definition.lua
@@ -0,0 +1,156 @@
+local guide = require 'parser.guide'
+local workspace = require 'workspace'
+local files = require 'files'
+local vm = require 'vm'
+local findSource = require 'core.find-source'
+
+local function sortResults(results)
+ -- 先按照顺序排序
+ table.sort(results, function (a, b)
+ local u1 = guide.getUri(a.target)
+ local u2 = guide.getUri(b.target)
+ if u1 == u2 then
+ return a.target.start < b.target.start
+ else
+ return u1 < u2
+ end
+ end)
+ -- 如果2个结果处于嵌套状态,则取范围小的那个
+ local lf, lu
+ for i = #results, 1, -1 do
+ local res = results[i].target
+ local f = res.finish
+ local uri = guide.getUri(res)
+ if lf and f > lf and uri == lu then
+ table.remove(results, i)
+ else
+ lu = uri
+ lf = f
+ end
+ end
+end
+
+local accept = {
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['label'] = true,
+ ['goto'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['string'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+
+ ['doc.type.name'] = true,
+ ['doc.class.name'] = true,
+ ['doc.extends.name'] = true,
+ ['doc.alias.name'] = true,
+}
+
+local function checkRequire(source, offset)
+ if source.type ~= 'string' then
+ return nil
+ end
+ local callargs = source.parent
+ if callargs.type ~= 'callargs' then
+ return
+ end
+ if callargs[1] ~= source then
+ return
+ end
+ local call = callargs.parent
+ local func = call.node
+ local literal = guide.getLiteral(source)
+ local libName = vm.getLibraryName(func)
+ if not libName then
+ return nil
+ end
+ if libName == 'require' then
+ return workspace.findUrisByRequirePath(literal)
+ elseif libName == 'dofile'
+ or libName == 'loadfile' then
+ return workspace.findUrisByFilePath(literal)
+ end
+ return nil
+end
+
+local function convertIndex(source)
+ if not source then
+ return
+ end
+ if source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number' then
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return parent
+ end
+ end
+ return source
+end
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local source = convertIndex(findSource(ast, offset, accept))
+ if not source then
+ return nil
+ end
+
+ local results = {}
+ local uris = checkRequire(source)
+ if uris then
+ for i, uri in ipairs(uris) do
+ results[#results+1] = {
+ uri = files.getOriginUri(uri),
+ source = source,
+ target = {
+ start = 0,
+ finish = 0,
+ uri = uri,
+ }
+ }
+ end
+ end
+
+ for _, src in ipairs(vm.getDefs(source, 'deep')) do
+ local root = guide.getRoot(src)
+ if not root then
+ goto CONTINUE
+ end
+ src = src.field or src.method or src.index or src
+ if src.type == 'table' and src.parent.type ~= 'return' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.class.name'
+ and source.type ~= 'doc.type.name'
+ and source.type ~= 'doc.extends.name' then
+ goto CONTINUE
+ end
+ results[#results+1] = {
+ target = src,
+ uri = files.getOriginUri(root.uri),
+ source = source,
+ }
+ ::CONTINUE::
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ sortResults(results)
+
+ return results
+end
diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua
new file mode 100644
index 00000000..37815fb5
--- /dev/null
+++ b/script/core/diagnostics/ambiguity-1.lua
@@ -0,0 +1,69 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+local opMap = {
+ ['+'] = true,
+ ['-'] = true,
+ ['*'] = true,
+ ['/'] = true,
+ ['//'] = true,
+ ['^'] = true,
+ ['<<'] = true,
+ ['>>'] = true,
+ ['&'] = true,
+ ['|'] = true,
+ ['~'] = true,
+ ['..'] = true,
+}
+
+local literalMap = {
+ ['number'] = true,
+ ['boolean'] = true,
+ ['string'] = true,
+ ['table'] = true,
+}
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local text = files.getText(uri)
+ guide.eachSourceType(ast.ast, 'binary', function (source)
+ if source.op.type ~= 'or' then
+ return
+ end
+ local first = source[1]
+ local second = source[2]
+ -- a + (b or 0) --> (a + b) or 0
+ do
+ if opMap[first.op and first.op.type]
+ and first.type ~= 'unary'
+ and not second.op
+ and literalMap[second.type]
+ and not literalMap[first[2].type]
+ then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_AMBIGUITY_1', text:sub(first.start, first.finish))
+ }
+ end
+ end
+ -- (a or 0) + c --> a or (0 + c)
+ do
+ if opMap[second.op and second.op.type]
+ and second.type ~= 'unary'
+ and not first.op
+ and literalMap[second[1].type]
+ then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_AMBIGUITY_1', text:sub(second.start, second.finish))
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua
new file mode 100644
index 00000000..55179447
--- /dev/null
+++ b/script/core/diagnostics/circle-doc-class.lua
@@ -0,0 +1,54 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.class' then
+ if not doc.extends then
+ goto CONTINUE
+ end
+ local myName = guide.getName(doc)
+ local list = { doc }
+ local mark = {}
+ for i = 1, 999 do
+ local current = list[i]
+ if not current then
+ goto CONTINUE
+ end
+ if current.extends then
+ local newName = current.extends[1]
+ if newName == myName then
+ callback {
+ start = doc.start,
+ finish = doc.finish,
+ message = lang.script('DIAG_CIRCLE_DOC_CLASS', myName)
+ }
+ goto CONTINUE
+ end
+ if not mark[newName] then
+ mark[newName] = true
+ local docs = vm.getDocTypes(newName)
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name' then
+ list[#list+1] = otherDoc.parent
+ end
+ end
+ end
+ end
+ end
+ ::CONTINUE::
+ end
+ end
+end
diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua
new file mode 100644
index 00000000..a2bac8a4
--- /dev/null
+++ b/script/core/diagnostics/code-after-break.lua
@@ -0,0 +1,34 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ local mark = {}
+ guide.eachSourceType(state.ast, 'break', function (source)
+ local list = source.parent
+ if mark[list] then
+ return
+ end
+ mark[list] = true
+ for i = #list, 1, -1 do
+ local src = list[i]
+ if src == source then
+ if i == #list then
+ return
+ end
+ callback {
+ start = list[i+1].start,
+ finish = list[#list].range or list[#list].finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_CODE_AFTER_BREAK,
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/doc-field-no-class.lua b/script/core/diagnostics/doc-field-no-class.lua
new file mode 100644
index 00000000..f27bbb32
--- /dev/null
+++ b/script/core/diagnostics/doc-field-no-class.lua
@@ -0,0 +1,41 @@
+local files = require 'files'
+local lang = require 'language'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type ~= 'doc.field' then
+ goto CONTINUE
+ end
+ local bindGroup = doc.bindGroup
+ if not bindGroup then
+ goto CONTINUE
+ end
+ local ok
+ for _, other in ipairs(bindGroup) do
+ if other.type == 'doc.class' then
+ ok = true
+ break
+ end
+ if other == doc then
+ break
+ end
+ end
+ if not ok then
+ callback {
+ start = doc.start,
+ finish = doc.finish,
+ message = lang.script('DIAG_DOC_FIELD_NO_CLASS'),
+ }
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua
new file mode 100644
index 00000000..259c048b
--- /dev/null
+++ b/script/core/diagnostics/duplicate-doc-class.lua
@@ -0,0 +1,46 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ local cache = {}
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.alias' then
+ local name = guide.getName(doc)
+ if not cache[name] then
+ local docs = vm.getDocTypes(name)
+ cache[name] = {}
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name'
+ or otherDoc.type == 'doc.alias.name' then
+ cache[name][#cache[name]+1] = {
+ start = otherDoc.start,
+ finish = otherDoc.finish,
+ uri = guide.getUri(otherDoc),
+ }
+ end
+ end
+ end
+ if #cache[name] > 1 then
+ callback {
+ start = doc.start,
+ finish = doc.finish,
+ related = cache,
+ message = lang.script('DIAG_DUPLICATE_DOC_CLASS', name)
+ }
+ end
+ end
+ end
+end
diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua
new file mode 100644
index 00000000..b621fd9e
--- /dev/null
+++ b/script/core/diagnostics/duplicate-doc-field.lua
@@ -0,0 +1,34 @@
+local files = require 'files'
+local lang = require 'language'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ local mark
+ for _, group in ipairs(state.ast.docs.groups) do
+ for _, doc in ipairs(group) do
+ if doc.type == 'doc.class' then
+ mark = {}
+ elseif doc.type == 'doc.field' then
+ if mark then
+ local name = doc.field[1]
+ if mark[name] then
+ callback {
+ start = doc.field.start,
+ finish = doc.field.finish,
+ message = lang.script('DIAG_DUPLICATE_DOC_FIELD', name),
+ }
+ end
+ mark[name] = true
+ end
+ end
+ end
+ end
+end
diff --git a/script/core/diagnostics/duplicate-doc-param.lua b/script/core/diagnostics/duplicate-doc-param.lua
new file mode 100644
index 00000000..676a6fb4
--- /dev/null
+++ b/script/core/diagnostics/duplicate-doc-param.lua
@@ -0,0 +1,37 @@
+local files = require 'files'
+local lang = require 'language'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type ~= 'doc.param' then
+ goto CONTINUE
+ end
+ local name = doc.param[1]
+ local bindGroup = doc.bindGroup
+ if not bindGroup then
+ goto CONTINUE
+ end
+ for _, other in ipairs(bindGroup) do
+ if other ~= doc
+ and other.type == 'doc.param'
+ and other.param[1] == name then
+ callback {
+ start = doc.param.start,
+ finish = doc.param.finish,
+ message = lang.script('DIAG_DUPLICATE_DOC_PARAM', name)
+ }
+ goto CONTINUE
+ end
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua
new file mode 100644
index 00000000..dabe1b3c
--- /dev/null
+++ b/script/core/diagnostics/duplicate-index.lua
@@ -0,0 +1,63 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'table', function (source)
+ local mark = {}
+ for _, obj in ipairs(source) do
+ if obj.type == 'tablefield'
+ or obj.type == 'tableindex' then
+ local name = vm.getKeyName(obj)
+ if name then
+ if not mark[name] then
+ mark[name] = {}
+ end
+ mark[name][#mark[name]+1] = obj.field or obj.index
+ end
+ end
+ end
+
+ for name, defs in pairs(mark) do
+ local sname = name:match '^.|(.+)$'
+ if #defs > 1 and sname then
+ local related = {}
+ for i = 1, #defs do
+ local def = defs[i]
+ related[i] = {
+ start = def.start,
+ finish = def.finish,
+ uri = uri,
+ }
+ end
+ for i = 1, #defs - 1 do
+ local def = defs[i]
+ callback {
+ start = def.start,
+ finish = def.finish,
+ related = related,
+ message = lang.script('DIAG_DUPLICATE_INDEX', sname),
+ level = define.DiagnosticSeverity.Hint,
+ tags = { define.DiagnosticTag.Unnecessary },
+ }
+ end
+ for i = #defs, #defs do
+ local def = defs[i]
+ callback {
+ start = def.start,
+ finish = def.finish,
+ related = related,
+ message = lang.script('DIAG_DUPLICATE_INDEX', sname),
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua
new file mode 100644
index 00000000..2024f4e3
--- /dev/null
+++ b/script/core/diagnostics/empty-block.lua
@@ -0,0 +1,49 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+
+-- 检查空代码块
+-- 但是排除忙等待(repeat/while)
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'if', function (source)
+ for _, block in ipairs(source) do
+ if #block > 0 then
+ return
+ end
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+ guide.eachSourceType(ast.ast, 'loop', function (source)
+ if #source > 0 then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+ guide.eachSourceType(ast.ast, 'in', function (source)
+ if #source > 0 then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_EMPTY_BLOCK,
+ }
+ end)
+end
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
new file mode 100644
index 00000000..9a0d4f35
--- /dev/null
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -0,0 +1,66 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+-- TODO: 检查路径是否可达
+local function mayRun(path)
+ return true
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local root = guide.getRoot(ast.ast)
+ local env = guide.getENV(root)
+
+ local nilDefs = {}
+ if not env.ref then
+ return
+ end
+ for _, ref in ipairs(env.ref) do
+ if ref.type == 'setlocal' then
+ if ref.value and ref.value.type == 'nil' then
+ nilDefs[#nilDefs+1] = ref
+ end
+ end
+ end
+
+ if #nilDefs == 0 then
+ return
+ end
+
+ local function check(source)
+ local node = source.node
+ if node.tag == '_ENV' then
+ local ok
+ for _, nilDef in ipairs(nilDefs) do
+ local mode, pathA = guide.getPath(nilDef, source)
+ if mode == 'before'
+ and mayRun(pathA) then
+ ok = nilDef
+ break
+ end
+ end
+ if ok then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ uri = uri,
+ message = lang.script.DIAG_GLOBAL_IN_NIL_ENV,
+ related = {
+ {
+ start = ok.start,
+ finish = ok.finish,
+ uri = uri,
+ }
+ }
+ }
+ end
+ end
+ end
+
+ guide.eachSourceType(ast.ast, 'getglobal', check)
+ guide.eachSourceType(ast.ast, 'setglobal', check)
+end
diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua
new file mode 100644
index 00000000..a6b61e12
--- /dev/null
+++ b/script/core/diagnostics/init.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local define = require 'proto.define'
+local config = require 'config'
+local await = require 'await'
+
+-- 把耗时最长的诊断放到最后面
+local diagLevel = {
+ ['redundant-parameter'] = 100,
+}
+
+local diagList = {}
+for k in pairs(define.DiagnosticDefaultSeverity) do
+ diagList[#diagList+1] = k
+end
+table.sort(diagList, function (a, b)
+ return (diagLevel[a] or 0) < (diagLevel[b] or 0)
+end)
+
+local function check(uri, name, results)
+ if config.config.diagnostics.disable[name] then
+ return
+ end
+ local level = config.config.diagnostics.severity[name]
+ or define.DiagnosticDefaultSeverity[name]
+ if level == 'Hint' and not files.isOpen(uri) then
+ return
+ end
+ local severity = define.DiagnosticSeverity[level]
+ local clock = os.clock()
+ require('core.diagnostics.' .. name)(uri, function (result)
+ result.level = severity or result.level
+ result.code = name
+ results[#results+1] = result
+ end, name)
+ local passed = os.clock() - clock
+ if passed >= 0.5 then
+ log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed))
+ end
+end
+
+return function (uri, response)
+ local vm = require 'vm'
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local isOpen = files.isOpen(uri)
+
+ for _, name in ipairs(diagList) do
+ await.delay()
+ local results = {}
+ check(uri, name, results)
+ response(results)
+ end
+end
diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua
new file mode 100644
index 00000000..fe5d1eca
--- /dev/null
+++ b/script/core/diagnostics/lowercase-global.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local config = require 'config'
+local vm = require 'vm'
+
+local function isDocClass(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class' then
+ return true
+ end
+ end
+ return false
+end
+
+-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删)
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local definedGlobal = {}
+ for name in pairs(config.config.diagnostics.globals) do
+ definedGlobal[name] = true
+ end
+
+ guide.eachSourceType(ast.ast, 'setglobal', function (source)
+ local name = guide.getName(source)
+ if definedGlobal[name] then
+ return
+ end
+ local first = name:match '%w'
+ if not first then
+ return
+ end
+ if not first:match '%l' then
+ return
+ end
+ -- 如果赋值被标记为 doc.class ,则认为是允许的
+ if isDocClass(source) then
+ return
+ end
+ if vm.isGlobalLibraryName(name) then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script.DIAG_LOWERCASE_GLOBAL,
+ }
+ end)
+end
diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua
new file mode 100644
index 00000000..75681cbc
--- /dev/null
+++ b/script/core/diagnostics/newfield-call.lua
@@ -0,0 +1,37 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+
+ guide.eachSourceType(ast.ast, 'table', function (source)
+ for i = 1, #source do
+ local field = source[i]
+ if field.type == 'call' then
+ local func = field.node
+ local args = field.args
+ if args then
+ local funcLine = guide.positionOf(lines, func.finish)
+ local argsLine = guide.positionOf(lines, args.start)
+ if argsLine > funcLine then
+ callback {
+ start = field.start,
+ finish = field.finish,
+ message = lang.script('DIAG_PREFIELD_CALL'
+ , text:sub(func.start, func.finish)
+ , text:sub(args.start, args.finish)
+ )
+ }
+ end
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua
new file mode 100644
index 00000000..cb318380
--- /dev/null
+++ b/script/core/diagnostics/newline-call.lua
@@ -0,0 +1,38 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local lines = files.getLines(uri)
+
+ guide.eachSourceType(ast.ast, 'call', function (source)
+ local node = source.node
+ local args = source.args
+ if not args then
+ return
+ end
+
+ -- 必须有其他人在继续使用当前对象
+ if not source.next then
+ return
+ end
+
+ local nodeRow = guide.positionOf(lines, node.finish)
+ local argRow = guide.positionOf(lines, args.start)
+ if nodeRow == argRow then
+ return
+ end
+
+ if #args == 1 then
+ callback {
+ start = args.start,
+ finish = args.finish,
+ message = lang.script.DIAG_PREVIOUS_CALL,
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua
new file mode 100644
index 00000000..5e53d837
--- /dev/null
+++ b/script/core/diagnostics/redefined-local.lua
@@ -0,0 +1,32 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'local', function (source)
+ local name = source[1]
+ if name == '_'
+ or name == ast.ENVMode then
+ return
+ end
+ local exist = guide.getLocal(source, name, source.start-1)
+ if exist then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_REDEFINED_LOCAL', name),
+ related = {
+ {
+ start = exist.start,
+ finish = exist.finish,
+ uri = uri,
+ }
+ },
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
new file mode 100644
index 00000000..2fae20e8
--- /dev/null
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -0,0 +1,82 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+local define = require 'proto.define'
+local await = require 'await'
+
+local function countCallArgs(source)
+ local result = 0
+ if not source.args then
+ return 0
+ end
+ if source.node and source.node.type == 'getmethod' then
+ result = result + 1
+ end
+ result = result + #source.args
+ return result
+end
+
+local function countFuncArgs(source)
+ local result = 0
+ if source.parent and source.parent.type == 'setmethod' then
+ result = result + 1
+ end
+ if not source.args then
+ return result
+ end
+ if source.args[#source.args].type == '...' then
+ return math.maxinteger
+ end
+ result = result + #source.args
+ return result
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'call', function (source)
+ local callArgs = countCallArgs(source)
+ if callArgs == 0 then
+ return
+ end
+
+ local func = source.node
+ local funcArgs
+ local defs = vm.getDefs(func)
+ for _, def in ipairs(defs) do
+ if def.value then
+ def = def.value
+ end
+ if def.type == 'function' then
+ local args = countFuncArgs(def)
+ if not funcArgs or args > funcArgs then
+ funcArgs = args
+ end
+ end
+ end
+
+ if not funcArgs then
+ return
+ end
+
+ local delta = callArgs - funcArgs
+ if delta <= 0 then
+ return
+ end
+ for i = #source.args - delta + 1, #source.args do
+ local arg = source.args[i]
+ if arg then
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/redundant-value.lua b/script/core/diagnostics/redundant-value.lua
new file mode 100644
index 00000000..be483448
--- /dev/null
+++ b/script/core/diagnostics/redundant-value.lua
@@ -0,0 +1,24 @@
+local files = require 'files'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback, code)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ local diags = ast.diags[code]
+ if not diags then
+ return
+ end
+
+ for _, info in ipairs(diags) do
+ callback {
+ start = info.start,
+ finish = info.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_OVER_MAX_VALUES', info.max, info.passed)
+ }
+ end
+end
diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua
new file mode 100644
index 00000000..e54a6e60
--- /dev/null
+++ b/script/core/diagnostics/trailing-space.lua
@@ -0,0 +1,55 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+
+local function isInString(ast, offset)
+ local result = false
+ guide.eachSourceType(ast, 'string', function (source)
+ if offset >= source.start and offset <= source.finish then
+ result = true
+ end
+ end)
+ return result
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local text = files.getText(uri)
+ local lines = files.getLines(uri)
+ for i = 1, #lines do
+ local start = lines[i].start
+ local range = lines[i].range
+ local lastChar = text:sub(range, range)
+ if lastChar ~= ' ' and lastChar ~= '\t' then
+ goto NEXT_LINE
+ end
+ if isInString(ast.ast, range) then
+ goto NEXT_LINE
+ end
+ local first = start
+ for n = range - 1, start, -1 do
+ local char = text:sub(n, n)
+ if char ~= ' ' and char ~= '\t' then
+ first = n + 1
+ break
+ end
+ end
+ if first == start then
+ callback {
+ start = first,
+ finish = range,
+ message = lang.script.DIAG_LINE_ONLY_SPACE,
+ }
+ else
+ callback {
+ start = first,
+ finish = range,
+ message = lang.script.DIAG_LINE_POST_SPACE,
+ }
+ end
+ ::NEXT_LINE::
+ end
+end
diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua
new file mode 100644
index 00000000..bbfdceec
--- /dev/null
+++ b/script/core/diagnostics/undefined-doc-class.lua
@@ -0,0 +1,46 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ local cache = {}
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type == 'doc.class' then
+ local ext = doc.extends
+ if not ext then
+ goto CONTINUE
+ end
+ local name = ext[1]
+ local docs = vm.getDocTypes(name)
+ if cache[name] == nil then
+ cache[name] = false
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name' then
+ cache[name] = true
+ break
+ end
+ end
+ end
+ if not cache[name] then
+ callback {
+ start = ext.start,
+ finish = ext.finish,
+ related = cache,
+ message = lang.script('DIAG_UNDEFINED_DOC_CLASS', name)
+ }
+ end
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
new file mode 100644
index 00000000..5c1e8fbf
--- /dev/null
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -0,0 +1,60 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+local function hasNameOfClassOrAlias(name)
+ local docs = vm.getDocTypes(name)
+ for _, otherDoc in ipairs(docs) do
+ if otherDoc.type == 'doc.class.name'
+ or otherDoc.type == 'doc.alias.name' then
+ return true
+ end
+ end
+ return false
+end
+
+local function hasNameOfGeneric(name, source)
+ if not source.typeGeneric then
+ return false
+ end
+ if not source.typeGeneric[name] then
+ return false
+ end
+ return true
+end
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ guide.eachSource(state.ast.docs, function (source)
+ if source.type ~= 'doc.extends.name'
+ and source.type ~= 'doc.type.name' then
+ return
+ end
+ if source.parent.type == 'doc.class' then
+ return
+ end
+ local name = source[1]
+ if name == '...' then
+ return
+ end
+ if hasNameOfClassOrAlias(name)
+ or hasNameOfGeneric(name, source) then
+ return
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_UNDEFINED_DOC_NAME', name)
+ }
+ end)
+end
diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua
new file mode 100644
index 00000000..af3e07bc
--- /dev/null
+++ b/script/core/diagnostics/undefined-doc-param.lua
@@ -0,0 +1,52 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
+
+local function hasParamName(func, name)
+ if not func.args then
+ return false
+ end
+ for _, arg in ipairs(func.args) do
+ if arg[1] == name then
+ return true
+ end
+ end
+ return false
+end
+
+return function (uri, callback)
+ local state = files.getAst(uri)
+ if not state then
+ return
+ end
+
+ if not state.ast.docs then
+ return
+ end
+
+ for _, doc in ipairs(state.ast.docs) do
+ if doc.type ~= 'doc.param' then
+ goto CONTINUE
+ end
+ local binds = doc.bindSources
+ if not binds then
+ goto CONTINUE
+ end
+ local param = doc.param
+ local name = param[1]
+ for _, source in ipairs(binds) do
+ if source.type == 'function' then
+ if not hasParamName(source, name) then
+ callback {
+ start = param.start,
+ finish = param.finish,
+ message = lang.script('DIAG_UNDEFINED_DOC_PARAM', name)
+ }
+ end
+ end
+ end
+ ::CONTINUE::
+ end
+end
diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua
new file mode 100644
index 00000000..6b8c62f0
--- /dev/null
+++ b/script/core/diagnostics/undefined-env-child.lua
@@ -0,0 +1,27 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'getglobal', function (source)
+ -- 单独验证自己是否在重载过的 _ENV 中有定义
+ if source.node.tag == '_ENV' then
+ return
+ end
+ local defs = guide.requestDefinition(source)
+ if #defs > 0 then
+ return
+ end
+ local key = source[1]
+ callback {
+ start = source.start,
+ finish = source.finish,
+ message = lang.script('DIAG_UNDEF_ENV_CHILD', key),
+ }
+ end)
+end
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
new file mode 100644
index 00000000..778fc1f1
--- /dev/null
+++ b/script/core/diagnostics/undefined-global.lua
@@ -0,0 +1,40 @@
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local config = require 'config'
+local guide = require 'parser.guide'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ -- 遍历全局变量,检查所有没有 set 模式的全局变量
+ guide.eachSourceType(ast.ast, 'getglobal', function (src)
+ local key = guide.getName(src)
+ if not key then
+ return
+ end
+ if config.config.diagnostics.globals[key] then
+ return
+ end
+ if #vm.getGlobalSets(guide.getKeyName(src)) > 0 then
+ return
+ end
+ local message = lang.script('DIAG_UNDEF_GLOBAL', key)
+ -- TODO check other version
+ local otherVersion
+ local customVersion
+ if otherVersion then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version))
+ elseif customVersion then
+ message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customVersion, '/')))
+ end
+ callback {
+ start = src.start,
+ finish = src.finish,
+ message = message,
+ }
+ end)
+end
diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua
new file mode 100644
index 00000000..f0bca613
--- /dev/null
+++ b/script/core/diagnostics/unused-function.lua
@@ -0,0 +1,40 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local define = require 'proto.define'
+local lang = require 'language'
+local await = require 'await'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ -- 只检查局部函数
+ guide.eachSourceType(ast.ast, 'function', function (source)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type ~= 'local'
+ and parent.type ~= 'setlocal' then
+ return
+ end
+ local hasGet
+ local refs = vm.getRefs(source)
+ for _, src in ipairs(refs) do
+ if vm.isGet(src) then
+ hasGet = true
+ break
+ end
+ end
+ if not hasGet then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_FUNCTION,
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua
new file mode 100644
index 00000000..e6d998ba
--- /dev/null
+++ b/script/core/diagnostics/unused-label.lua
@@ -0,0 +1,22 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'label', function (source)
+ if not source.ref then
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LABEL', source[1]),
+ }
+ end
+ end)
+end
diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua
new file mode 100644
index 00000000..873a70f2
--- /dev/null
+++ b/script/core/diagnostics/unused-local.lua
@@ -0,0 +1,93 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+local function hasGet(loc)
+ if not loc.ref then
+ return false
+ end
+ local weak
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'getlocal' then
+ if not ref.next then
+ return 'strong'
+ end
+ local nextType = ref.next.type
+ if nextType ~= 'setmethod'
+ and nextType ~= 'setfield'
+ and nextType ~= 'setindex' then
+ return 'strong'
+ else
+ weak = true
+ end
+ end
+ end
+ if weak then
+ return 'weak'
+ else
+ return nil
+ end
+end
+
+local function isMyTable(loc)
+ local value = loc.value
+ if value and value.type == 'table' then
+ return true
+ end
+ return false
+end
+
+local function isClose(source)
+ if not source.attrs then
+ return false
+ end
+ for _, attr in ipairs(source.attrs) do
+ if attr[1] == 'close' then
+ return true
+ end
+ end
+ return false
+end
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ guide.eachSourceType(ast.ast, 'local', function (source)
+ local name = source[1]
+ if name == '_'
+ or name == ast.ENVMode then
+ return
+ end
+ if isClose(source) then
+ return
+ end
+ local data = hasGet(source)
+ if data == 'strong' then
+ return
+ end
+ if data == 'weak' then
+ if not isMyTable(source) then
+ return
+ end
+ end
+ callback {
+ start = source.start,
+ finish = source.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LOCAL', name),
+ }
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback {
+ start = ref.start,
+ finish = ref.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script('DIAG_UNUSED_LOCAL', name),
+ }
+ end
+ end
+ end)
+end
diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua
new file mode 100644
index 00000000..74cc08e7
--- /dev/null
+++ b/script/core/diagnostics/unused-vararg.lua
@@ -0,0 +1,31 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local lang = require 'language'
+
+return function (uri, callback)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSourceType(ast.ast, 'function', function (source)
+ local args = source.args
+ if not args then
+ return
+ end
+
+ for _, arg in ipairs(args) do
+ if arg.type == '...' then
+ if not arg.ref then
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ tags = { define.DiagnosticTag.Unnecessary },
+ message = lang.script.DIAG_UNUSED_VARARG,
+ }
+ end
+ end
+ end
+ end)
+end
diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua
new file mode 100644
index 00000000..7392b337
--- /dev/null
+++ b/script/core/document-symbol.lua
@@ -0,0 +1,307 @@
+local await = require 'await'
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local util = require 'utility'
+
+local function buildName(source, text)
+ if source.type == 'setmethod'
+ or source.type == 'getmethod' then
+ if source.method then
+ return text:sub(source.start, source.method.finish)
+ end
+ end
+ if source.type == 'setfield'
+ or source.type == 'tablefield'
+ or source.type == 'getfield' then
+ if source.field then
+ return text:sub(source.start, source.field.finish)
+ end
+ end
+ return text:sub(source.start, source.finish)
+end
+
+local function buildFunctionParams(func)
+ if not func.args then
+ return ''
+ end
+ local params = {}
+ for i, arg in ipairs(func.args) do
+ if arg.type == '...' then
+ params[i] = '...'
+ else
+ params[i] = arg[1] or ''
+ end
+ end
+ return table.concat(params, ', ')
+end
+
+local function buildFunction(source, text, symbols)
+ local name = buildName(source, text)
+ local func = source.value
+ if source.type == 'tablefield'
+ or source.type == 'setfield' then
+ source = source.field
+ if not source then
+ return
+ end
+ end
+ local range, kind
+ if func.start > source.finish then
+ -- a = function()
+ range = { source.start, func.finish }
+ else
+ -- function f()
+ range = { func.start, func.finish }
+ end
+ if source.type == 'setmethod' then
+ kind = define.SymbolKind.Method
+ else
+ kind = define.SymbolKind.Function
+ end
+ symbols[#symbols+1] = {
+ name = name,
+ detail = ('function (%s)'):format(buildFunctionParams(func)),
+ kind = kind,
+ range = range,
+ selectionRange = { source.start, source.finish },
+ valueRange = { func.start, func.finish },
+ }
+end
+
+local function buildTable(tbl)
+ local buf = {}
+ for i = 1, 3 do
+ local field = tbl[i]
+ if not field then
+ break
+ end
+ if field.type == 'tablefield' then
+ buf[i] = ('%s'):format(field.field[1])
+ end
+ end
+ return table.concat(buf, ', ')
+end
+
+local function buildValue(source, text, symbols)
+ local name = buildName(source, text)
+ local range, sRange, valueRange, kind
+ local details = {}
+ if source.type == 'local' then
+ if source.parent.type == 'funcargs' then
+ details[1] = 'param'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Constant
+ else
+ details[1] = 'local'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Variable
+ end
+ elseif source.type == 'setlocal' then
+ details[1] = 'setlocal'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Variable
+ elseif source.type == 'setglobal' then
+ details[1] = 'global'
+ range = { source.start, source.finish }
+ sRange = { source.start, source.finish }
+ kind = define.SymbolKind.Class
+ elseif source.type == 'tablefield' then
+ if not source.field then
+ return
+ end
+ details[1] = 'field'
+ range = { source.field.start, source.field.finish }
+ sRange = { source.field.start, source.field.finish }
+ kind = define.SymbolKind.Property
+ elseif source.type == 'setfield' then
+ if not source.field then
+ return
+ end
+ details[1] = 'field'
+ range = { source.field.start, source.field.finish }
+ sRange = { source.field.start, source.field.finish }
+ kind = define.SymbolKind.Field
+ else
+ return
+ end
+ if source.value then
+ local literal = source.value[1]
+ if source.value.type == 'boolean' then
+ details[2] = ' boolean'
+ if literal ~= nil then
+ details[3] = ' = '
+ details[4] = util.viewLiteral(source.value[1])
+ end
+ elseif source.value.type == 'string' then
+ details[2] = ' string'
+ if literal ~= nil then
+ details[3] = ' = '
+ details[4] = util.viewLiteral(source.value[1])
+ end
+ elseif source.value.type == 'number' then
+ details[2] = ' number'
+ if literal ~= nil then
+ details[3] = ' = '
+ details[4] = util.viewLiteral(source.value[1])
+ end
+ elseif source.value.type == 'table' then
+ details[2] = ' {'
+ details[3] = buildTable(source.value)
+ details[4] = '}'
+ valueRange = { source.value.start, source.value.finish }
+ elseif source.value.type == 'select' then
+ if source.value.vararg and source.value.vararg.type == 'call' then
+ valueRange = { source.value.start, source.value.finish }
+ end
+ end
+ range = { range[1], source.value.finish }
+ end
+ symbols[#symbols+1] = {
+ name = name,
+ detail = table.concat(details),
+ kind = kind,
+ range = range,
+ selectionRange = sRange,
+ valueRange = valueRange,
+ }
+end
+
+local function buildSet(source, text, used, symbols)
+ local value = source.value
+ if value and value.type == 'function' then
+ used[value] = true
+ buildFunction(source, text, symbols)
+ else
+ buildValue(source, text, symbols)
+ end
+end
+
+local function buildAnonymousFunction(source, text, used, symbols)
+ if used[source] then
+ return
+ end
+ used[source] = true
+ local head = ''
+ local parent = source.parent
+ if parent.type == 'return' then
+ head = 'return '
+ elseif parent.type == 'callargs' then
+ local call = parent.parent
+ local node = call.node
+ head = buildName(node, text) .. ' -> '
+ end
+ symbols[#symbols+1] = {
+ name = '',
+ detail = ('%sfunction (%s)'):format(head, buildFunctionParams(source)),
+ kind = define.SymbolKind.Function,
+ range = { source.start, source.finish },
+ selectionRange = { source.start, source.start },
+ valueRange = { source.start, source.finish },
+ }
+end
+
+local function buildSource(source, text, used, symbols)
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal'
+ or source.type == 'setfield'
+ or source.type == 'setmethod'
+ or source.type == 'tablefield' then
+ await.delay()
+ buildSet(source, text, used, symbols)
+ elseif source.type == 'function' then
+ await.delay()
+ buildAnonymousFunction(source, text, used, symbols)
+ end
+end
+
+local function makeSymbol(uri)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local text = files.getText(uri)
+ local symbols = {}
+ local used = {}
+ guide.eachSource(ast.ast, function (source)
+ buildSource(source, text, used, symbols)
+ end)
+
+ return symbols
+end
+
+local function packChild(ranges, symbols)
+ await.delay()
+ table.sort(symbols, function (a, b)
+ return a.selectionRange[1] < b.selectionRange[1]
+ end)
+ await.delay()
+ local root = {
+ valueRange = { 0, math.maxinteger },
+ children = {},
+ }
+ local stacks = { root }
+ for _, symbol in ipairs(symbols) do
+ local parent = stacks[#stacks]
+ -- 移除已经超出生效范围的区间
+ while symbol.selectionRange[1] > parent.valueRange[2] do
+ stacks[#stacks] = nil
+ parent = stacks[#stacks]
+ end
+ -- 向后看,找出当前可能生效的区间
+ local nextRange
+ while #ranges > 0
+ and symbol.selectionRange[1] >= ranges[#ranges].valueRange[1] do
+ if symbol.selectionRange[1] <= ranges[#ranges].valueRange[2] then
+ nextRange = ranges[#ranges]
+ end
+ ranges[#ranges] = nil
+ end
+ if nextRange then
+ stacks[#stacks+1] = nextRange
+ parent = nextRange
+ end
+ if parent == symbol then
+ -- function f() end 的情况,selectionRange 在 valueRange 内部,
+ -- 当前区间置为上一层
+ parent = stacks[#stacks-1]
+ end
+ -- 把自己放到当前区间中
+ if not parent.children then
+ parent.children = {}
+ end
+ parent.children[#parent.children+1] = symbol
+ end
+ return root.children
+end
+
+local function packSymbols(symbols)
+ local ranges = {}
+ for _, symbol in ipairs(symbols) do
+ if symbol.valueRange then
+ ranges[#ranges+1] = symbol
+ end
+ end
+ await.delay()
+ table.sort(ranges, function (a, b)
+ return a.valueRange[1] > b.valueRange[1]
+ end)
+ -- 处理嵌套
+ return packChild(ranges, symbols)
+end
+
+return function (uri)
+ local symbols = makeSymbol(uri)
+ if not symbols then
+ return nil
+ end
+
+ local packedSymbols = packSymbols(symbols)
+
+ return packedSymbols
+end
diff --git a/script/core/find-source.lua b/script/core/find-source.lua
new file mode 100644
index 00000000..32de102c
--- /dev/null
+++ b/script/core/find-source.lua
@@ -0,0 +1,14 @@
+local guide = require 'parser.guide'
+
+return function (ast, offset, accept)
+ local len = math.huge
+ local result
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ local start, finish = guide.getStartFinish(source)
+ if finish - start < len and accept[source.type] then
+ result = source
+ len = finish - start
+ end
+ end)
+ return result
+end
diff --git a/script/core/highlight.lua b/script/core/highlight.lua
new file mode 100644
index 00000000..d7671df2
--- /dev/null
+++ b/script/core/highlight.lua
@@ -0,0 +1,252 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm'
+local define = require 'proto.define'
+local findSource = require 'core.find-source'
+
+local function eachRef(source, callback)
+ local results = guide.requestReference(source)
+ for i = 1, #results do
+ callback(results[i])
+ end
+end
+
+local function eachField(source, callback)
+ local isGlobal = guide.isGlobal(source)
+ local results = guide.requestReference(source)
+ for i = 1, #results do
+ local res = results[i]
+ if isGlobal == guide.isGlobal(res) then
+ callback(res)
+ end
+ end
+end
+
+local function eachLocal(source, callback)
+ callback(source)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback(ref)
+ end
+ end
+end
+
+local function find(source, uri, callback)
+ if source.type == 'local' then
+ eachLocal(source, callback)
+ elseif source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ eachLocal(source.node, callback)
+ elseif source.type == 'field'
+ or source.type == 'method' then
+ eachField(source.parent, callback)
+ elseif source.type == 'getindex'
+ or source.type == 'setindex'
+ or source.type == 'tableindex' then
+ eachField(source, callback)
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ eachField(source, callback)
+ elseif source.type == 'goto'
+ or source.type == 'label' then
+ eachRef(source, callback)
+ elseif source.type == 'string'
+ and source.parent.index == source then
+ eachField(source.parent, callback)
+ elseif source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number'
+ or source.type == 'nil' then
+ callback(source)
+ end
+end
+
+local function checkInIf(source, text, offset)
+ -- 检查 end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if offset >= endA
+ and offset <= endB
+ and text:sub(endA, endB) == 'end' then
+ return true
+ end
+ -- 检查每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ if offset >= start and offset <= finish then
+ return true
+ end
+ end
+ end
+ return false
+end
+
+local function makeIf(source, text, callback)
+ -- end
+ local endA = source.finish - #'end' + 1
+ local endB = source.finish
+ if text:sub(endA, endB) == 'end' then
+ callback(endA, endB)
+ end
+ -- 每个子模块
+ for _, block in ipairs(source) do
+ for i = 1, #block.keyword, 2 do
+ local start = block.keyword[i]
+ local finish = block.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ return false
+end
+
+local function findKeyWord(ast, text, offset, callback)
+ guide.eachSourceContain(ast.ast, offset, function (source)
+ if source.type == 'do'
+ or source.type == 'function'
+ or source.type == 'loop'
+ or source.type == 'in'
+ or source.type == 'while'
+ or source.type == 'repeat' then
+ local ok
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ if offset >= start and offset <= finish then
+ ok = true
+ break
+ end
+ end
+ if ok then
+ for i = 1, #source.keyword, 2 do
+ local start = source.keyword[i]
+ local finish = source.keyword[i+1]
+ callback(start, finish)
+ end
+ end
+ elseif source.type == 'if' then
+ local ok = checkInIf(source, text, offset)
+ if ok then
+ makeIf(source, text, callback)
+ end
+ end
+ end)
+end
+
+local accept = {
+ ['label'] = true,
+ ['goto'] = true,
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['tablefield'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['string'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+ ['nil'] = true,
+}
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local text = files.getText(uri)
+ local results = {}
+ local mark = {}
+
+ local source = findSource(ast, offset, accept)
+ if source then
+ find(source, uri, function (target)
+ local kind
+ if target.type == 'getfield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setfield'
+ or target.type == 'tablefield' then
+ target = target.field
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setmethod' then
+ target = target.method
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'field' then
+ if target.parent.type == 'getfield' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'method' then
+ if target.parent.type == 'getmethod' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'index' then
+ if target.parent.type == 'getindex' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'index' then
+ if target.parent.type == 'getindex' then
+ kind = define.DocumentHighlightKind.Read
+ else
+ kind = define.DocumentHighlightKind.Write
+ end
+ elseif target.type == 'setindex'
+ or target.type == 'tableindex' then
+ target = target.index
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'getlocal'
+ or target.type == 'getglobal'
+ or target.type == 'goto' then
+ kind = define.DocumentHighlightKind.Read
+ elseif target.type == 'setlocal'
+ or target.type == 'local'
+ or target.type == 'setglobal'
+ or target.type == 'label' then
+ kind = define.DocumentHighlightKind.Write
+ elseif target.type == 'string'
+ or target.type == 'boolean'
+ or target.type == 'number'
+ or target.type == 'nil' then
+ kind = define.DocumentHighlightKind.Text
+ else
+ return
+ end
+ if mark[target] then
+ return
+ end
+ mark[target] = true
+ results[#results+1] = {
+ start = target.start,
+ finish = target.finish,
+ kind = kind,
+ }
+ end)
+ end
+
+ findKeyWord(ast, text, offset, function (start, finish)
+ results[#results+1] = {
+ start = start,
+ finish = finish,
+ kind = define.DocumentHighlightKind.Write
+ }
+ end)
+
+ if #results == 0 then
+ return nil
+ end
+ return results
+end
diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua
new file mode 100644
index 00000000..9cd19f02
--- /dev/null
+++ b/script/core/hover/arg.lua
@@ -0,0 +1,71 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local function optionalArg(arg)
+ if not arg.bindDocs then
+ return false
+ end
+ local name = arg[1]
+ for _, doc in ipairs(arg.bindDocs) do
+ if doc.type == 'doc.param' and doc.param[1] == name then
+ return doc.optional
+ end
+ end
+end
+
+local function asFunction(source, oop)
+ if not source.args then
+ return ''
+ end
+ local args = {}
+ for i = 1, #source.args do
+ local arg = source.args[i]
+ local name = arg.name or guide.getName(arg)
+ if name then
+ args[i] = ('%s%s: %s'):format(
+ name,
+ optionalArg(arg) and '?' or '',
+ vm.getInferType(arg)
+ )
+ else
+ args[i] = ('%s'):format(vm.getInferType(arg))
+ end
+ end
+ local methodDef
+ local parent = source.parent
+ if parent and parent.type == 'setmethod' then
+ methodDef = true
+ end
+ if not methodDef and oop then
+ return table.concat(args, ', ', 2)
+ else
+ return table.concat(args, ', ')
+ end
+end
+
+local function asDocFunction(source)
+ if not source.args then
+ return ''
+ end
+ local args = {}
+ for i = 1, #source.args do
+ local arg = source.args[i]
+ local name = arg.name[1]
+ args[i] = ('%s%s: %s'):format(
+ name,
+ arg.optional and '?' or '',
+ vm.getInferType(arg.extends)
+ )
+ end
+ return table.concat(args, ', ')
+end
+
+return function (source, oop)
+ if source.type == 'function' then
+ return asFunction(source, oop)
+ end
+ if source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ end
+ return ''
+end
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
new file mode 100644
index 00000000..7d89ee6c
--- /dev/null
+++ b/script/core/hover/description.lua
@@ -0,0 +1,204 @@
+local vm = require 'vm'
+local ws = require 'workspace'
+local furi = require 'file-uri'
+local files = require 'files'
+local guide = require 'parser.guide'
+local markdown = require 'provider.markdown'
+local config = require 'config'
+local lang = require 'language'
+
+local function asStringInRequire(source, literal)
+ local rootPath = ws.path or ''
+ local parent = source.parent
+ if parent and parent.type == 'callargs' then
+ local result, searchers
+ local call = parent.parent
+ local func = call.node
+ local libName = vm.getLibraryName(func)
+ if not libName then
+ return
+ end
+ if libName == 'require' then
+ result, searchers = ws.findUrisByRequirePath(literal)
+ elseif libName == 'dofile'
+ or libName == 'loadfile' then
+ result = ws.findUrisByFilePath(literal)
+ end
+ if result and #result > 0 then
+ for i, uri in ipairs(result) do
+ local searcher = searchers and furi.decode(searchers[uri])
+ uri = files.getOriginUri(uri)
+ local path = furi.decode(uri)
+ if files.eq(path:sub(1, #rootPath), rootPath) then
+ path = path:sub(#rootPath + 1)
+ end
+ path = path:gsub('^[/\\]*', '')
+ if vm.isMetaFile(uri) then
+ result[i] = ('* [[meta]](%s)'):format(uri)
+ elseif searcher then
+ searcher = searcher:sub(#rootPath + 1)
+ searcher = ws.normalize(searcher)
+ result[i] = ('* [%s](%s) %s'):format(path, uri, lang.script('HOVER_USE_LUA_PATH', searcher))
+ else
+ result[i] = ('* [%s](%s)'):format(path, uri)
+ end
+ end
+ table.sort(result)
+ local md = markdown()
+ md:add('md', table.concat(result, '\n'))
+ return md:string()
+ end
+ end
+end
+
+local function asStringView(source, literal)
+ -- 内部包含转义符?
+ local rawLen = source.finish - source.start - 2 * #source[2] + 1
+ if config.config.hover.viewString
+ and (source[2] == '"' or source[2] == "'")
+ and rawLen > #literal then
+ local view = literal
+ local max = config.config.hover.viewStringMax
+ if #view > max then
+ view = view:sub(1, max) .. '...'
+ end
+ local md = markdown()
+ md:add('txt', view)
+ return md:string()
+ end
+end
+
+local function asString(source)
+ local literal = guide.getLiteral(source)
+ if type(literal) ~= 'string' then
+ return nil
+ end
+ return asStringInRequire(source, literal)
+ or asStringView(source, literal)
+end
+
+local function getBindComment(docGroup, base)
+ local lines = {}
+ for _, doc in ipairs(docGroup) do
+ if doc.type == 'doc.comment' then
+ lines[#lines+1] = doc.comment.text:sub(2)
+ elseif #lines > 0 and not base then
+ break
+ elseif doc == base then
+ break
+ else
+ lines = {}
+ end
+ end
+ if #lines == 0 then
+ return nil
+ end
+ return table.concat(lines, '\n')
+end
+
+local function buildEnumChunk(docType, name)
+ local enums = vm.getDocEnums(docType)
+ if #enums == 0 then
+ return
+ end
+ local types = {}
+ for _, tp in ipairs(docType.types) do
+ types[#types+1] = tp[1]
+ end
+ local lines = {}
+ lines[#lines+1] = ('%s: %s'):format(name, table.concat(types))
+ for _, enum in ipairs(enums) do
+ lines[#lines+1] = (' %s %s%s'):format(
+ (enum.default and '->')
+ or (enum.additional and '+>')
+ or ' |',
+ enum[1],
+ enum.comment and (' -- %s'):format(enum.comment) or ''
+ )
+ end
+ return table.concat(lines, '\n')
+end
+
+local function getBindEnums(docGroup)
+ local mark = {}
+ local chunks = {}
+ local returnIndex = 0
+ for _, doc in ipairs(docGroup) do
+ if doc.type == 'doc.param' then
+ local name = doc.param[1]
+ if mark[name] then
+ goto CONTINUE
+ end
+ mark[name] = true
+ chunks[#chunks+1] = buildEnumChunk(doc.extends, name)
+ elseif doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ returnIndex = returnIndex + 1
+ local name = rtn.name and rtn.name[1] or ('(return %d)'):format(returnIndex)
+ if mark[name] then
+ goto CONTINUE
+ end
+ mark[name] = true
+ chunks[#chunks+1] = buildEnumChunk(rtn, name)
+ end
+ end
+ ::CONTINUE::
+ end
+ if #chunks == 0 then
+ return nil
+ end
+ return table.concat(chunks, '\n\n')
+end
+
+local function tryDocFieldUpComment(source)
+ if source.type ~= 'doc.field' then
+ return
+ end
+ if not source.bindGroup then
+ return
+ end
+ local comment = getBindComment(source.bindGroup, source)
+ return comment
+end
+
+local function tryDocComment(source)
+ if not source.bindDocs then
+ return
+ end
+ local comment = getBindComment(source.bindDocs)
+ local enums = getBindEnums(source.bindDocs)
+ local md = markdown()
+ if comment then
+ md:add('md', comment)
+ end
+ if enums then
+ md:add('lua', enums)
+ end
+ return md:string()
+end
+
+local function tryDocOverloadToComment(source)
+ if source.type ~= 'doc.type.function' then
+ return
+ end
+ local doc = source.parent
+ if doc.type ~= 'doc.overload'
+ or not doc.bindSources then
+ return
+ end
+ for _, src in ipairs(doc.bindSources) do
+ local md = tryDocComment(src)
+ if md then
+ return md
+ end
+ end
+end
+
+return function (source)
+ if source.type == 'string' then
+ return asString(source)
+ end
+ return tryDocOverloadToComment(source)
+ or tryDocFieldUpComment(source)
+ or tryDocComment(source)
+end
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
new file mode 100644
index 00000000..96e01ab5
--- /dev/null
+++ b/script/core/hover/init.lua
@@ -0,0 +1,164 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local getLabel = require 'core.hover.label'
+local getDesc = require 'core.hover.description'
+local util = require 'utility'
+local findSource = require 'core.find-source'
+local lang = require 'language'
+
+local function eachFunctionAndOverload(value, callback)
+ callback(value)
+ if not value.bindDocs then
+ return
+ end
+ for _, doc in ipairs(value.bindDocs) do
+ if doc.type == 'doc.overload' then
+ callback(doc.overload)
+ end
+ end
+end
+
+local function getHoverAsFunction(source)
+ local values = vm.getDefs(source, 'deep')
+ local desc = getDesc(source)
+ local labels = {}
+ local defs = 0
+ local protos = 0
+ local other = 0
+ local oop = source.type == 'method'
+ or source.type == 'getmethod'
+ or source.type == 'setmethod'
+ local mark = {}
+ for _, def in ipairs(values) do
+ def = guide.getObjectValue(def) or def
+ if def.type == 'function'
+ or def.type == 'doc.type.function' then
+ eachFunctionAndOverload(def, function (value)
+ if mark[value] then
+ return
+ end
+ mark[value] =true
+ local label = getLabel(value, oop)
+ if label then
+ defs = defs + 1
+ labels[label] = (labels[label] or 0) + 1
+ if labels[label] == 1 then
+ protos = protos + 1
+ end
+ end
+ desc = desc or getDesc(value)
+ end)
+ elseif def.type == 'table'
+ or def.type == 'boolean'
+ or def.type == 'string'
+ or def.type == 'number' then
+ other = other + 1
+ desc = desc or getDesc(def)
+ end
+ end
+
+ if defs == 1 and other == 0 then
+ return {
+ label = next(labels),
+ source = source,
+ description = desc,
+ }
+ end
+
+ local lines = {}
+ if defs > 1 then
+ lines[#lines+1] = lang.script('HOVER_MULTI_DEF_PROTO', defs, protos)
+ end
+ if other > 0 then
+ lines[#lines+1] = lang.script('HOVER_MULTI_PROTO_NOT_FUNC', other)
+ end
+ if defs > 1 then
+ for label, count in util.sortPairs(labels) do
+ lines[#lines+1] = ('(%d) %s'):format(count, label)
+ end
+ else
+ lines[#lines+1] = next(labels)
+ end
+ local label = table.concat(lines, '\n')
+ return {
+ label = label,
+ source = source,
+ description = desc,
+ }
+end
+
+local function getHoverAsValue(source)
+ local oop = source.type == 'method'
+ or source.type == 'getmethod'
+ or source.type == 'setmethod'
+ local label = getLabel(source, oop)
+ local desc = getDesc(source)
+ if not desc then
+ local values = vm.getDefs(source, 'deep')
+ for _, def in ipairs(values) do
+ desc = getDesc(def)
+ if desc then
+ break
+ end
+ end
+ end
+ return {
+ label = label,
+ source = source,
+ description = desc,
+ }
+end
+
+local function getHoverAsDocName(source)
+ local label = getLabel(source)
+ local desc = getDesc(source)
+ return {
+ label = label,
+ source = source,
+ description = desc,
+ }
+end
+
+local function getHover(source)
+ if source.type == 'doc.type.name' then
+ return getHoverAsDocName(source)
+ end
+ local isFunction = vm.hasInferType(source, 'function', 'deep')
+ if isFunction then
+ return getHoverAsFunction(source)
+ else
+ return getHoverAsValue(source)
+ end
+end
+
+local accept = {
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['string'] = true,
+ ['number'] = true,
+ ['doc.type.name'] = true,
+}
+
+local function getHoverByUri(uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local source = findSource(ast, offset, accept)
+ if not source then
+ return nil
+ end
+ local hover = getHover(source)
+ return hover
+end
+
+return {
+ get = getHover,
+ byUri = getHoverByUri,
+}
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
new file mode 100644
index 00000000..d785bc27
--- /dev/null
+++ b/script/core/hover/label.lua
@@ -0,0 +1,211 @@
+local buildName = require 'core.hover.name'
+local buildArg = require 'core.hover.arg'
+local buildReturn = require 'core.hover.return'
+local buildTable = require 'core.hover.table'
+local vm = require 'vm'
+local util = require 'utility'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local config = require 'config'
+local files = require 'files'
+
+local function asFunction(source, oop)
+ local name = buildName(source, oop)
+ local arg = buildArg(source, oop)
+ local rtn = buildReturn(source)
+ local lines = {}
+ lines[1] = ('function %s(%s)'):format(name, arg)
+ lines[2] = rtn
+ return table.concat(lines, '\n')
+end
+
+local function asDocFunction(source)
+ local name = buildName(source)
+ local arg = buildArg(source)
+ local rtn = buildReturn(source)
+ local lines = {}
+ lines[1] = ('function %s(%s)'):format(name, arg)
+ lines[2] = rtn
+ return table.concat(lines, '\n')
+end
+
+local function asDocTypeName(source)
+ for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ if doc.type == 'doc.class.name' then
+ return 'class ' .. source[1]
+ end
+ if doc.type == 'doc.alias.name' then
+ local extends = doc.parent.extends
+ return lang.script('HOVER_EXTENDS', vm.getInferType(extends))
+ end
+ end
+end
+
+local function asValue(source, title)
+ local name = buildName(source)
+ local infers = vm.getInfers(source, 'deep')
+ local type = vm.getInferType(source, 'deep')
+ local class = vm.getClass(source, 'deep')
+ local literal = vm.getInferLiteral(source, 'deep')
+ local cont
+ if type ~= 'string' and not type:find('%[%]$') then
+ if #vm.getFields(source, 'deep') > 0
+ or vm.hasInferType(source, 'table', 'deep') then
+ cont = buildTable(source)
+ end
+ end
+ local pack = {}
+ pack[#pack+1] = title
+ pack[#pack+1] = name .. ':'
+ if cont and type == 'table' then
+ type = nil
+ end
+ if class then
+ pack[#pack+1] = class
+ else
+ pack[#pack+1] = type
+ end
+ if literal then
+ pack[#pack+1] = '='
+ pack[#pack+1] = literal
+ end
+ if cont then
+ pack[#pack+1] = cont
+ end
+ return table.concat(pack, ' ')
+end
+
+local function asLocal(source)
+ return asValue(source, 'local')
+end
+
+local function asGlobal(source)
+ return asValue(source, 'global')
+end
+
+local function isGlobalField(source)
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ if source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod' then
+ local node = source.node
+ if node.type == 'setglobal'
+ or node.type == 'getglobal' then
+ return true
+ end
+ return isGlobalField(node)
+ elseif source.type == 'tablefield' then
+ local parent = source.parent
+ if parent.type == 'setglobal'
+ or parent.type == 'getglobal' then
+ return true
+ end
+ return isGlobalField(parent)
+ else
+ return false
+ end
+end
+
+local function asField(source)
+ if isGlobalField(source) then
+ return asGlobal(source)
+ end
+ return asValue(source, 'field')
+end
+
+local function asDocField(source)
+ local name = source.field[1]
+ local class
+ for _, doc in ipairs(source.bindGroup) do
+ if doc.type == 'doc.class' then
+ class = doc
+ break
+ end
+ end
+ if not class then
+ return ('field ?.%s: %s'):format(
+ name,
+ vm.getInferType(source.extends)
+ )
+ end
+ return ('field %s.%s: %s'):format(
+ class.class[1],
+ name,
+ vm.getInferType(source.extends)
+ )
+end
+
+local function asString(source)
+ local str = source[1]
+ if type(str) ~= 'string' then
+ return ''
+ end
+ local len = #str
+ local charLen = util.utf8Len(str, 1, -1)
+ if len == charLen then
+ return lang.script('HOVER_STRING_BYTES', len)
+ else
+ return lang.script('HOVER_STRING_CHARACTERS', len, charLen)
+ end
+end
+
+local function formatNumber(n)
+ local str = ('%.10f'):format(n)
+ str = str:gsub('%.?0*$', '')
+ return str
+end
+
+local function asNumber(source)
+ if not config.config.hover.viewNumber then
+ return nil
+ end
+ local num = source[1]
+ if type(num) ~= 'number' then
+ return nil
+ end
+ local uri = guide.getUri(source)
+ local text = files.getText(uri)
+ if not text then
+ return nil
+ end
+ local raw = text:sub(source.start, source.finish)
+ if not raw or not raw:find '[^%-%d%.]' then
+ return nil
+ end
+ return formatNumber(num)
+end
+
+return function (source, oop)
+ if source.type == 'function' then
+ return asFunction(source, oop)
+ elseif source.type == 'local'
+ or source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ return asLocal(source)
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return asGlobal(source)
+ elseif source.type == 'getfield'
+ or source.type == 'setfield'
+ or source.type == 'getmethod'
+ or source.type == 'setmethod'
+ or source.type == 'tablefield'
+ or source.type == 'field'
+ or source.type == 'method' then
+ return asField(source)
+ elseif source.type == 'string' then
+ return asString(source)
+ elseif source.type == 'number' then
+ return asNumber(source)
+ elseif source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ elseif source.type == 'doc.type.name' then
+ return asDocTypeName(source)
+ elseif source.type == 'doc.field' then
+ return asDocField(source)
+ end
+end
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
new file mode 100644
index 00000000..9ad32e09
--- /dev/null
+++ b/script/core/hover/name.lua
@@ -0,0 +1,101 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local buildName
+
+local function asLocal(source)
+ local name = guide.getName(source)
+ if not source.attrs then
+ return name
+ end
+ local label = {}
+ label[#label+1] = name
+ for _, attr in ipairs(source.attrs) do
+ label[#label+1] = ('<%s>'):format(attr[1])
+ end
+ return table.concat(label, ' ')
+end
+
+local function asField(source, oop)
+ local class
+ if source.node.type ~= 'getglobal' then
+ class = vm.getClass(source.node, 'deep')
+ end
+ local node = class or guide.getName(source.node) or '?'
+ local method = guide.getName(source)
+ if oop then
+ return ('%s:%s'):format(node, method)
+ else
+ return ('%s.%s'):format(node, method)
+ end
+end
+
+local function asTableField(source)
+ if not source.field then
+ return
+ end
+ return guide.getName(source.field)
+end
+
+local function asGlobal(source)
+ return guide.getName(source)
+end
+
+local function asDocFunction(source)
+ local doc = guide.getParentType(source, 'doc.type')
+ or guide.getParentType(source, 'doc.overload')
+ if not doc or not doc.bindSources then
+ return ''
+ end
+ for _, src in ipairs(doc.bindSources) do
+ local name = buildName(src)
+ if name ~= '' then
+ return name
+ end
+ end
+ return ''
+end
+
+local function asDocField(source)
+ return source.field[1]
+end
+
+function buildName(source, oop)
+ if oop == nil then
+ oop = source.type == 'setmethod'
+ or source.type == 'getmethod'
+ end
+ if source.type == 'local'
+ or source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ return asLocal(source) or ''
+ end
+ if source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return asGlobal(source) or ''
+ end
+ if source.type == 'setmethod'
+ or source.type == 'getmethod' then
+ return asField(source, true) or ''
+ end
+ if source.type == 'setfield'
+ or source.type == 'getfield' then
+ return asField(source, oop) or ''
+ end
+ if source.type == 'tablefield' then
+ return asTableField(source) or ''
+ end
+ if source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ end
+ if source.type == 'doc.field' then
+ return asDocField(source)
+ end
+ local parent = source.parent
+ if parent then
+ return buildName(parent, oop)
+ end
+ return ''
+end
+
+return buildName
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
new file mode 100644
index 00000000..3829dbed
--- /dev/null
+++ b/script/core/hover/return.lua
@@ -0,0 +1,125 @@
+local guide = require 'parser.guide'
+local vm = require 'vm'
+
+local function mergeTypes(returns)
+ if type(returns) == 'string' then
+ return returns
+ end
+ return guide.mergeTypes(returns)
+end
+
+local function getReturnDualByDoc(source)
+ local docs = source.bindDocs
+ if not docs then
+ return
+ end
+ local dual
+ for _, doc in ipairs(docs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ if not dual then
+ dual = {}
+ end
+ dual[#dual+1] = { rtn }
+ end
+ end
+ end
+ return dual
+end
+
+local function getReturnDualByGrammar(source)
+ if not source.returns then
+ return nil
+ end
+ local dual
+ for _, rtn in ipairs(source.returns) do
+ if not dual then
+ dual = {}
+ end
+ for n = 1, #rtn do
+ if not dual[n] then
+ dual[n] = {}
+ end
+ dual[n][#dual[n]+1] = rtn[n]
+ end
+ end
+ return dual
+end
+
+local function asFunction(source)
+ local dual = getReturnDualByDoc(source)
+ or getReturnDualByGrammar(source)
+ if not dual then
+ return
+ end
+ local returns = {}
+ for i, rtn in ipairs(dual) do
+ local line = {}
+ local types = {}
+ if i == 1 then
+ line[#line+1] = ' -> '
+ else
+ line[#line+1] = ('% 3d. '):format(i)
+ end
+ for n = 1, #rtn do
+ local values = vm.getInfers(rtn[n])
+ for _, value in ipairs(values) do
+ if value.type then
+ for tp in value.type:gmatch '[^|]+' do
+ types[#types+1] = tp
+ end
+ end
+ end
+ end
+ if #types > 0 or rtn[1] then
+ local tp = mergeTypes(types) or 'any'
+ if rtn[1].name then
+ line[#line+1] = ('%s%s: %s'):format(
+ rtn[1].name[1],
+ rtn[1].optional and '?' or '',
+ tp
+ )
+ else
+ line[#line+1] = ('%s%s'):format(
+ tp,
+ rtn[1].optional and '?' or ''
+ )
+ end
+ else
+ break
+ end
+ returns[i] = table.concat(line)
+ end
+ if #returns == 0 then
+ return nil
+ end
+ return table.concat(returns, '\n')
+end
+
+local function asDocFunction(source)
+ if not source.returns or #source.returns == 0 then
+ return nil
+ end
+ local returns = {}
+ for i, rtn in ipairs(source.returns) do
+ local rtnText = ('%s%s'):format(
+ vm.getInferType(rtn),
+ rtn.optional and '?' or ''
+ )
+ if i == 1 then
+ returns[#returns+1] = (' -> %s'):format(rtnText)
+ else
+ returns[#returns+1] = ('% 3d. %s'):format(i, rtnText)
+ end
+ end
+ return table.concat(returns, '\n')
+end
+
+return function (source)
+ if source.type == 'function' then
+ return asFunction(source)
+ end
+ if source.type == 'doc.type.function' then
+ return asDocFunction(source)
+ end
+end
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
new file mode 100644
index 00000000..02be5271
--- /dev/null
+++ b/script/core/hover/table.lua
@@ -0,0 +1,257 @@
+local vm = require 'vm'
+local util = require 'utility'
+local guide = require 'parser.guide'
+local config = require 'config'
+local lang = require 'language'
+
+local function getKey(src)
+ local key = vm.getKeyName(src)
+ if not key or #key <= 2 then
+ if not src.index then
+ return '[any]'
+ end
+ local class = vm.getClass(src.index)
+ if class then
+ return ('[%s]'):format(class)
+ end
+ local tp = vm.getInferType(src.index)
+ if tp then
+ return ('[%s]'):format(tp)
+ end
+ return '[any]'
+ end
+ local ktype = key:sub(1, 2)
+ key = key:sub(3)
+ if ktype == 's|' then
+ if key:match '^[%a_][%w_]*$' then
+ return key
+ else
+ return ('[%s]'):format(util.viewLiteral(key))
+ end
+ end
+ return ('[%s]'):format(key)
+end
+
+local function getFieldFast(src)
+ local value = guide.getObjectValue(src) or src
+ if not value then
+ return 'any'
+ end
+ if value.type == 'boolean' then
+ return value.type, util.viewLiteral(value[1])
+ end
+ if value.type == 'number'
+ or value.type == 'integer' then
+ if math.tointeger(value[1]) then
+ if config.config.runtime.version == 'Lua 5.3'
+ or config.config.runtime.version == 'Lua 5.4' then
+ return 'integer', util.viewLiteral(value[1])
+ end
+ end
+ return value.type, util.viewLiteral(value[1])
+ end
+ if value.type == 'table'
+ or value.type == 'function' then
+ return value.type
+ end
+ if value.type == 'string' then
+ local literal = value[1]
+ if type(literal) == 'string' and #literal >= 50 then
+ literal = literal:sub(1, 47) .. '...'
+ end
+ return value.type, util.viewLiteral(literal)
+ end
+end
+
+local function getFieldFull(src)
+ local tp = vm.getInferType(src)
+ --local class = vm.getClass(src)
+ local literal = vm.getInferLiteral(src)
+ if type(literal) == 'string' and #literal >= 50 then
+ literal = literal:sub(1, 47) .. '...'
+ end
+ return tp, literal
+end
+
+local function getField(src, timeUp, mark, key)
+ if src.type == 'table'
+ or src.type == 'function' then
+ return nil
+ end
+ if src.parent then
+ if src.type == 'string'
+ or src.type == 'boolean'
+ or src.type == 'number'
+ or src.type == 'integer' then
+ if src.parent.type == 'tableindex'
+ or src.parent.type == 'setindex'
+ or src.parent.type == 'getindex' then
+ if src.parent.index == src then
+ src = src.parent
+ end
+ end
+ end
+ end
+ local tp, literal
+ tp, literal = getFieldFast(src)
+ if tp then
+ return tp, literal
+ end
+ if timeUp or mark[key] then
+ return nil
+ end
+ mark[key] = true
+ tp, literal = getFieldFull(src)
+ if tp then
+ return tp, literal
+ end
+ return nil
+end
+
+local function buildAsHash(classes, literals)
+ local keys = {}
+ for k in pairs(classes) do
+ keys[#keys+1] = k
+ end
+ table.sort(keys)
+ local lines = {}
+ lines[#lines+1] = '{'
+ for _, key in ipairs(keys) do
+ local class = classes[key]
+ local literal = literals[key]
+ if literal then
+ lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ else
+ lines[#lines+1] = (' %s: %s,'):format(key, class)
+ end
+ end
+ lines[#lines+1] = '}'
+ return table.concat(lines, '\n')
+end
+
+local function buildAsConst(classes, literals)
+ local keys = {}
+ for k in pairs(classes) do
+ keys[#keys+1] = k
+ end
+ table.sort(keys, function (a, b)
+ return tonumber(literals[a]) < tonumber(literals[b])
+ end)
+ local lines = {}
+ lines[#lines+1] = '{'
+ for _, key in ipairs(keys) do
+ local class = classes[key]
+ local literal = literals[key]
+ if literal then
+ lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ else
+ lines[#lines+1] = (' %s: %s,'):format(key, class)
+ end
+ end
+ lines[#lines+1] = '}'
+ return table.concat(lines, '\n')
+end
+
+local function mergeLiteral(literals)
+ local results = {}
+ local mark = {}
+ for _, value in ipairs(literals) do
+ if not mark[value] then
+ mark[value] = true
+ results[#results+1] = value
+ end
+ end
+ if #results == 0 then
+ return nil
+ end
+ table.sort(results)
+ return table.concat(results, '|')
+end
+
+local function mergeTypes(types)
+ local results = {}
+ local mark = {
+ -- 讲道理table的keyvalue不会是nil
+ ['nil'] = true,
+ }
+ for _, tv in ipairs(types) do
+ for tp in tv:gmatch '[^|]+' do
+ if not mark[tp] then
+ mark[tp] = true
+ results[#results+1] = tp
+ end
+ end
+ end
+ return guide.mergeTypes(results)
+end
+
+local function clearClasses(classes)
+ classes['[nil]'] = nil
+ classes['[any]'] = nil
+ classes['[string]'] = nil
+end
+
+return function (source)
+ local literals = {}
+ local classes = {}
+ local clock = os.clock()
+ local timeUp
+ local mark = {}
+ local fields = vm.getFields(source, 'deep')
+ local keyCount = 0
+ for _, src in ipairs(fields) do
+ local key = getKey(src)
+ if not key then
+ goto CONTINUE
+ end
+ if not classes[key] then
+ classes[key] = {}
+ keyCount = keyCount + 1
+ end
+ if not literals[key] then
+ literals[key] = {}
+ end
+ if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then
+ timeUp = true
+ end
+ local class, literal = getField(src, timeUp, mark, key)
+ if literal == 'nil' then
+ literal = nil
+ end
+ classes[key][#classes[key]+1] = class
+ literals[key][#literals[key]+1] = literal
+ if keyCount >= 1000 then
+ break
+ end
+ ::CONTINUE::
+ end
+
+ clearClasses(classes)
+
+ for key, class in pairs(classes) do
+ literals[key] = mergeLiteral(literals[key])
+ classes[key] = mergeTypes(class)
+ end
+
+ if not next(classes) then
+ return '{}'
+ end
+
+ local intValue = true
+ for key, class in pairs(classes) do
+ if class ~= 'integer' or not tonumber(literals[key]) then
+ intValue = false
+ break
+ end
+ end
+ local result
+ if intValue then
+ result = buildAsConst(classes, literals)
+ else
+ result = buildAsHash(classes, literals)
+ end
+ if timeUp then
+ result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ end
+ return result
+end
diff --git a/script/core/keyword.lua b/script/core/keyword.lua
new file mode 100644
index 00000000..1cbeb78d
--- /dev/null
+++ b/script/core/keyword.lua
@@ -0,0 +1,264 @@
+local define = require 'proto.define'
+local guide = require 'parser.guide'
+
+local keyWordMap = {
+ {'do', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'do .. end',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[$0 end]],
+ }
+ else
+ results[#results+1] = {
+ label = 'do .. end',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+do
+ $0
+end]],
+ }
+ end
+ return true
+ end, function (ast, start)
+ return guide.eachSourceContain(ast.ast, start, function (source)
+ if source.type == 'while'
+ or source.type == 'in'
+ or source.type == 'loop' then
+ for i = 1, #source.keyword do
+ if start == source.keyword[i] then
+ return true
+ end
+ end
+ end
+ end)
+ end},
+ {'and'},
+ {'break'},
+ {'else'},
+ {'elseif', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'elseif .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[$1 then]],
+ }
+ else
+ results[#results+1] = {
+ label = 'elseif .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[elseif $1 then]],
+ }
+ end
+ return true
+ end},
+ {'end'},
+ {'false'},
+ {'for', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'for .. in',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:key, value} in ${2:pairs(${3:t})} do
+ $0
+end]]
+ }
+ results[#results+1] = {
+ label = 'for i = ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:i} = ${2:1}, ${3:10, 1} do
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'for .. in',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+for ${1:key, value} in ${2:pairs(${3:t})} do
+ $0
+end]]
+ }
+ results[#results+1] = {
+ label = 'for i = ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+for ${1:i} = ${2:1}, ${3:10, 1} do
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'function', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'function ()',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+$1($2)
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'function ()',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+function $1($2)
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'goto'},
+ {'if', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'if .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+$1 then
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'if .. then',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+if $1 then
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'in', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'in ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:pairs(${2:t})} do
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'in ..',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+in ${1:pairs(${2:t})} do
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+ {'local', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'local function',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+function $1($2)
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'local function',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+local function $1($2)
+ $0
+end]]
+ }
+ end
+ return false
+ end},
+ {'nil'},
+ {'not'},
+ {'or'},
+ {'repeat', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'repeat .. until',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[$0 until $1]]
+ }
+ else
+ results[#results+1] = {
+ label = 'repeat .. until',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+repeat
+ $0
+until $1]]
+ }
+ end
+ return true
+ end},
+ {'return', function (hasSpace, results)
+ if not hasSpace then
+ results[#results+1] = {
+ label = 'do return end',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[do return $1end]]
+ }
+ end
+ return false
+ end},
+ {'then'},
+ {'true'},
+ {'until'},
+ {'while', function (hasSpace, results)
+ if hasSpace then
+ results[#results+1] = {
+ label = 'while .. do',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+${1:true} do
+ $0
+end]]
+ }
+ else
+ results[#results+1] = {
+ label = 'while .. do',
+ kind = define.CompletionItemKind.Snippet,
+ insertTextFormat = 2,
+ insertText = [[
+while ${1:true} do
+ $0
+end]]
+ }
+ end
+ return true
+ end},
+}
+
+return keyWordMap
diff --git a/script/core/matchkey.lua b/script/core/matchkey.lua
new file mode 100644
index 00000000..45c86eff
--- /dev/null
+++ b/script/core/matchkey.lua
@@ -0,0 +1,33 @@
+return function (me, other, fast)
+ if me == other then
+ return true
+ end
+ if me == '' then
+ return true
+ end
+ if #me > #other then
+ return false
+ end
+ local lMe = me:lower()
+ local lOther = other:lower()
+ if lMe == lOther:sub(1, #lMe) then
+ return true
+ end
+ if fast and me:sub(1, 1) ~= other:sub(1, 1) then
+ return false
+ end
+ local chars = {}
+ for i = 1, #lOther do
+ local c = lOther:sub(i, i)
+ chars[c] = (chars[c] or 0) + 1
+ end
+ for i = 1, #lMe do
+ local c = lMe:sub(i, i)
+ if chars[c] and chars[c] > 0 then
+ chars[c] = chars[c] - 1
+ else
+ return false
+ end
+ end
+ return true
+end
diff --git a/script/core/reference.lua b/script/core/reference.lua
new file mode 100644
index 00000000..d7d3df03
--- /dev/null
+++ b/script/core/reference.lua
@@ -0,0 +1,116 @@
+local guide = require 'parser.guide'
+local files = require 'files'
+local vm = require 'vm'
+local findSource = require 'core.find-source'
+
+local function isValidFunction(source, offset)
+ -- 必须点在 `function` 这个单词上才能查找函数引用
+ return offset >= source.start and offset < source.start + #'function'
+end
+
+local function sortResults(results)
+ -- 先按照顺序排序
+ table.sort(results, function (a, b)
+ local u1 = guide.getUri(a.target)
+ local u2 = guide.getUri(b.target)
+ if u1 == u2 then
+ return a.target.start < b.target.start
+ else
+ return u1 < u2
+ end
+ end)
+ -- 如果2个结果处于嵌套状态,则取范围小的那个
+ local lf, lu
+ for i = #results, 1, -1 do
+ local res = results[i].target
+ local f = res.finish
+ local uri = guide.getUri(res)
+ if lf and f > lf and uri == lu then
+ table.remove(results, i)
+ else
+ lu = uri
+ lf = f
+ end
+ end
+end
+
+local accept = {
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['label'] = true,
+ ['goto'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['setindex'] = true,
+ ['getindex'] = true,
+ ['tableindex'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['function'] = true,
+
+ ['doc.type.name'] = true,
+ ['doc.class.name'] = true,
+ ['doc.extends.name'] = true,
+ ['doc.alias.name'] = true,
+}
+
+return function (uri, offset)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+
+ local source = findSource(ast, offset, accept)
+ if not source then
+ return nil
+ end
+ if source.type == 'function' and not isValidFunction(source, offset) and not TEST then
+ return nil
+ end
+
+ local results = {}
+ for _, src in ipairs(vm.getRefs(source, 'deep')) do
+ local root = guide.getRoot(src)
+ if not root then
+ goto CONTINUE
+ end
+ if vm.isMetaFile(root.uri) then
+ goto CONTINUE
+ end
+ if ( src.type == 'doc.class.name'
+ or src.type == 'doc.type.name'
+ )
+ and source.type ~= 'doc.type.name'
+ and source.type ~= 'doc.class.name' then
+ goto CONTINUE
+ end
+ if src.type == 'setfield'
+ or src.type == 'getfield'
+ or src.type == 'tablefield' then
+ src = src.field
+ elseif src.type == 'setindex'
+ or src.type == 'getindex'
+ or src.type == 'tableindex' then
+ src = src.index
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ src = src.method
+ elseif src.type == 'table' and src.parent.type ~= 'return' then
+ goto CONTINUE
+ end
+ results[#results+1] = {
+ target = src,
+ uri = files.getOriginUri(root.uri),
+ }
+ ::CONTINUE::
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ sortResults(results)
+
+ return results
+end
diff --git a/script/core/rename.lua b/script/core/rename.lua
new file mode 100644
index 00000000..89298bdd
--- /dev/null
+++ b/script/core/rename.lua
@@ -0,0 +1,448 @@
+local files = require 'files'
+local vm = require 'vm'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local define = require 'proto.define'
+local util = require 'utility'
+local findSource = require 'core.find-source'
+local ws = require 'workspace'
+
+local Forcing
+
+local function askForcing(str)
+ -- TODO 总是可以替换
+ do return true end
+ if TEST then
+ return true
+ end
+ if Forcing ~= nil then
+ return Forcing
+ end
+ local version = files.globalVersion
+ -- TODO
+ local item = proto.awaitRequest('window/showMessageRequest', {
+ type = define.MessageType.Warning,
+ message = ('[%s]不是有效的标识符,是否强制替换?'):format(str),
+ actions = {
+ {
+ title = '强制替换',
+ },
+ {
+ title = '取消',
+ },
+ }
+ })
+ if version ~= files.globalVersion then
+ Forcing = false
+ proto.notify('window/showMessage', {
+ type = define.MessageType.Warning,
+ message = '文件发生了变化,替换取消。'
+ })
+ return false
+ end
+ if not item then
+ Forcing = false
+ return false
+ end
+ if item.title == '强制替换' then
+ Forcing = true
+ return true
+ else
+ Forcing = false
+ return false
+ end
+end
+
+local function askForMultiChange(results, newname)
+ -- TODO 总是可以替换
+ do return true end
+ if TEST then
+ return true
+ end
+ local uris = {}
+ for _, result in ipairs(results) do
+ local uri = result.uri
+ if not uris[uri] then
+ uris[uri] = 0
+ uris[#uris+1] = uri
+ end
+ uris[uri] = uris[uri] + 1
+ end
+ if #uris <= 1 then
+ return true
+ end
+
+ local version = files.globalVersion
+ -- TODO
+ local item = proto.awaitRequest('window/showMessageRequest', {
+ type = define.MessageType.Warning,
+ message = ('将修改 %d 个文件,共 %d 处。'):format(
+ #uris,
+ #results
+ ),
+ actions = {
+ {
+ title = '继续',
+ },
+ {
+ title = '放弃',
+ },
+ }
+ })
+ if version ~= files.globalVersion then
+ proto.notify('window/showMessage', {
+ type = define.MessageType.Warning,
+ message = '文件发生了变化,替换取消。'
+ })
+ return false
+ end
+ if item and item.title == '继续' then
+ local fileList = {}
+ for _, uri in ipairs(uris) do
+ fileList[#fileList+1] = ('%s (%d)'):format(uri, uris[uri])
+ end
+
+ log.debug(('Renamed [%s]\r\n%s'):format(newname, table.concat(fileList, '\r\n')))
+ return true
+ end
+ return false
+end
+
+local function trim(str)
+ return str:match '^%s*(%S+)%s*$'
+end
+
+local function isValidName(str)
+ return str:match '^[%a_][%w_]*$'
+end
+
+local function isValidGlobal(str)
+ for s in str:gmatch '[^%.]*' do
+ if not isValidName(trim(s)) then
+ return false
+ end
+ end
+ return true
+end
+
+local function isValidFunctionName(str)
+ if isValidGlobal(str) then
+ return true
+ end
+ local pos = str:find(':', 1, true)
+ if not pos then
+ return false
+ end
+ return isValidGlobal(trim(str:sub(1, pos-1)))
+ and isValidName(trim(str:sub(pos+1)))
+end
+
+local function isFunctionGlobalName(source)
+ local parent = source.parent
+ if parent.type ~= 'setglobal' then
+ return false
+ end
+ local value = parent.value
+ if not value.type ~= 'function' then
+ return false
+ end
+ return value.start <= parent.start
+end
+
+local function renameLocal(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return
+ end
+ if askForcing(newname) then
+ callback(source, source.start, source.finish, newname)
+ end
+end
+
+local function renameField(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ local parent = source.parent
+ if parent.type == 'setfield'
+ or parent.type == 'getfield' then
+ local dot = parent.dot
+ local newstr = '[' .. util.viewString(newname) .. ']'
+ callback(source, dot.start, source.finish, newstr)
+ elseif parent.type == 'tablefield' then
+ local newstr = '[' .. util.viewString(newname) .. ']'
+ callback(source, source.start, source.finish, newstr)
+ elseif parent.type == 'getmethod' then
+ if not askForcing(newname) then
+ return false
+ end
+ callback(source, source.start, source.finish, newname)
+ elseif parent.type == 'setmethod' then
+ local uri = guide.getUri(source)
+ local text = files.getText(uri)
+ local func = parent.value
+ -- function mt:name () end --> mt['newname'] = function (self) end
+ local newstr = string.format('%s[%s] = function '
+ , text:sub(parent.start, parent.node.finish)
+ , util.viewString(newname)
+ )
+ callback(source, func.start, parent.finish, newstr)
+ local pl = text:find('(', parent.finish, true)
+ if pl then
+ if func.args then
+ callback(source, pl + 1, pl, 'self, ')
+ else
+ callback(source, pl + 1, pl, 'self')
+ end
+ end
+ end
+ return true
+end
+
+local function renameGlobal(source, newname, callback)
+ if isValidGlobal(newname) then
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ if isValidFunctionName(newname) then
+ if not isFunctionGlobalName(source) then
+ askForcing(newname)
+ end
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ local newstr = '_ENV[' .. util.viewString(newname) .. ']'
+ -- function name () end --> _ENV['newname'] = function () end
+ if source.value and source.value.type == 'function'
+ and source.value.start < source.start then
+ callback(source, source.value.start, source.finish, newstr .. ' = function ')
+ return true
+ end
+ callback(source, source.start, source.finish, newstr)
+ return true
+end
+
+local function ofLocal(source, newname, callback)
+ renameLocal(source, newname, callback)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ renameLocal(ref, newname, callback)
+ end
+ end
+end
+
+local function ofFieldThen(key, src, newname, callback)
+ if vm.getKeyName(src) ~= key then
+ return
+ end
+ if src.type == 'tablefield'
+ or src.type == 'getfield'
+ or src.type == 'setfield' then
+ src = src.field
+ elseif src.type == 'tableindex'
+ or src.type == 'getindex'
+ or src.type == 'setindex' then
+ src = src.index
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ src = src.method
+ end
+ if src.type == 'string' then
+ local quo = src[2]
+ local text = util.viewString(newname, quo)
+ callback(src, src.start, src.finish, text)
+ return
+ elseif src.type == 'field'
+ or src.type == 'method' then
+ local suc = renameField(src, newname, callback)
+ if not suc then
+ return
+ end
+ elseif src.type == 'setglobal'
+ or src.type == 'getglobal' then
+ local suc = renameGlobal(src, newname, callback)
+ if not suc then
+ return
+ end
+ end
+end
+
+local function ofField(source, newname, callback)
+ local key = guide.getKeyName(source)
+ local node
+ if source.type == 'tablefield'
+ or source.type == 'tableindex' then
+ node = source.parent
+ else
+ node = source.node
+ end
+ for _, src in ipairs(vm.getFields(node, 'deep')) do
+ ofFieldThen(key, src, newname, callback)
+ end
+end
+
+local function ofGlobal(source, newname, callback)
+ local key = guide.getKeyName(source)
+ for _, src in ipairs(vm.getRefs(source, 'deep')) do
+ ofFieldThen(key, src, newname, callback)
+ end
+end
+
+local function ofLabel(source, newname, callback)
+ if not isValidName(newname) and not askForcing(newname)then
+ return false
+ end
+ for _, src in ipairs(vm.getRefs(source, 'deep')) do
+ callback(src, src.start, src.finish, newname)
+ end
+end
+
+local function rename(source, newname, callback)
+ if source.type == 'label'
+ or source.type == 'goto' then
+ return ofLabel(source, newname, callback)
+ elseif source.type == 'local' then
+ return ofLocal(source, newname, callback)
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ return ofLocal(source.node, newname, callback)
+ elseif source.type == 'field'
+ or source.type == 'method'
+ or source.type == 'index' then
+ return ofField(source.parent, newname, callback)
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return ofGlobal(source, newname, callback)
+ elseif source.type == 'string'
+ or source.type == 'number'
+ or source.type == 'boolean' then
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return ofField(parent, newname, callback)
+ end
+ end
+ return
+end
+
+local function prepareRename(source)
+ if source.type == 'label'
+ or source.type == 'goto'
+ or source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'field'
+ or source.type == 'method'
+ or source.type == 'tablefield'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return source, source[1]
+ elseif source.type == 'string'
+ or source.type == 'number'
+ or source.type == 'boolean' then
+ local parent = source.parent
+ if not parent then
+ return nil
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return source, source[1]
+ end
+ return nil
+ end
+ return nil
+end
+
+local accept = {
+ ['label'] = true,
+ ['goto'] = true,
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['tablefield'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['string'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+}
+
+local m = {}
+
+function m.rename(uri, pos, newname)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local source = findSource(ast, pos, accept)
+ if not source then
+ return nil
+ end
+ local results = {}
+ local mark = {}
+
+ rename(source, newname, function (target, start, finish, text)
+ local turi = files.getOriginUri(guide.getUri(target))
+ local uid = turi .. start
+ if mark[uid] then
+ return
+ end
+ mark[uid] = true
+ if files.isLibrary(turi) then
+ return
+ end
+ results[#results+1] = {
+ start = start,
+ finish = finish,
+ text = text,
+ uri = turi,
+ }
+ end)
+
+ if Forcing == false then
+ Forcing = nil
+ return nil
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ if not askForMultiChange(results, newname) then
+ return nil
+ end
+
+ return results
+end
+
+function m.prepareRename(uri, pos)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local source = findSource(ast, pos, accept)
+ if not source then
+ return
+ end
+
+ local res, text = prepareRename(source)
+ if not res then
+ return nil
+ end
+
+ return {
+ start = source.start,
+ finish = source.finish,
+ text = text,
+ }
+end
+
+return m
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
new file mode 100644
index 00000000..e6b35cdd
--- /dev/null
+++ b/script/core/semantic-tokens.lua
@@ -0,0 +1,161 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local await = require 'await'
+local define = require 'proto.define'
+local vm = require 'vm'
+local util = require 'utility'
+
+local Care = {}
+Care['setglobal'] = function (source, results)
+ local isLib = vm.isGlobalLibraryName(source[1])
+ if not isLib then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.namespace,
+ modifieres = define.TokenModifiers.deprecated,
+ }
+ end
+end
+Care['getglobal'] = function (source, results)
+ local isLib = vm.isGlobalLibraryName(source[1])
+ if not isLib then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.namespace,
+ modifieres = define.TokenModifiers.deprecated,
+ }
+ end
+end
+Care['tablefield'] = function (source, results)
+ local field = source.field
+ if not field then
+ return
+ end
+ results[#results+1] = {
+ start = field.start,
+ finish = field.finish,
+ type = define.TokenTypes.property,
+ modifieres = define.TokenModifiers.declaration,
+ }
+end
+Care['getlocal'] = function (source, results)
+ local loc = source.node
+ -- 1. 值为函数的局部变量
+ local hasFunc
+ local node = loc.node
+ if node then
+ for _, ref in ipairs(node.ref) do
+ local def = ref.value
+ if def.type == 'function' then
+ hasFunc = true
+ break
+ end
+ end
+ end
+ if hasFunc then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.interface,
+ modifieres = define.TokenModifiers.declaration,
+ }
+ return
+ end
+ -- 2. 对象
+ if source.parent.type == 'getmethod'
+ and source.parent.node == source then
+ return
+ end
+ -- 3. 函数的参数
+ if loc.parent and loc.parent.type == 'funcargs' then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.parameter,
+ modifieres = define.TokenModifiers.declaration,
+ }
+ return
+ end
+ -- 4. 特殊变量
+ if source[1] == '_ENV'
+ or source[1] == 'self' then
+ return
+ end
+ -- 5. 其他
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.variable,
+ }
+end
+Care['setlocal'] = Care['getlocal']
+Care['doc.return.name'] = function (source, results)
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.parameter,
+ }
+end
+
+local function buildTokens(results, text, lines)
+ local tokens = {}
+ local lastLine = 0
+ local lastStartChar = 0
+ for i, source in ipairs(results) do
+ local row, col = guide.positionOf(lines, source.start)
+ local start = guide.lineRange(lines, row)
+ local ucol = util.utf8Len(text, start, start + col - 1)
+ local line = row - 1
+ local startChar = ucol - 1
+ local deltaLine = line - lastLine
+ local deltaStartChar
+ if deltaLine == 0 then
+ deltaStartChar = startChar - lastStartChar
+ else
+ deltaStartChar = startChar
+ end
+ lastLine = line
+ lastStartChar = startChar
+ -- see https://microsoft.github.io/language-server-protocol/specifications/specification-3-16/#textDocument_semanticTokens
+ local len = i * 5 - 5
+ tokens[len + 1] = deltaLine
+ tokens[len + 2] = deltaStartChar
+ tokens[len + 3] = source.finish - source.start + 1 -- length
+ tokens[len + 4] = source.type
+ tokens[len + 5] = source.modifieres or 0
+ end
+ return tokens
+end
+
+return function (uri, start, finish)
+ local ast = files.getAst(uri)
+ local lines = files.getLines(uri)
+ local text = files.getText(uri)
+ if not ast then
+ return nil
+ end
+
+ local results = {}
+ local count = 0
+ guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ local method = Care[source.type]
+ if not method then
+ return
+ end
+ method(source, results)
+ count = count + 1
+ if count % 100 == 0 then
+ await.delay()
+ end
+ end)
+
+ table.sort(results, function (a, b)
+ return a.start < b.start
+ end)
+
+ local tokens = buildTokens(results, text, lines)
+
+ return tokens
+end
diff --git a/script/core/signature.lua b/script/core/signature.lua
new file mode 100644
index 00000000..dad38924
--- /dev/null
+++ b/script/core/signature.lua
@@ -0,0 +1,106 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local hoverLabel = require 'core.hover.label'
+local hoverDesc = require 'core.hover.description'
+
+local function findNearCall(uri, ast, pos)
+ local text = files.getText(uri)
+ -- 检查 `f()$` 的情况,注意要区别于 `f($`
+ if text:sub(pos, pos) == ')' then
+ return nil
+ end
+
+ local nearCall
+ guide.eachSourceContain(ast.ast, pos, function (src)
+ if src.type == 'call'
+ or src.type == 'table'
+ or src.type == 'function' then
+ if not nearCall or nearCall.start < src.start then
+ nearCall = src
+ end
+ end
+ end)
+ if not nearCall then
+ return nil
+ end
+ if nearCall.type ~= 'call' then
+ return nil
+ end
+ return nearCall
+end
+
+local function makeOneSignature(source, oop, index)
+ local label = hoverLabel(source, oop)
+ -- 去掉返回值
+ label = label:gsub('%s*->.+', '')
+ local params = {}
+ local i = 0
+ for start, finish in label:gmatch '[%(%)%,]%s*().-()%s*%f[%(%)%,%[%]]' do
+ i = i + 1
+ params[i] = {
+ label = {start, finish-1},
+ }
+ end
+ -- 不定参数
+ if index > i and i > 0 then
+ local lastLabel = params[i].label
+ local text = label:sub(lastLabel[1], lastLabel[2])
+ if text == '...' then
+ index = i
+ end
+ end
+ return {
+ label = label,
+ params = params,
+ index = index,
+ description = hoverDesc(source),
+ }
+end
+
+local function makeSignatures(call, pos)
+ local node = call.node
+ local oop = node.type == 'method'
+ or node.type == 'getmethod'
+ or node.type == 'setmethod'
+ local index
+ local args = call.args
+ if args then
+ for i, arg in ipairs(args) do
+ if arg.start <= pos and arg.finish >= pos then
+ index = i
+ break
+ end
+ end
+ if not index then
+ index = #args + 1
+ end
+ else
+ index = 1
+ end
+ local signs = {}
+ local defs = vm.getDefs(node, 'deep')
+ for _, src in ipairs(defs) do
+ if src.type == 'function'
+ or src.type == 'doc.type.function' then
+ signs[#signs+1] = makeOneSignature(src, oop, index)
+ end
+ end
+ return signs
+end
+
+return function (uri, pos)
+ local ast = files.getAst(uri)
+ if not ast then
+ return nil
+ end
+ local call = findNearCall(uri, ast, pos)
+ if not call then
+ return nil
+ end
+ local signs = makeSignatures(call, pos)
+ if not signs or #signs == 0 then
+ return nil
+ end
+ return signs
+end
diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua
new file mode 100644
index 00000000..4fc6a854
--- /dev/null
+++ b/script/core/workspace-symbol.lua
@@ -0,0 +1,69 @@
+local files = require 'files'
+local guide = require 'parser.guide'
+local matchKey = require 'core.matchkey'
+local define = require 'proto.define'
+local await = require 'await'
+
+local function buildSource(uri, source, key, results)
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal' then
+ local name = source[1]
+ if matchKey(key, name) then
+ results[#results+1] = {
+ name = name,
+ kind = define.SymbolKind.Variable,
+ uri = uri,
+ range = { source.start, source.finish },
+ }
+ end
+ elseif source.type == 'setfield'
+ or source.type == 'tablefield' then
+ local field = source.field
+ local name = field[1]
+ if matchKey(key, name) then
+ results[#results+1] = {
+ name = name,
+ kind = define.SymbolKind.Field,
+ uri = uri,
+ range = { field.start, field.finish },
+ }
+ end
+ elseif source.type == 'setmethod' then
+ local method = source.method
+ local name = method[1]
+ if matchKey(key, name) then
+ results[#results+1] = {
+ name = name,
+ kind = define.SymbolKind.Method,
+ uri = uri,
+ range = { method.start, method.finish },
+ }
+ end
+ end
+end
+
+local function searchFile(uri, key, results)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+
+ guide.eachSource(ast.ast, function (source)
+ buildSource(uri, source, key, results)
+ end)
+end
+
+return function (key)
+ local results = {}
+
+ for uri in files.eachFile() do
+ searchFile(files.getOriginUri(uri), key, results)
+ if #results > 1000 then
+ break
+ end
+ await.delay()
+ end
+
+ return results
+end