summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/core/code-action.lua27
-rw-r--r--script/core/collector.lua71
-rw-r--r--script/core/command/removeSpace.lua16
-rw-r--r--script/core/command/solve.lua11
-rw-r--r--script/core/completion.lua142
-rw-r--r--script/core/definition.lua20
-rw-r--r--script/core/diagnostics/ambiguity-1.lua4
-rw-r--r--script/core/diagnostics/circle-doc-class.lua14
-rw-r--r--script/core/diagnostics/close-non-object.lua9
-rw-r--r--script/core/diagnostics/code-after-break.lua10
-rw-r--r--script/core/diagnostics/count-down-loop.lua8
-rw-r--r--script/core/diagnostics/deprecated.lua22
-rw-r--r--script/core/diagnostics/doc-field-no-class.lua2
-rw-r--r--script/core/diagnostics/duplicate-doc-class.lua16
-rw-r--r--script/core/diagnostics/duplicate-doc-field.lua2
-rw-r--r--script/core/diagnostics/duplicate-doc-param.lua2
-rw-r--r--script/core/diagnostics/duplicate-index.lua12
-rw-r--r--script/core/diagnostics/duplicate-set-field.lua14
-rw-r--r--script/core/diagnostics/empty-block.lua4
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua4
-rw-r--r--script/core/diagnostics/init.lua8
-rw-r--r--script/core/diagnostics/lowercase-global.lua4
-rw-r--r--script/core/diagnostics/newfield-call.lua4
-rw-r--r--script/core/diagnostics/newline-call.lua4
-rw-r--r--script/core/diagnostics/no-implicit-any.lua9
-rw-r--r--script/core/diagnostics/redefined-local.lua7
-rw-r--r--script/core/diagnostics/redundant-parameter.lua14
-rw-r--r--script/core/diagnostics/redundant-value.lua2
-rw-r--r--script/core/diagnostics/trailing-space.lua4
-rw-r--r--script/core/diagnostics/unbalanced-assignments.lua4
-rw-r--r--script/core/diagnostics/undefined-doc-class.lua6
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua27
-rw-r--r--script/core/diagnostics/undefined-doc-param.lua4
-rw-r--r--script/core/diagnostics/undefined-env-child.lua13
-rw-r--r--script/core/diagnostics/undefined-field.lua71
-rw-r--r--script/core/diagnostics/undefined-global.lua21
-rw-r--r--script/core/diagnostics/unknown-diag-code.lua2
-rw-r--r--script/core/diagnostics/unused-function.lua4
-rw-r--r--script/core/diagnostics/unused-label.lua4
-rw-r--r--script/core/diagnostics/unused-local.lua7
-rw-r--r--script/core/diagnostics/unused-vararg.lua4
-rw-r--r--script/core/document-symbol.lua12
-rw-r--r--script/core/find-source.lua2
-rw-r--r--script/core/folding.lua4
-rw-r--r--script/core/generic.lua234
-rw-r--r--script/core/guide2.lua (renamed from script/core/guide.lua)5
-rw-r--r--script/core/highlight.lua64
-rw-r--r--script/core/hint.lua26
-rw-r--r--script/core/hover/arg.lua13
-rw-r--r--script/core/hover/description.lua17
-rw-r--r--script/core/hover/init.lua13
-rw-r--r--script/core/hover/label.lua59
-rw-r--r--script/core/hover/name.lua6
-rw-r--r--script/core/hover/return.lua28
-rw-r--r--script/core/hover/table.lua282
-rw-r--r--script/core/infer.lua639
-rw-r--r--script/core/keyword.lua2
-rw-r--r--script/core/noder.lua1007
-rw-r--r--script/core/reference.lua25
-rw-r--r--script/core/rename.lua25
-rw-r--r--script/core/searcher.lua838
-rw-r--r--script/core/semantic-tokens.lua5
-rw-r--r--script/core/signature.lua9
-rw-r--r--script/core/type-formatting.lua4
-rw-r--r--script/core/workspace-symbol.lua4
-rw-r--r--script/files.lua17
-rw-r--r--script/parser/ast.lua10
-rw-r--r--script/parser/compile.lua5
-rw-r--r--script/parser/guide.lua401
-rw-r--r--script/parser/luadoc.lua81
-rw-r--r--script/proto/define.lua19
-rw-r--r--script/provider/diagnostic.lua2
-rw-r--r--script/service/service.lua1
-rw-r--r--script/utility.lua4
-rw-r--r--script/vm/eachDef.lua50
-rw-r--r--script/vm/eachField.lua109
-rw-r--r--script/vm/eachRef.lua49
-rw-r--r--script/vm/getClass.lua64
-rw-r--r--script/vm/getDocs.lua167
-rw-r--r--script/vm/getGlobals.lua37
-rw-r--r--script/vm/getInfer.lua104
-rw-r--r--script/vm/getLibrary.lua7
-rw-r--r--script/vm/getLinks.lua20
-rw-r--r--script/vm/getMeta.lua53
-rw-r--r--script/vm/guideInterface.lua14
-rw-r--r--script/vm/init.lua4
-rw-r--r--script/vm/vm.lua11
87 files changed, 3825 insertions, 1360 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
index bae3df81..3fd58c81 100644
--- a/script/core/code-action.lua
+++ b/script/core/code-action.lua
@@ -1,14 +1,13 @@
-local files = require 'files'
-local lang = require 'language'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local util = require 'utility'
-local sp = require 'bee.subprocess'
-local vm = require 'vm'
+local files = require 'files'
+local lang = require 'language'
+local util = require 'utility'
+local sp = require 'bee.subprocess'
+local vm = require 'vm'
+local guide = require "parser.guide"
local function checkDisableByLuaDocExits(uri, row, mode, code)
local lines = files.getLines(uri)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local text = files.getOriginText(uri)
local line = lines[row]
if ast.ast.docs and line then
@@ -44,7 +43,7 @@ end
local function checkDisableByLuaDocInsert(uri, row, mode, code)
local lines = files.getLines(uri)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local text = files.getOriginText(uri)
-- 先看看上一行是不是已经有了
-- 没有的话就插入一行
@@ -135,7 +134,7 @@ local function changeVersion(uri, version, results)
end
local function solveUndefinedGlobal(uri, diag, results)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local offset = files.offsetOfWord(uri, diag.range.start)
guide.eachSourceContain(ast.ast, offset, function (source)
if source.type ~= 'getglobal' then
@@ -154,7 +153,7 @@ local function solveUndefinedGlobal(uri, diag, results)
end
local function solveLowercaseGlobal(uri, diag, results)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local offset = files.offsetOfWord(uri, diag.range.start)
guide.eachSourceContain(ast.ast, offset, function (source)
if source.type ~= 'setglobal' then
@@ -167,7 +166,7 @@ local function solveLowercaseGlobal(uri, diag, results)
end
local function findSyntax(uri, diag)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
for _, err in ipairs(ast.errs) do
if err.type:lower():gsub('_', '-') == diag.code then
local range = files.range(uri, err.start, err.finish)
@@ -351,7 +350,7 @@ local function checkQuickFix(results, uri, start, diagnostics)
end
local function checkSwapParams(results, uri, start, finish)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local text = files.getText(uri)
if not ast then
return
@@ -540,7 +539,7 @@ local function checkJsonToLua(results, uri, start, finish)
end
return function (uri, start, finish, diagnostics)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
diff --git a/script/core/collector.lua b/script/core/collector.lua
new file mode 100644
index 00000000..763d145b
--- /dev/null
+++ b/script/core/collector.lua
@@ -0,0 +1,71 @@
+local collect = {}
+local subscribed = {}
+
+local m = {}
+
+--- 订阅一个名字
+---@param uri uri
+---@param name string
+---@param value any
+function m.subscribe(uri, name, value)
+ -- 订阅部分
+ local uriSubscribed = subscribed[uri]
+ if not uriSubscribed then
+ uriSubscribed = {}
+ subscribed[uri] = uriSubscribed
+ end
+ uriSubscribed[name] = true
+ -- 收集部分
+ local nameCollect = collect[name]
+ if not nameCollect then
+ nameCollect = {}
+ collect[name] = nameCollect
+ end
+ if value == nil then
+ value = true
+ end
+ nameCollect[uri] = value
+end
+
+--- 丢弃掉某个 uri 中收集的所有信息
+---@param uri uri
+function m.dropUri(uri)
+ local uriSubscribed = subscribed[uri]
+ if not uriSubscribed then
+ return
+ end
+ subscribed[uri] = nil
+ for name in pairs(uriSubscribed) do
+ collect[name][uri] = nil
+ end
+end
+
+--- 是否包含某个名字的订阅
+---@param name string
+---@return boolean
+function m.has(name)
+ local nameCollect = collect[name]
+ if not nameCollect then
+ return false
+ end
+ if next(nameCollect) == nil then
+ return false
+ end
+ return true
+end
+
+--- 迭代某个名字的订阅
+---@param name string
+function m.each(name)
+ local nameCollect = collect[name]
+ if not nameCollect then
+ return function () end
+ end
+ local uri, value
+ return function ()
+ uri, value = next(nameCollect, uri)
+ return value
+ end
+end
+
+return m
diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua
index 527af8d5..6fb9669f 100644
--- a/script/core/command/removeSpace.lua
+++ b/script/core/command/removeSpace.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local proto = require 'proto'
-local lang = require 'language'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
local function isInString(ast, offset)
return guide.eachSourceContain(ast.ast, offset, function (source)
@@ -16,17 +16,17 @@ return function (data)
local uri = data.uri
local lines = files.getLines(uri)
local text = files.getText(uri)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not lines then
return
end
local textEdit = {}
for i = 1, #lines do
- local line = guide.lineContent(lines, text, i, true)
+ local line = searcher.lineContent(lines, text, i, true)
local pos = line:find '[ \t]+$'
if pos then
- local start, finish = guide.lineRange(lines, i, true)
+ local start, finish = searcher.lineRange(lines, i, true)
start = start + pos - 1
if isInString(ast, start) then
goto NEXT_LINE
diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua
index 995a2109..348c2646 100644
--- a/script/core/command/solve.lua
+++ b/script/core/command/solve.lua
@@ -1,8 +1,7 @@
-local files = require 'files'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local proto = require 'proto'
-local lang = require 'language'
+local files = require 'files'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
local opMap = {
['+'] = true,
@@ -29,7 +28,7 @@ local literalMap = {
return function (data)
local uri = data.uri
local text = files.getText(uri)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/completion.lua b/script/core/completion.lua
index e3980eca..d261b302 100644
--- a/script/core/completion.lua
+++ b/script/core/completion.lua
@@ -1,19 +1,14 @@
local define = require 'proto.define'
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local matchKey = require 'core.matchkey'
local vm = require 'vm'
-local getLabel = require 'core.hover.label'
local getName = require 'core.hover.name'
local getArg = require 'core.hover.arg'
-local getReturn = require 'core.hover.return'
-local getDesc = require 'core.hover.description'
local getHover = require 'core.hover'
local config = require 'config'
local util = require 'utility'
local markdown = require 'provider.markdown'
-local findSource = require 'core.find-source'
-local await = require 'await'
local parser = require 'parser'
local keyWordMap = require 'core.keyword'
local workspace = require 'workspace'
@@ -21,6 +16,8 @@ local furi = require 'file-uri'
local rpath = require 'workspace.require-path'
local lang = require 'language'
local lookBackward = require 'core.look-backward'
+local guide = require 'parser.guide'
+local infer = require 'core.infer'
local DiagnosticModes = {
'disable-next-line',
@@ -135,8 +132,11 @@ local function buildFunctionSnip(source, value, oop)
end
local function buildDetail(source)
- local types = vm.getInferType(source, 0)
- local literals = vm.getInferLiteral(source, 0)
+ if source.type == 'dummy' then
+ return
+ end
+ local types = infer.searchAndViewInfers(source)
+ local literals = infer.searchAndViewLiterals(source)
if literals then
return types .. ' = ' .. literals
else
@@ -149,9 +149,9 @@ local function getSnip(source)
if context <= 0 then
return nil
end
- local defs = vm.getRefs(source, 0)
+ local defs = vm.getRefs(source)
for _, def in ipairs(defs) do
- def = guide.getObjectValue(def) or def
+ def = searcher.getObjectValue(def) or def
if def ~= source and def.type == 'function' then
local uri = guide.getUri(def)
local text = files.getText(uri)
@@ -173,6 +173,9 @@ local function getSnip(source)
end
local function buildDesc(source)
+ if source.type == 'dummy' then
+ return
+ end
local hover = getHover.get(source)
local md = markdown()
md:add('lua', hover.label)
@@ -273,8 +276,8 @@ local function checkLocal(ast, word, offset, results)
if not matchKey(word, name) then
goto CONTINUE
end
- if vm.hasType(source, 'function') then
- for _, def in ipairs(vm.getDefs(source, 0)) do
+ if infer.hasType(source, 'function') then
+ for _, def in ipairs(vm.getDefs(source)) do
if def.type == 'function'
or def.type == 'doc.type.function' then
local funcLabel = name .. getParams(def, false)
@@ -325,7 +328,7 @@ local function checkModule(ast, word, offset, results)
and not config.config.diagnostics.globals[stemName]
and stemName:match '^[%a_][%w_]*$'
and matchKey(word, stemName) then
- local targetAst = files.getAst(uri)
+ local targetAst = files.getState(uri)
if not targetAst then
goto CONTINUE
end
@@ -417,7 +420,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset)
end
local function checkFieldThen(name, src, word, start, offset, parent, oop, results)
- local value = guide.getObjectValue(src) or src
+ local value = searcher.getObjectValue(src) or src
local kind = define.CompletionItemKind.Field
if value.type == 'function'
or value.type == 'doc.type.function' then
@@ -492,7 +495,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res
end
local funcLabel
if config.config.completion.showParams then
- local value = guide.getObjectValue(src) or src
+ local value = searcher.getObjectValue(src) or src
if value.type == 'function'
or value.type == 'doc.type.function' then
funcLabel = name .. getParams(value, oop)
@@ -539,16 +542,16 @@ end
local function checkGlobal(ast, word, start, offset, parent, oop, results)
local locals = guide.getVisibleLocals(ast.ast, offset)
- local refs = vm.getGlobalSets '*'
- checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global')
+ local globals = vm.getGlobalSets '*'
+ checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results, locals, 'global')
end
local function checkField(ast, word, start, offset, parent, oop, results)
if parent.tag == '_ENV' or parent.special == '_G' then
- local refs = vm.getGlobalSets '*'
- checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
+ local globals = vm.getGlobalSets '*'
+ checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results)
else
- local refs = vm.getFields(parent, 0)
+ local refs = vm.getRefs(parent, '*')
checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
end
end
@@ -1043,14 +1046,14 @@ local function mergeEnums(a, b, source)
end
end
-local function checkTypingEnum(ast, text, offset, infers, str, results)
+local function checkTypingEnum(ast, text, offset, defs, str, results)
local enums = {}
- for _, infer in ipairs(infers) do
- if infer.source.type == 'doc.type.enum'
- or infer.source.type == 'doc.resume' then
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
enums[#enums+1] = {
- label = infer.source[1],
- description = infer.source.comment and infer.source.comment.text,
+ label = def[1],
+ description = def.comment and def.comment.text,
kind = define.CompletionItemKind.EnumMember,
}
end
@@ -1074,8 +1077,8 @@ local function checkEqualEnumLeft(ast, text, offset, source, results)
return src
end
end)
- local infers = vm.getInfers(source, 0)
- checkTypingEnum(ast, text, offset, infers, str, results)
+ local defs = vm.getDefs(source)
+ checkTypingEnum(ast, text, offset, defs, str, results)
end
local function checkEqualEnum(ast, text, offset, results)
@@ -1247,7 +1250,30 @@ function (%s)\
end"):format(table.concat(args, ', '))
end
-local function getCallEnums(source, index)
+local function pushCallEnumsAndFuncs(defs)
+ local results = {}
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
+ results[#results+1] = {
+ label = def[1],
+ description = def.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ end
+ if def.type == 'doc.type.function' then
+ results[#results+1] = {
+ label = infer.viewDocFunction(def),
+ description = def.comment,
+ kind = define.CompletionItemKind.Function,
+ insertText = buildInsertDocFunction(def),
+ }
+ end
+ end
+ return results
+end
+
+local function getCallEnumsAndFuncs(source, index)
if source.type == 'function' and source.bindDocs then
if not source.args then
return
@@ -1266,37 +1292,10 @@ local function getCallEnums(source, index)
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.param'
and doc.param[1] == arg[1] then
- local enums = {}
- for _, enum in ipairs(vm.getDocEnums(doc.extends) or {}) do
- enums[#enums+1] = {
- label = enum[1],
- description = enum.comment,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
- for _, unit in ipairs(vm.getDocTypeUnits(doc.extends) or {}) do
- if unit.type == 'doc.type.function' then
- local text = files.getText(guide.getUri(unit))
- enums[#enums+1] = {
- label = text:sub(unit.start, unit.finish),
- description = doc.comment,
- kind = define.CompletionItemKind.Function,
- insertText = buildInsertDocFunction(unit),
- }
- end
- end
- return enums
+ return pushCallEnumsAndFuncs(vm.getDefs(doc.extends))
elseif doc.type == 'doc.vararg'
and arg.type == '...' then
- local enums = {}
- for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do
- enums[#enums+1] = {
- label = enum[1],
- description = enum.comment,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
- return enums
+ return pushCallEnumsAndFuncs(vm.getDefs(doc.vararg))
end
end
end
@@ -1403,12 +1402,12 @@ local function checkTableLiteralFieldByCall(ast, text, offset, call, defs, index
return
end
for _, def in ipairs(defs) do
- local func = guide.getObjectValue(def) or def
+ local func = searcher.getObjectValue(def) or def
local param = getFuncParamByCallIndex(func, index)
if not param then
goto CONTINUE
end
- local defs = vm.getDefFields(param, 0)
+ local defs = vm.getDefs(param, '*')
for _, field in ipairs(defs) do
local name = guide.getKeyName(field)
if name and not mark[name] then
@@ -1431,10 +1430,10 @@ local function tryCallArg(ast, text, offset, results)
if arg and arg.type == 'function' then
return
end
- local defs = vm.getDefs(call.node, 0)
+ local defs = vm.getDefs(call.node)
for _, def in ipairs(defs) do
- def = guide.getObjectValue(def) or def
- local enums = getCallEnums(def, argIndex)
+ def = searcher.getObjectValue(def) or def
+ local enums = getCallEnumsAndFuncs(def, argIndex)
if enums then
mergeEnums(myResults, enums, arg)
end
@@ -1461,7 +1460,8 @@ local function tryTable(ast, text, offset, results)
if source.type ~= 'table' then
tbl = source.parent
end
- local defs = vm.getDefFields(tbl, 0)
+ local parent = tbl.parent
+ local defs = vm.getDefs(parent, '*')
for _, field in ipairs(defs) do
local name = guide.getKeyName(field)
if name and not mark[name] then
@@ -1560,7 +1560,7 @@ end
local function tryLuaDocBySource(ast, offset, source, results)
if source.type == 'doc.extends.name' then
if source.parent.type == 'doc.class' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if doc.type == 'doc.class.name'
and doc.parent ~= source.parent
and matchKey(source[1], doc[1]) then
@@ -1578,7 +1578,7 @@ local function tryLuaDocBySource(ast, offset, source, results)
end
return true
elseif source.type == 'doc.type.name' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name')
and doc.parent ~= source.parent
and matchKey(source[1], doc[1]) then
@@ -1652,7 +1652,7 @@ end
local function tryLuaDocByErr(ast, offset, err, docState, results)
if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if doc.type == 'doc.class.name'
and doc.parent ~= docState then
results[#results+1] = {
@@ -1662,7 +1662,7 @@ local function tryLuaDocByErr(ast, offset, err, docState, results)
end
end
elseif err.type == 'LUADOC_MISS_TYPE_NAME' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then
results[#results+1] = {
label = doc[1],
@@ -1735,14 +1735,14 @@ local function buildLuaDocOfFunction(func)
local returns = {}
if func.args then
for _, arg in ipairs(func.args) do
- args[#args+1] = vm.getInferType(arg)
+ args[#args+1] = infer.searchAndViewInfers(arg)
end
end
if func.returns then
for _, rtns in ipairs(func.returns) do
for n = 1, #rtns do
if not returns[n] then
- returns[n] = vm.getInferType(rtns[n])
+ returns[n] = infer.searchAndViewInfers(rtns[n])
end
end
end
@@ -1931,7 +1931,7 @@ local function completion(uri, offset, triggerCharacter)
return results
end
tracy.ZoneBeginN 'completion #1'
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local text = files.getText(uri)
results = {}
clearStack()
diff --git a/script/core/definition.lua b/script/core/definition.lua
index b26bb922..27a9e553 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -1,14 +1,15 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local workspace = require 'workspace'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
+local guide = require 'parser.guide'
local function sortResults(results)
-- 先按照顺序排序
table.sort(results, function (a, b)
- local u1 = guide.getUri(a.target)
- local u2 = guide.getUri(b.target)
+ local u1 = searcher.getUri(a.target)
+ local u2 = searcher.getUri(b.target)
if u1 == u2 then
return a.target.start < b.target.start
else
@@ -20,7 +21,7 @@ local function sortResults(results)
for i = #results, 1, -1 do
local res = results[i].target
local f = res.finish
- local uri = guide.getUri(res)
+ local uri = searcher.getUri(res)
if lf and f > lf and uri == lu then
table.remove(results, i)
else
@@ -101,7 +102,7 @@ local function convertIndex(source)
end
return function (uri, offset)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
@@ -127,11 +128,11 @@ return function (uri, offset)
end
end
- local defs = vm.getDefs(source, 0)
+ local defs = vm.getDefs(source)
local values = {}
for _, src in ipairs(defs) do
- local value = guide.getObjectValue(src)
- if value and value ~= src then
+ local value = searcher.getObjectValue(src)
+ if value and value ~= src and guide.isLiteral(value) then
values[value] = true
end
end
@@ -148,9 +149,6 @@ return function (uri, offset)
goto CONTINUE
end
src = src.field or src.method or src.index or src
- if src.type == 'table' and src.parent.type ~= 'return' then
- goto CONTINUE
- end
if src.type == 'doc.class.name'
and source.type ~= 'doc.type.name'
and source.type ~= 'doc.extends.name'
diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua
index 19bb4f97..bae39a03 100644
--- a/script/core/diagnostics/ambiguity-1.lua
+++ b/script/core/diagnostics/ambiguity-1.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local opMap = {
@@ -25,7 +25,7 @@ local literalMap = {
}
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua
index 702cd904..ae6d4d3b 100644
--- a/script/core/diagnostics/circle-doc-class.lua
+++ b/script/core/diagnostics/circle-doc-class.lua
@@ -1,11 +1,11 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local vm = require 'vm'
+local guide = require 'parser.guide'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
@@ -40,7 +40,7 @@ return function (uri, callback)
end
if not mark[newName] then
mark[newName] = true
- local docs = vm.getDocTypes(newName)
+ local docs = vm.getDocDefines(newName)
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name' then
list[#list+1] = otherDoc.parent
diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua
index d1983c42..afd259d0 100644
--- a/script/core/diagnostics/close-non-object.lua
+++ b/script/core/diagnostics/close-non-object.lua
@@ -1,10 +1,9 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua
index f23755ea..21f7e83a 100644
--- a/script/core/diagnostics/code-after-break.lua
+++ b/script/core/diagnostics/code-after-break.lua
@@ -1,10 +1,10 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua
index 65099af8..a16811ab 100644
--- a/script/core/diagnostics/count-down-loop.lua
+++ b/script/core/diagnostics/count-down-loop.lua
@@ -1,9 +1,9 @@
-local files = require "files"
-local guide = require "core.guide"
-local lang = require 'language'
+local files = require "files"
+local guide = require "parser.guide"
+local lang = require 'language'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua
index 60d60946..c60084fb 100644
--- a/script/core/diagnostics/deprecated.lua
+++ b/script/core/diagnostics/deprecated.lua
@@ -1,13 +1,13 @@
-local files = require 'files'
-local vm = require 'vm'
-local lang = require 'language'
-local guide = require 'core.guide'
-local config = require 'config'
-local define = require 'proto.define'
-local await = require 'await'
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local config = require 'config'
+local define = require 'proto.define'
+local await = require 'await'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -20,7 +20,7 @@ return function (uri, callback)
return
end
if src.type == 'getglobal' then
- local key = guide.getKeyName(src)
+ local key = src[1]
if not key then
return
end
@@ -34,11 +34,11 @@ return function (uri, callback)
await.delay()
- if not vm.isDeprecated(src, 0) then
+ if not vm.isDeprecated(src, true) then
return
end
- local defs = vm.getDefs(src, 0)
+ local defs = vm.getDefs(src)
local validVersions
for _, def in ipairs(defs) do
if def.bindDocs then
diff --git a/script/core/diagnostics/doc-field-no-class.lua b/script/core/diagnostics/doc-field-no-class.lua
index f27bbb32..97603c0b 100644
--- a/script/core/diagnostics/doc-field-no-class.lua
+++ b/script/core/diagnostics/doc-field-no-class.lua
@@ -2,7 +2,7 @@ local files = require 'files'
local lang = require 'language'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua
index 8c6696a9..20eedb5e 100644
--- a/script/core/diagnostics/duplicate-doc-class.lua
+++ b/script/core/diagnostics/duplicate-doc-class.lua
@@ -1,11 +1,11 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local vm = require 'vm'
+local guide = require 'parser.guide'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
@@ -20,7 +20,7 @@ return function (uri, callback)
or doc.type == 'doc.alias' then
local name = guide.getKeyName(doc)
if not cache[name] then
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
cache[name] = {}
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name'
@@ -28,7 +28,7 @@ return function (uri, callback)
cache[name][#cache[name]+1] = {
start = otherDoc.start,
finish = otherDoc.finish,
- uri = guide.getUri(otherDoc),
+ uri = searcher.getUri(otherDoc),
}
end
end
diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua
index b621fd9e..1ee27ff2 100644
--- a/script/core/diagnostics/duplicate-doc-field.lua
+++ b/script/core/diagnostics/duplicate-doc-field.lua
@@ -2,7 +2,7 @@ local files = require 'files'
local lang = require 'language'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
diff --git a/script/core/diagnostics/duplicate-doc-param.lua b/script/core/diagnostics/duplicate-doc-param.lua
index 676a6fb4..b54c1978 100644
--- a/script/core/diagnostics/duplicate-doc-param.lua
+++ b/script/core/diagnostics/duplicate-doc-param.lua
@@ -2,7 +2,7 @@ local files = require 'files'
local lang = require 'language'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua
index 5e63d39e..91a35212 100644
--- a/script/core/diagnostics/duplicate-index.lua
+++ b/script/core/diagnostics/duplicate-index.lua
@@ -1,11 +1,11 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua
index c1e2285a..492793b1 100644
--- a/script/core/diagnostics/duplicate-set-field.lua
+++ b/script/core/diagnostics/duplicate-set-field.lua
@@ -1,11 +1,11 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local define = require 'proto.define'
+local guide = require "parser.guide"
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -30,7 +30,7 @@ return function (uri, callback)
if not name then
goto CONTINUE
end
- local value = guide.getObjectValue(nxt)
+ local value = searcher.getObjectValue(nxt)
if not value or value.type ~= 'function' then
goto CONTINUE
end
diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua
index 690a4ca2..fc205d7e 100644
--- a/script/core/diagnostics/empty-block.lua
+++ b/script/core/diagnostics/empty-block.lua
@@ -1,12 +1,12 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
-- 检查空代码块
-- 但是排除忙等待(repeat/while)
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
index de23bc76..d95963e4 100644
--- a/script/core/diagnostics/global-in-nil-env.lua
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
-- TODO: 检查路径是否可达
@@ -8,7 +8,7 @@ local function mayRun(path)
end
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua
index a2b831f7..5446a7c3 100644
--- a/script/core/diagnostics/init.lua
+++ b/script/core/diagnostics/init.lua
@@ -59,16 +59,18 @@ local function check(uri, name, results)
if passed >= 0.5 then
log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed))
end
+ if DIAGTIMES then
+ DIAGTIMES[name] = (DIAGTIMES[name] or 0) + passed
+ end
end
return function (uri, response)
- local vm = require 'vm'
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
- local isOpen = files.isOpen(uri)
+ log.debug('do diagnostic @', uri)
for _, name in ipairs(diagList) do
await.delay()
diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua
index 9c094701..cba33459 100644
--- a/script/core/diagnostics/lowercase-global.lua
+++ b/script/core/diagnostics/lowercase-global.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local config = require 'config'
local vm = require 'vm'
@@ -18,7 +18,7 @@ end
-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删)
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua
index 0727c2fd..2cbc13ee 100644
--- a/script/core/diagnostics/newfield-call.lua
+++ b/script/core/diagnostics/newfield-call.lua
@@ -1,9 +1,9 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua
index 807f76a2..71dc33e2 100644
--- a/script/core/diagnostics/newline-call.lua
+++ b/script/core/diagnostics/newline-call.lua
@@ -1,9 +1,9 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local lines = files.getLines(uri)
local text = files.getText(uri)
if not ast or not lines then
diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua
index ffaab821..6ff17c81 100644
--- a/script/core/diagnostics/no-implicit-any.lua
+++ b/script/core/diagnostics/no-implicit-any.lua
@@ -1,11 +1,10 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local infer = require 'core.infer'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -21,7 +20,7 @@ return function (uri, callback)
and source.type ~= 'tableindex' then
return
end
- if vm.getInferType(source, 0) == 'any' then
+ if infer.searchAndViewInfers(source) == 'any' then
callback {
start = source.start,
finish = source.finish,
diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua
index 857d80d2..503347d0 100644
--- a/script/core/diagnostics/redefined-local.lua
+++ b/script/core/diagnostics/redefined-local.lua
@@ -1,9 +1,9 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -13,6 +13,9 @@ return function (uri, callback)
or name == ast.ENVMode then
return
end
+ if source.tag == 'self' then
+ return
+ end
local exist = guide.getLocal(source, name, source.start-1)
if exist then
callback {
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
index c5bcd5a5..b25ec77a 100644
--- a/script/core/diagnostics/redundant-parameter.lua
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -1,9 +1,8 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local define = require 'proto.define'
-local await = require 'await'
local function countCallArgs(source)
local result = 0
@@ -67,7 +66,7 @@ local function getFuncArgs(func)
end
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -81,14 +80,7 @@ return function (uri, callback)
end
local func = source.node
- local funcArgs = cache[func]
- if funcArgs == nil then
- funcArgs = getFuncArgs(func) or false
- local refs = vm.getRefs(func, 0)
- for _, ref in ipairs(refs) do
- cache[ref] = funcArgs
- end
- end
+ local funcArgs = getFuncArgs(func)
if not funcArgs then
return
diff --git a/script/core/diagnostics/redundant-value.lua b/script/core/diagnostics/redundant-value.lua
index be483448..d6cd97a7 100644
--- a/script/core/diagnostics/redundant-value.lua
+++ b/script/core/diagnostics/redundant-value.lua
@@ -3,7 +3,7 @@ local define = require 'proto.define'
local lang = require 'language'
return function (uri, callback, code)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua
index 0a4b1d57..824eb83f 100644
--- a/script/core/diagnostics/trailing-space.lua
+++ b/script/core/diagnostics/trailing-space.lua
@@ -1,6 +1,6 @@
local files = require 'files'
local lang = require 'language'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local function isInString(ast, offset)
local result = false
@@ -13,7 +13,7 @@ local function isInString(ast, offset)
end
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua
index b2b2800c..df71f0c9 100644
--- a/script/core/diagnostics/unbalanced-assignments.lua
+++ b/script/core/diagnostics/unbalanced-assignments.lua
@@ -1,10 +1,10 @@
local files = require 'files'
local define = require 'proto.define'
local lang = require 'language'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
return function (uri, callback, code)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua
index a91cfa7f..e7133ab9 100644
--- a/script/core/diagnostics/undefined-doc-class.lua
+++ b/script/core/diagnostics/undefined-doc-class.lua
@@ -1,11 +1,11 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
@@ -25,7 +25,7 @@ return function (uri, callback)
end
for _, ext in ipairs(doc.extends) do
local name = ext[1]
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
if cache[name] == nil then
cache[name] = false
for _, otherDoc in ipairs(docs) do
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
index d8a4363b..91d4b90e 100644
--- a/script/core/diagnostics/undefined-doc-name.lua
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -1,11 +1,10 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
-local define = require 'proto.define'
local vm = require 'vm'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
@@ -14,26 +13,6 @@ return function (uri, callback)
return
end
- local classCache = {
- ['any'] = true,
- ['nil'] = true,
- }
- local function hasNameOfClassOrAlias(name)
- if classCache[name] ~= nil then
- return classCache[name]
- end
- local docs = vm.getDocTypes(name)
- for _, otherDoc in ipairs(docs) do
- if otherDoc.type == 'doc.class.name'
- or otherDoc.type == 'doc.alias.name' then
- classCache[name] = true
- return true
- end
- end
- classCache[name] = false
- return false
- end
-
local function hasNameOfGeneric(name, source)
if not source.typeGeneric then
return false
@@ -56,7 +35,7 @@ return function (uri, callback)
if name == '...' then
return
end
- if hasNameOfClassOrAlias(name)
+ if vm.isDocDefined(name)
or hasNameOfGeneric(name, source) then
return
end
diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua
index 0bf371e5..6140b4f0 100644
--- a/script/core/diagnostics/undefined-doc-param.lua
+++ b/script/core/diagnostics/undefined-doc-param.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
@@ -17,7 +17,7 @@ local function hasParamName(func, name)
end
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua
index 89efb8c7..39c8de27 100644
--- a/script/core/diagnostics/undefined-env-child.lua
+++ b/script/core/diagnostics/undefined-env-child.lua
@@ -1,10 +1,11 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local vm = require 'vm'
-local lang = require 'language'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local vm = require "vm.vm"
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -13,7 +14,7 @@ return function (uri, callback)
if source.node.tag == '_ENV' then
return
end
- local defs = guide.requestDefinition(source)
+ local defs = vm.getDefs(source)
if #defs > 0 then
return
end
diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua
index b10c9ab0..9d1f696c 100644
--- a/script/core/diagnostics/undefined-field.lua
+++ b/script/core/diagnostics/undefined-field.lua
@@ -2,11 +2,18 @@ local files = require 'files'
local vm = require 'vm'
local lang = require 'language'
local config = require 'config'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
+local SkipCheckClass = {
+ ['unknown'] = true,
+ ['any'] = true,
+ ['table'] = true,
+ ['nil'] = true,
+}
+
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -18,7 +25,7 @@ return function (uri, callback)
if cache[src] == nil then
tracy.ZoneBeginN('undefined-field getInfers')
infers = vm.getInfers(src, 0) or false
- local refs = vm.getRefs(src, 0)
+ local refs = vm.getRefs(src)
for _, ref in ipairs(refs) do
cache[ref] = infers
end
@@ -47,7 +54,7 @@ return function (uri, callback)
elseif inferSource.type == 'doc.class.name' then
addTo(allDocClass, inferSource.parent)
elseif inferSource.type == 'doc.type.name' then
- local docTypes = vm.getDocTypes(inferSource[1])
+ local docTypes = vm.getDocDefines(inferSource[1])
for _, docType in ipairs(docTypes) do
if docType.type == 'doc.class.name' then
addTo(allDocClass, docType.parent)
@@ -65,7 +72,7 @@ return function (uri, callback)
local empty = true
for _, docClass in ipairs(allDocClass) do
tracy.ZoneBeginN('undefined-field getDefFields')
- local refs = vm.getDefFields(docClass)
+ local refs = vm.getDefs(docClass, '*')
tracy.ZoneEnd()
for _, ref in ipairs(refs) do
@@ -87,35 +94,37 @@ return function (uri, callback)
end
local function checkUndefinedField(src)
- local fieldName = guide.getKeyName(src)
-
- local allDocClass = getAllDocClassFromInfer(src.node)
- if (not allDocClass) or (#allDocClass == 0) then
- return
- end
-
- local fields = getAllFieldsFromAllDocClass(allDocClass)
-
- -- 没找到任何 field,跳过检查
- if not fields then
+ if #vm.getDefs(src) > 0 then
return
end
-
- if not fields[fieldName] then
- local message = lang.script('DIAG_UNDEF_FIELD', fieldName)
- if src.type == 'getfield' and src.field then
- callback {
- start = src.field.start,
- finish = src.field.finish,
- message = message,
- }
- elseif src.type == 'getmethod' and src.method then
- callback {
- start = src.method.start,
- finish = src.method.finish,
- message = message,
- }
+ local node = src.node
+ if node then
+ local defs = vm.getDefs(node)
+ local ok
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class.name'
+ and not SkipCheckClass[def[1]] then
+ ok = true
+ break
+ end
end
+ if not ok then
+ return
+ end
+ end
+ local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src))
+ if src.type == 'getfield' and src.field then
+ callback {
+ start = src.field.start,
+ finish = src.field.finish,
+ message = message,
+ }
+ elseif src.type == 'getmethod' and src.method then
+ callback {
+ start = src.method.start,
+ finish = src.method.finish,
+ message = message,
+ }
end
end
guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField);
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
index 161d8856..549a1922 100644
--- a/script/core/diagnostics/undefined-global.lua
+++ b/script/core/diagnostics/undefined-global.lua
@@ -1,9 +1,8 @@
-local files = require 'files'
-local vm = require 'vm'
-local lang = require 'language'
-local config = require 'config'
-local guide = require 'core.guide'
-local define = require 'proto.define'
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local config = require 'config'
+local guide = require 'parser.guide'
local requireLike = {
['include'] = true,
@@ -13,14 +12,14 @@ local requireLike = {
}
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
-- 遍历全局变量,检查所有没有 set 模式的全局变量
guide.eachSourceType(ast.ast, 'getglobal', function (src)
- local key = guide.getKeyName(src)
+ local key = src[1]
if not key then
return
end
@@ -30,7 +29,11 @@ return function (uri, callback)
if config.config.runtime.special[key] then
return
end
- if #vm.getGlobalSets(key) == 0 then
+ local node = src.node
+ if node.tag ~= '_ENV' then
+ return
+ end
+ if #vm.getDefs(src) == 0 then
local message = lang.script('DIAG_UNDEF_GLOBAL', key)
if requireLike[key:lower()] then
message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
diff --git a/script/core/diagnostics/unknown-diag-code.lua b/script/core/diagnostics/unknown-diag-code.lua
index 45d3b6db..013a702b 100644
--- a/script/core/diagnostics/unknown-diag-code.lua
+++ b/script/core/diagnostics/unknown-diag-code.lua
@@ -3,7 +3,7 @@ local lang = require 'language'
local define = require 'proto.define'
return function (uri, callback)
- local state = files.getAst(uri)
+ local state = files.getState(uri)
if not state then
return
end
diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua
index b6f92e60..59f27e59 100644
--- a/script/core/diagnostics/unused-function.lua
+++ b/script/core/diagnostics/unused-function.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local vm = require 'vm'
local define = require 'proto.define'
local lang = require 'language'
@@ -19,7 +19,7 @@ local function isToBeClosed(source)
end
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua
index e2d5e49a..8ee0bba3 100644
--- a/script/core/diagnostics/unused-label.lua
+++ b/script/core/diagnostics/unused-label.lua
@@ -1,10 +1,10 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua
index fde90cb8..072cbd31 100644
--- a/script/core/diagnostics/unused-local.lua
+++ b/script/core/diagnostics/unused-local.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
@@ -77,7 +77,7 @@ local function isDocParam(source)
end
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -87,6 +87,9 @@ return function (uri, callback)
or name == ast.ENVMode then
return
end
+ if source.tag == 'self' then
+ return
+ end
if isToBeClosed(source) then
return
end
diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua
index ec0a05fb..2e07e1ee 100644
--- a/script/core/diagnostics/unused-vararg.lua
+++ b/script/core/diagnostics/unused-vararg.lua
@@ -1,10 +1,10 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
return function (uri, callback)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua
index cc87e3ca..03169cfd 100644
--- a/script/core/document-symbol.lua
+++ b/script/core/document-symbol.lua
@@ -1,8 +1,8 @@
-local await = require 'await'
-local files = require 'files'
-local guide = require 'core.guide'
-local define = require 'proto.define'
-local util = require 'utility'
+local await = require 'await'
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local util = require 'utility'
local function buildName(source, text)
if source.type == 'setmethod'
@@ -228,7 +228,7 @@ local function buildSource(source, text, used, symbols)
end
local function makeSymbol(uri)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local text = files.getText(uri)
if not ast or not text then
return nil
diff --git a/script/core/find-source.lua b/script/core/find-source.lua
index b36306b6..edbb1e2c 100644
--- a/script/core/find-source.lua
+++ b/script/core/find-source.lua
@@ -1,4 +1,4 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local function isValidFunctionPos(source, offset)
for i = 1, #source.keyword // 2 do
diff --git a/script/core/folding.lua b/script/core/folding.lua
index 15678995..dad98422 100644
--- a/script/core/folding.lua
+++ b/script/core/folding.lua
@@ -1,5 +1,5 @@
local files = require "files"
-local guide = require "core.guide"
+local guide = require "parser.guide"
local util = require 'utility'
local Care = {
@@ -145,7 +145,7 @@ local Care = {
}
return function (uri)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local text = files.getText(uri)
if not ast or not text then
return nil
diff --git a/script/core/generic.lua b/script/core/generic.lua
new file mode 100644
index 00000000..15950974
--- /dev/null
+++ b/script/core/generic.lua
@@ -0,0 +1,234 @@
+local guide = require 'parser.guide'
+local noder = require "core.noder"
+
+---@class generic.value
+---@field type string
+---@field closure generic.closure
+---@field proto parser.guide.object
+---@field parent parser.guide.object
+
+---@class generic.closure
+---@field type string
+---@field proto parser.guide.object
+---@field upvalues table<string, generic.value[]>
+---@field params generic.value[]
+---@field returns generic.value[]
+
+local m = {}
+
+---@param closure generic.closure
+---@param proto parser.guide.object
+local function instantValue(closure, proto)
+ ---@type generic.value
+ local value = {
+ type = 'generic.value',
+ closure = closure,
+ proto = proto,
+ parent = proto.parent,
+ }
+ closure.values[#closure.values+1] = value
+ return value
+end
+
+---递归实例化对象
+---@param proto parser.guide.object
+---@return generic.value
+local function createValue(closure, proto, callback, road)
+ if callback then
+ road = road or {}
+ end
+ if proto.type == 'doc.type' then
+ local types = {}
+ local hasGeneric
+ for i, tp in ipairs(proto.types) do
+ local genericValue = createValue(closure, tp, callback, road)
+ if genericValue then
+ hasGeneric = true
+ types[i] = genericValue
+ else
+ types[i] = tp
+ end
+ end
+ if not hasGeneric then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.types = types
+ noder.compileNode(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.name' then
+ if not proto.typeGeneric then
+ return nil
+ end
+ local key = proto[1]
+ local value = instantValue(closure, proto)
+ if callback then
+ callback(road, key, proto)
+ end
+ noder.compileNode(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.function' then
+ local hasGeneric
+ local args = {}
+ local returns = {}
+ for i, arg in ipairs(proto.args) do
+ local value = createValue(closure, arg, callback, road)
+ if value then
+ hasGeneric = true
+ end
+ args[i] = value or arg
+ end
+ for i, rtn in ipairs(proto.returns) do
+ local value = createValue(closure, rtn, callback, road)
+ if value then
+ hasGeneric = true
+ end
+ returns[i] = value or rtn
+ end
+ if not hasGeneric then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.args = args
+ value.returns = returns
+ value.isGeneric = true
+ noder.pushSource(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.array' then
+ if road then
+ road[#road+1] = noder.ANY_FIELD
+ end
+ local node = createValue(closure, proto.node, callback, road)
+ if road then
+ road[#road] = nil
+ end
+ if not node then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.node = node
+ return value
+ end
+ if proto.type == 'doc.type.table' then
+ road[#road+1] = noder.TABLE_KEY
+ local tkey = createValue(closure, proto.tkey, callback, road)
+ road[#road] = nil
+
+ road[#road+1] = noder.ANY_FIELD
+ local tvalue = createValue(closure, proto.tvalue, callback, road)
+ road[#road] = nil
+
+ if not tkey and not tvalue then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.tkey = tkey or proto.tkey
+ value.tvalue = tvalue or proto.tvalue
+ return value
+ end
+end
+
+local function buildValue(road, key, proto, param, upvalues)
+ local paramID
+ if proto.literal then
+ local str = param.type == 'string' and param[1]
+ if not str then
+ return
+ end
+ paramID = 'dn:' .. str
+ else
+ paramID = noder.getID(param)
+ end
+ if not paramID then
+ return
+ end
+ local myUri = guide.getUri(param)
+ local myHead = noder.URI_CHAR .. myUri .. noder.URI_CHAR
+ paramID = myHead .. paramID
+ if not upvalues[key] then
+ upvalues[key] = {}
+ end
+ upvalues[key][#upvalues[key]+1] = paramID .. table.concat(road)
+end
+
+-- 为所有的 param 与 return 创建副本
+---@param closure generic.closure
+local function buildValues(closure)
+ local protoFunction = closure.proto
+ local upvalues = closure.upvalues
+ local params = closure.call.args
+
+ if protoFunction.type == 'function' then
+ for _, doc in ipairs(protoFunction.bindDocs) do
+ if doc.type == 'doc.param' then
+ local extends = doc.extends
+ local index = extends.paramIndex
+ if index then
+ local param = params and params[index]
+ closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ buildValue(road, key, proto, param, upvalues)
+ end) or extends
+ end
+ end
+ end
+ for _, doc in ipairs(protoFunction.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ closure.returns[rtn.returnIndex] = createValue(closure, rtn) or rtn
+ end
+ end
+ end
+ end
+ if protoFunction.type == 'doc.type.function' then
+ for index, arg in ipairs(protoFunction.args) do
+ local extends = arg.extends
+ local param = params and params[index]
+ closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ buildValue(road, key, proto, param, upvalues)
+ end) or extends
+ end
+ for index, rtn in ipairs(protoFunction.returns) do
+ closure.returns[index] = createValue(closure, rtn) or rtn
+ end
+ end
+end
+
+---创建一个闭包
+---@param proto parser.guide.object|generic.value # 原型函数|泛型值
+---@return generic.closure
+function m.createClosure(proto, call)
+ local protoFunction, parentClosure
+ if proto.type == 'function' then
+ protoFunction = proto
+ elseif proto.type == 'doc.type.function' then
+ protoFunction = proto
+ elseif proto.type == 'generic.value' then
+ protoFunction = proto.proto
+ parentClosure = proto.closure
+ end
+ ---@type generic.closure
+ local closure = {
+ type = 'generic.closure',
+ parent = protoFunction.parent,
+ proto = protoFunction,
+ call = call,
+ upvalues = parentClosure and parentClosure.upvalues or {},
+ params = {},
+ returns = {},
+ values = {},
+ }
+ buildValues(closure)
+
+ if #closure.returns == 0 then
+ return nil
+ end
+
+ noder.compileNode(noder.getNoders(proto), closure)
+
+ return closure
+end
+
+return m
diff --git a/script/core/guide.lua b/script/core/guide2.lua
index c7a784b7..183555b3 100644
--- a/script/core/guide.lua
+++ b/script/core/guide2.lua
@@ -292,8 +292,13 @@ end
---@param obj parser.guide.object
---@return parser.guide.object
function m.getRoot(obj)
+ local source = obj
+ if source._root then
+ return source._root
+ end
for _ = 1, 1000 do
if obj.type == 'main' then
+ source._root = obj
return obj
end
local parent = obj.parent
diff --git a/script/core/highlight.lua b/script/core/highlight.lua
index 12ec114f..d1f11906 100644
--- a/script/core/highlight.lua
+++ b/script/core/highlight.lua
@@ -1,31 +1,18 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local files = require 'files'
local vm = require 'vm'
local define = require 'proto.define'
local findSource = require 'core.find-source'
local util = require 'utility'
+local guide = require 'parser.guide'
local function eachRef(source, callback)
- local results = guide.requestReference(source)
+ local results = vm.getRefs(source)
for i = 1, #results do
callback(results[i])
end
end
-local function eachField(source, callback)
- if not source then
- return
- end
- local isGlobal = guide.isGlobal(source)
- local results = guide.requestReference(source)
- for i = 1, #results do
- local res = results[i]
- if isGlobal == guide.isGlobal(res) then
- callback(res)
- end
- end
-end
-
local function eachLocal(source, callback)
callback(source)
if source.ref then
@@ -43,21 +30,21 @@ local function find(source, uri, callback)
eachLocal(source.node, callback)
elseif source.type == 'field'
or source.type == 'method' then
- eachField(source.parent, callback)
+ eachRef(source.parent, callback)
elseif source.type == 'getindex'
or source.type == 'setindex'
or source.type == 'tableindex' then
- eachField(source, callback)
+ eachRef(source, callback)
elseif source.type == 'setglobal'
or source.type == 'getglobal' then
- eachField(source, callback)
+ eachRef(source, callback)
elseif source.type == 'goto'
or source.type == 'label' then
eachRef(source, callback)
elseif source.type == 'string'
and source.parent
and source.parent.index == source then
- eachField(source.parent, callback)
+ eachRef(source.parent, callback)
elseif source.type == 'string'
or source.type == 'boolean'
or source.type == 'number'
@@ -238,8 +225,18 @@ local accept = {
['nil'] = true,
}
+local function isLiteralValue(source)
+ if not guide.isLiteral(source) then
+ return false
+ end
+ if source.parent.index == source then
+ return false
+ end
+ return true
+end
+
return function (uri, offset)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
@@ -249,10 +246,28 @@ return function (uri, offset)
local source = findSource(ast, offset, accept)
if source then
+ local isGlobal = guide.isGlobal(source)
+ local isLiteral = isLiteralValue(source)
find(source, uri, function (target)
+ if not target then
+ return
+ end
if target.dummy then
return
end
+ if mark[target] then
+ return
+ end
+ mark[target] = true
+ if isGlobal ~= guide.isGlobal(target) then
+ return
+ end
+ if isLiteral ~= isLiteralValue(target) then
+ return
+ end
+ if not files.eq(uri, guide.getUri(target)) then
+ return
+ end
local kind
if target.type == 'getfield' then
target = target.field
@@ -315,13 +330,6 @@ return function (uri, offset)
else
return
end
- if not target then
- return
- end
- if mark[target] then
- return
- end
- mark[target] = true
results[#results+1] = {
start = target.start,
finish = target.finish,
diff --git a/script/core/hint.lua b/script/core/hint.lua
index 13d01dc7..67c725f7 100644
--- a/script/core/hint.lua
+++ b/script/core/hint.lua
@@ -1,10 +1,11 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local vm = require 'vm'
-local config = require 'config'
+local files = require 'files'
+local infer = require 'core.infer'
+local vm = require 'vm'
+local config = require 'config'
+local guide = require 'parser.guide'
local function typeHint(uri, edits, start, finish)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -18,6 +19,9 @@ local function typeHint(uri, edits, start, finish)
and source.type ~= 'setindex' then
return
end
+ if source.dummy then
+ return
+ end
if source[1] == '_' then
return
end
@@ -33,9 +37,9 @@ local function typeHint(uri, edits, start, finish)
return
end
end
- local infer = vm.getInferType(source, 0)
- if infer == 'any'
- or infer == 'nil' then
+ local view = infer.searchAndViewInfers(source)
+ if view == 'any'
+ or view == 'nil' then
return
end
local src = source
@@ -52,7 +56,7 @@ local function typeHint(uri, edits, start, finish)
end
mark[src] = true
edits[#edits+1] = {
- newText = (':%s'):format(infer),
+ newText = (':%s'):format(view),
start = src.finish,
finish = src.finish,
}
@@ -95,7 +99,7 @@ local function paramName(uri, edits, start, finish)
if not config.config.hint.paramName then
return
end
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
@@ -107,7 +111,7 @@ local function paramName(uri, edits, start, finish)
if not hasLiteralArgInCall(source) then
return
end
- local defs = vm.getDefs(source.node, 0)
+ local defs = vm.getDefs(source.node)
if not defs then
return
end
diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua
index 324d28af..822be2b6 100644
--- a/script/core/hover/arg.lua
+++ b/script/core/hover/arg.lua
@@ -1,4 +1,5 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
+local infer = require 'core.infer'
local vm = require 'vm'
local function optionalArg(arg)
@@ -21,7 +22,7 @@ local function asFunction(source, oop)
methodDef = true
end
if methodDef then
- args[#args+1] = ('self: %s'):format(vm.getInferType(parent.node))
+ args[#args+1] = ('self: %s'):format(infer.searchAndViewInfers(parent.node))
end
if source.args then
for i = 1, #source.args do
@@ -34,10 +35,12 @@ local function asFunction(source, oop)
args[#args+1] = ('%s%s: %s'):format(
name,
optionalArg(arg) and '?' or '',
- vm.getInferType(arg)
+ infer.searchAndViewInfers(arg)
)
+ elseif arg.type == '...' then
+ args[#args+1] = '...'
else
- args[#args+1] = ('%s'):format(vm.getInferType(arg))
+ args[#args+1] = ('%s'):format(infer.searchAndViewInfers(arg))
end
::CONTINUE::
end
@@ -61,7 +64,7 @@ local function asDocFunction(source)
args[i] = ('%s%s: %s'):format(
name,
arg.optional and '?' or '',
- vm.getInferType(arg.extends)
+ infer.searchAndViewInfers(arg.extends)
)
else
args[i] = ('%s%s'):format(
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index 401ca5a7..bcc3065a 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -2,11 +2,13 @@ local vm = require 'vm'
local ws = require 'workspace'
local furi = require 'file-uri'
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local markdown = require 'provider.markdown'
local config = require 'config'
local lang = require 'language'
local util = require 'utility'
+local guide = require 'parser.guide'
+local noder = require 'core.noder'
local function asStringInRequire(source, literal)
local rootPath = ws.path or ''
@@ -124,10 +126,10 @@ local function getBindComment(source, docGroup, base)
end
local function tryDocClassComment(source)
- for _, def in ipairs(vm.getDefs(source, 0)) do
+ for _, def in ipairs(vm.getDefs(source)) do
if def.type == 'doc.class.name'
or def.type == 'doc.alias.name' then
- local class = guide.getDocState(def)
+ local class = noder.getDocState(def)
local comment = getBindComment(class, class.bindGroup, class)
if comment then
return comment
@@ -180,7 +182,7 @@ local function isFunction(source)
if source.type == 'function' then
return true
end
- local value = guide.getObjectValue(source)
+ local value = searcher.getObjectValue(source)
if not value then
return false
end
@@ -223,13 +225,14 @@ local function getBindEnums(source, docGroup)
end
local function tryDocFieldUpComment(source)
- if source.type ~= 'doc.field' then
+ if source.type ~= 'doc.field.name' then
return
end
- if not source.bindGroup then
+ local docField = source.parent
+ if not docField.bindGroup then
return
end
- local comment = getBindComment(source, source.bindGroup, source)
+ local comment = getBindComment(docField, docField.bindGroup, docField)
return comment
end
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
index 0c8644ed..41616bc9 100644
--- a/script/core/hover/init.lua
+++ b/script/core/hover/init.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local vm = require 'vm'
local getLabel = require 'core.hover.label'
local getDesc = require 'core.hover.description'
@@ -7,6 +7,7 @@ local util = require 'utility'
local findSource = require 'core.find-source'
local lang = require 'language'
local markdown = require 'provider.markdown'
+local infer = require 'core.infer'
local function eachFunctionAndOverload(value, callback)
callback(value)
@@ -24,7 +25,7 @@ local function getHoverAsValue(source)
local label = getLabel(source)
local desc = getDesc(source)
if not desc then
- local values = vm.getDefs(source, 0)
+ local values = vm.getDefs(source)
for _, def in ipairs(values) do
desc = getDesc(def)
if desc then
@@ -40,7 +41,7 @@ local function getHoverAsValue(source)
end
local function getHoverAsFunction(source)
- local values = vm.getDefs(source, 0)
+ local values = vm.getDefs(source)
local desc = getDesc(source)
local labels = {}
local defs = 0
@@ -48,7 +49,7 @@ local function getHoverAsFunction(source)
local other = 0
local mark = {}
for _, def in ipairs(values) do
- def = guide.getObjectValue(def) or def
+ def = searcher.getObjectValue(def) or def
if def.type == 'function'
or def.type == 'doc.type.function' then
eachFunctionAndOverload(def, function (value)
@@ -123,7 +124,7 @@ local function getHover(source)
if source.type == 'doc.type.name' then
return getHoverAsDocName(source)
end
- local isFunction = vm.hasInferType(source, 'function', 0)
+ local isFunction = infer.hasType(source, 'function')
if isFunction then
return getHoverAsFunction(source)
else
@@ -146,7 +147,7 @@ local accept = {
}
local function getHoverByUri(uri, offset)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
index d93b14e3..d96b149c 100644
--- a/script/core/hover/label.lua
+++ b/script/core/hover/label.lua
@@ -2,9 +2,10 @@ local buildName = require 'core.hover.name'
local buildArg = require 'core.hover.arg'
local buildReturn = require 'core.hover.return'
local buildTable = require 'core.hover.table'
+local infer = require 'core.infer'
local vm = require 'vm'
local util = require 'utility'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local config = require 'config'
local files = require 'files'
@@ -31,29 +32,28 @@ local function asDocFunction(source)
end
local function asDocTypeName(source)
- for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ local defs = vm.getDefs(source)
+ for _, doc in ipairs(defs) do
if doc.type == 'doc.class.name' then
- return 'class ' .. source[1]
+ return 'class ' .. doc[1]
end
if doc.type == 'doc.alias.name' then
local extends = doc.parent.extends
- return lang.script('HOVER_EXTENDS', vm.getInferType(extends))
+ return lang.script('HOVER_EXTENDS', infer.searchAndViewInfers(extends))
end
end
end
local function asValue(source, title)
local name = buildName(source)
- local infers = vm.getInfers(source, 0)
- local type = vm.getInferType(source, 0)
- local class = vm.getClass(source, 0)
- local literal = vm.getInferLiteral(source, 0)
+ local type = infer.searchAndViewInfers(source)
+ local literal = infer.searchAndViewLiterals(source)
local cont
- if not vm.hasInferType(source, 'string', 0)
+ if not infer.hasType(source, 'string')
and not type:find('%[%]$')
and not type:find('%w%<') then
- if #vm.getFields(source, 0) > 0
- or vm.hasInferType(source, 'table', 0) then
+ if #vm.getRefs(source, '*') > 0
+ or infer.hasType(source, 'table') then
cont = buildTable(source)
end
end
@@ -66,11 +66,7 @@ local function asValue(source, title)
or type == 'nil') then
type = nil
end
- if class then
- pack[#pack+1] = class
- else
- pack[#pack+1] = type
- end
+ pack[#pack+1] = type
if literal then
pack[#pack+1] = '='
pack[#pack+1] = literal
@@ -123,30 +119,21 @@ local function asField(source)
return asValue(source, 'field')
end
-local function asDocField(source)
- local name = source.field[1]
+local function asDocFieldName(source)
+ local name = source[1]
+ local docField = source.parent
local class
- for _, doc in ipairs(source.bindGroup) do
+ for _, doc in ipairs(docField.bindGroup) do
if doc.type == 'doc.class' then
class = doc
break
end
end
- local infers = {}
- for _, infer in ipairs(vm.getInfers(source.extends) or {}) do
- infers[#infers+1] = infer
- end
+ local view = infer.searchAndViewInfers(docField.extends)
if not class then
- return ('field ?.%s: %s'):format(
- name,
- guide.viewInferType(infers)
- )
- end
- return ('field %s.%s: %s'):format(
- class.class[1],
- name,
- guide.viewInferType(infers)
- )
+ return ('field ?.%s: %s'):format(name, view)
+ end
+ return ('field %s.%s: %s'):format(class.class[1], name, view)
end
local function asString(source)
@@ -177,7 +164,7 @@ local function asNumber(source)
if type(num) ~= 'number' then
return nil
end
- local uri = guide.getUri(source)
+ local uri = searcher.getUri(source)
local text = files.getText(uri)
if not text then
return nil
@@ -215,7 +202,7 @@ return function (source, oop)
return asDocFunction(source)
elseif source.type == 'doc.type.name' then
return asDocTypeName(source)
- elseif source.type == 'doc.field' then
- return asDocField(source)
+ elseif source.type == 'doc.field.name' then
+ return asDocFieldName(source)
end
end
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
index d583f1e1..d2b9d30b 100644
--- a/script/core/hover/name.lua
+++ b/script/core/hover/name.lua
@@ -1,4 +1,6 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
+local infer = require 'core.infer'
+local guide = require 'parser.guide'
local vm = require 'vm'
local buildName
@@ -19,7 +21,7 @@ end
local function asField(source, oop)
local class
if source.node.type ~= 'getglobal' then
- class = vm.getClass(source.node, 0)
+ class = infer.getClass(source.node)
end
local node = class or guide.getKeyName(source.node) or '?'
local method = guide.getKeyName(source)
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
index c3e9656d..0f0d85e0 100644
--- a/script/core/hover/return.lua
+++ b/script/core/hover/return.lua
@@ -1,12 +1,4 @@
-local guide = require 'core.guide'
-local vm = require 'vm'
-
-local function mergeTypes(returns)
- if type(returns) == 'string' then
- return returns
- end
- return guide.mergeTypes(returns)
-end
+local infer = require 'core.infer'
local function getReturnDualByDoc(source)
local docs = source.bindDocs
@@ -55,24 +47,20 @@ local function asFunction(source)
local returns = {}
for i, rtn in ipairs(dual) do
local line = {}
- local types = {}
+ local infers = {}
if i == 1 then
line[#line+1] = ' -> '
else
line[#line+1] = ('% 3d. '):format(i)
end
for n = 1, #rtn do
- local values = vm.getInfers(rtn[n])
- for _, value in ipairs(values) do
- if value.type then
- for tp in value.type:gmatch '[^|]+' do
- types[tp] = true
- end
- end
+ local values = infer.searchInfers(rtn[n])
+ for tp in pairs(values) do
+ infers[tp] = true
end
end
- if next(types) or rtn[1] then
- local tp = mergeTypes(types) or 'any'
+ if next(infers) or rtn[1] then
+ local tp = infer.viewInfers(infers)
if rtn[1].name then
line[#line+1] = ('%s%s: %s'):format(
rtn[1].name[1],
@@ -103,7 +91,7 @@ local function asDocFunction(source)
local returns = {}
for i, rtn in ipairs(source.returns) do
local rtnText = ('%s%s'):format(
- vm.getInferType(rtn),
+ infer.searchAndViewInfers(rtn),
rtn.optional and '?' or ''
)
if i == 1 then
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
index edb7751b..159453e6 100644
--- a/script/core/hover/table.lua
+++ b/script/core/hover/table.lua
@@ -1,26 +1,12 @@
local vm = require 'vm'
local util = require 'utility'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local config = require 'config'
local lang = require 'language'
+local infer = require 'core.infer'
-local function getKey(src)
- local key = vm.getKeyName(src)
- if not key or #key <= 0 then
- if not src.index then
- return '[any]'
- end
- local class = vm.getClass(src.index)
- if class then
- return ('[%s]'):format(class)
- end
- local tp = vm.getInferType(src.index)
- if tp then
- return ('[%s]'):format(tp)
- end
- return '[any]'
- end
- if guide.getKeyType(src) == 'string' then
+local function formatKey(key)
+ if type(key) == 'string' then
if key:match '^[%a_][%w_]*$' then
return key
else
@@ -30,104 +16,16 @@ local function getKey(src)
return ('[%s]'):format(key)
end
-local function getFieldFull(src)
- local value = guide.getObjectValue(src) or src
- local tp = vm.getInferType(value, 0)
- --local class = vm.getClass(src)
- local literal = vm.getInferLiteral(value)
- if type(literal) == 'string' and #literal >= 50 then
- literal = literal:sub(1, 47) .. '...'
- end
- return tp, literal
-end
-
-local function getFieldFast(src)
- if src.bindDocs then
- return getFieldFull(src)
- end
- local value = guide.getObjectValue(src) or src
- if not value then
- return 'any'
- end
- if value.type == 'boolean' then
- return value.type, util.viewLiteral(value[1])
- end
- if value.type == 'number'
- or value.type == 'integer' then
- if math.tointeger(value[1]) then
- if config.config.runtime.version == 'Lua 5.3'
- or config.config.runtime.version == 'Lua 5.4' then
- return 'integer', util.viewLiteral(value[1])
- end
- end
- return value.type, util.viewLiteral(value[1])
- end
- if value.type == 'table'
- or value.type == 'function' then
- return value.type
- end
- if value.type == 'string' then
- local literal = value[1]
- if type(literal) == 'string' and #literal >= 50 then
- literal = literal:sub(1, 47) .. '...'
- end
- return value.type, util.viewLiteral(literal)
- end
- if value.type == 'doc.field' then
- return vm.getInferType(value)
- end
-end
-
-local function getField(src, timeUp, mark, key)
- if src.type == 'table'
- or src.type == 'function' then
- return nil
- end
- if src.parent then
- if src.type == 'string'
- or src.type == 'boolean'
- or src.type == 'number'
- or src.type == 'integer' then
- if src.parent.type == 'tableindex'
- or src.parent.type == 'setindex'
- or src.parent.type == 'getindex' then
- if src.parent.index == src then
- src = src.parent
- end
- end
- end
- end
- local tp, literal
- tp, literal = getFieldFast(src)
- if tp then
- return tp, literal
- end
- if timeUp or mark[key] then
- return nil
- end
- mark[key] = true
- tp, literal = getFieldFull(src)
- if tp then
- return tp, literal
- end
- return nil
-end
-
-local function buildAsHash(classes, literals, reachMax)
- local keys = {}
- for k in pairs(classes) do
- keys[#keys+1] = k
- end
- table.sort(keys)
+local function buildAsHash(keys, inferMap, literalMap, reachMax)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local class = classes[key]
- local literal = literals[key]
- if literal then
- lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ local inferView = inferMap[key]
+ local literalView = literalMap[key]
+ if literalView then
+ lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView)
else
- lines[#lines+1] = (' %s: %s,'):format(key, class)
+ lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView)
end
end
if reachMax then
@@ -137,23 +35,19 @@ local function buildAsHash(classes, literals, reachMax)
return table.concat(lines, '\n')
end
-local function buildAsConst(classes, literals, reachMax)
- local keys = {}
- for k in pairs(classes) do
- keys[#keys+1] = k
- end
+local function buildAsConst(keys, inferMap, literalMap, reachMax)
table.sort(keys, function (a, b)
- return tonumber(literals[a]) < tonumber(literals[b])
+ return tonumber(literalMap[a]) < tonumber(literalMap[b])
end)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local class = classes[key]
- local literal = literals[key]
- if literal then
- lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ local inferView = inferMap[key]
+ local literalView = literalMap[key]
+ if literalView then
+ lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView)
else
- lines[#lines+1] = (' %s: %s,'):format(key, class)
+ lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView)
end
end
if reachMax then
@@ -163,111 +57,79 @@ local function buildAsConst(classes, literals, reachMax)
return table.concat(lines, '\n')
end
-local function mergeLiteral(literals)
- local results = {}
+local typeSorter = {
+ ['string'] = 1,
+ ['number'] = 2,
+ ['boolean'] = 3,
+}
+
+local function getKeyMap(fields)
+ local keys = {}
local mark = {}
- for _, value in ipairs(literals) do
- if not mark[value] then
- mark[value] = true
- results[#results+1] = value
+ for _, field in ipairs(fields) do
+ local key = vm.getKeyName(field)
+ local tp = vm.getKeyType(field)
+ if tp == 'number' then
+ key = tonumber(key)
+ elseif tp == 'boolean' then
+ key = key == 'true'
end
- end
- if #results == 0 then
- return nil
- end
- table.sort(results)
- return table.concat(results, '|')
-end
-
-local function mergeTypes(types)
- local results = {}
- local mark = {
- -- 讲道理table的keyvalue不会是nil
- ['nil'] = true,
- }
- for _, tv in ipairs(types) do
- for tp in tv:gmatch '[^|]+' do
- if not mark[tp] then
- mark[tp] = true
- results[tp] = true
- end
+ if key and not mark[key] then
+ mark[key] = true
+ keys[#keys+1] = key
end
end
- return guide.mergeTypes(results)
-end
-
-local function clearClasses(classes)
- classes['[nil]'] = nil
- classes['[any]'] = nil
- classes['[string]'] = nil
+ table.sort(keys, function (a, b)
+ local ta = typeSorter[type(a)]
+ local tb = typeSorter[type(b)]
+ if ta == tb then
+ return tostring(a) < tostring(b)
+ else
+ return ta < tb
+ end
+ end)
+ return keys
end
return function (source)
- if config.config.hover.previewFields <= 0 then
+ local maxFields = config.config.hover.previewFields
+ if maxFields <= 0 then
return 'table'
end
- local literals = {}
- local classes = {}
- local clock = os.clock()
- local timeUp
- local mark = {}
- local fields = vm.getFields(source, 0)
- local keyCount = 0
- local reachMax
- for _, src in ipairs(fields) do
- local key = getKey(src)
- if not key then
- goto CONTINUE
- end
- if not classes[key] then
- classes[key] = {}
- keyCount = keyCount + 1
- end
- if not literals[key] then
- literals[key] = {}
- end
- if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then
- timeUp = true
- end
- local class, literal = getField(src, timeUp, mark, key)
- if literal == 'nil' then
- literal = nil
- end
- classes[key][#classes[key]+1] = class
- literals[key][#literals[key]+1] = literal
- if keyCount >= config.config.hover.previewFields then
- reachMax = true
- break
- end
- ::CONTINUE::
- end
-
- clearClasses(classes)
- for key, class in pairs(classes) do
- literals[key] = mergeLiteral(literals[key])
- classes[key] = mergeTypes(class)
- end
+ local fields = vm.getRefs(source, '*')
+ local keys = getKeyMap(fields)
- if not next(classes) then
+ if #keys == 0 then
return '{}'
end
- local intValue = true
- for key, class in pairs(classes) do
- if class ~= 'integer' or not tonumber(literals[key]) then
- intValue = false
- break
+ local inferMap = {}
+ local literalMap = {}
+
+ local reachMax = maxFields < #keys
+
+ local isConsts = true
+ for i = 1, math.min(maxFields, #keys) do
+ local key = keys[i]
+ inferMap[key] = infer.searchAndViewInfers(source, key)
+ literalMap[key] = infer.searchAndViewLiterals(source, key)
+ if not tonumber(literalMap[key]) then
+ isConsts = false
end
end
+
local result
- if intValue then
- result = buildAsConst(classes, literals, reachMax)
+
+ if isConsts then
+ result = buildAsConst(keys, inferMap, literalMap, reachMax)
else
- result = buildAsHash(classes, literals, reachMax)
- end
- if timeUp then
- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ result = buildAsHash(keys, inferMap, literalMap, reachMax)
end
+
+ --if timeUp then
+ -- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ --end
+
return result
end
diff --git a/script/core/infer.lua b/script/core/infer.lua
new file mode 100644
index 00000000..a2c12fba
--- /dev/null
+++ b/script/core/infer.lua
@@ -0,0 +1,639 @@
+local searcher = require 'core.searcher'
+local config = require 'config'
+local noder = require 'core.noder'
+local util = require 'utility'
+local vm = require "vm.vm"
+
+local STRING_OR_TABLE = {'STRING_OR_TABLE'}
+local BE_RETURN = {'BE_RETURN'}
+local BE_CONNACT = {'BE_CONNACT'}
+local CLASS = {'CLASS'}
+local TABLE = {'TABLE'}
+
+local TypeSort = {
+ ['boolean'] = 1,
+ ['string'] = 2,
+ ['integer'] = 3,
+ ['number'] = 4,
+ ['table'] = 5,
+ ['function'] = 6,
+ ['true'] = 101,
+ ['false'] = 102,
+ ['nil'] = 999,
+}
+
+local m = {}
+
+local function mergeTable(a, b)
+ if not b then
+ return
+ end
+ for v in pairs(b) do
+ a[v] = true
+ end
+end
+
+local function searchInferOfUnary(value, infers, mark)
+ local op = value.op.type
+ if op == 'not' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '#' then
+ infers['integer'] = true
+ return
+ end
+ if op == '-' then
+ if m.hasType(value[1], 'integer', mark) then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+end
+
+local function searchInferOfBinary(value, infers, mark)
+ local op = value.op.type
+ if op == 'and' then
+ if m.isTrue(value[1], mark) then
+ mergeTable(infers, m.searchInfers(value[2], nil, mark))
+ else
+ mergeTable(infers, m.searchInfers(value[1], nil, mark))
+ end
+ return
+ end
+ if op == 'or' then
+ if m.isTrue(value[1], mark) then
+ mergeTable(infers, m.searchInfers(value[1], nil, mark))
+ else
+ mergeTable(infers, m.searchInfers(value[2], nil, mark))
+ end
+ return
+ end
+ if op == '=='
+ or op == '~='
+ or op == '<'
+ or op == '>'
+ or op == '<='
+ or op == '>=' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '&'
+ or op == '|' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers['string'] = true
+ return
+ end
+ if op == '^'
+ or op == '/' then
+ infers['number'] = true
+ return
+ end
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '%'
+ or op == '//' then
+ if m.hasType(value[1], 'integer', mark)
+ and m.hasType(value[2], 'integer', mark) then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+end
+
+local function searchInferOfValue(value, infers, mark)
+ if value.type == 'string' then
+ infers['string'] = true
+ return true
+ end
+ if value.type == 'boolean' then
+ infers['boolean'] = true
+ return true
+ end
+ if value.type == 'table' then
+ if value.array then
+ local node = m.searchAndViewInfers(value.array, nil, mark)
+ local infer = node .. '[]'
+ infers[infer] = true
+ else
+ infers['table'] = true
+ end
+ return true
+ end
+ if value.type == 'number' then
+ if math.type(value[1]) == 'integer' then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return true
+ end
+ if value.type == 'nil' then
+ infers['nil'] = true
+ return true
+ end
+ if value.type == 'function' then
+ infers['function'] = true
+ return true
+ end
+ if value.type == 'unary' then
+ searchInferOfUnary(value, infers, mark)
+ return true
+ end
+ if value.type == 'binary' then
+ searchInferOfBinary(value, infers, mark)
+ return true
+ end
+ return false
+end
+
+local function searchLiteralOfValue(value, literals, mark)
+ if value.type == 'string'
+ or value.type == 'boolean'
+ or value.type == 'number'
+ or value.type == 'integer' then
+ local v = value[1]
+ if v ~= nil then
+ literals[v] = true
+ end
+ return
+ end
+ if value.type == 'unary' then
+ local op = value.op.type
+ if op == '-' then
+ local subLiterals = m.searchLiterals(value[1], nil, mark)
+ if subLiterals then
+ for subLiteral in pairs(subLiterals) do
+ local num = tonumber(subLiteral)
+ if num then
+ literals[-num] = true
+ end
+ end
+ end
+ end
+ if op == '~' then
+ local subLiterals = m.searchLiterals(value[1], nil, mark)
+ if subLiterals then
+ for subLiteral in pairs(subLiterals) do
+ local num = math.tointeger(subLiteral)
+ if num then
+ literals[~num] = true
+ end
+ end
+ end
+ end
+ end
+ return
+end
+
+local function bindClassOrType(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.type' then
+ return true
+ end
+ end
+ return false
+end
+
+local function cleanInfers(infers)
+ local version = config.config.runtime.version
+ local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4'
+ infers['unknown'] = nil
+ if infers['any'] and infers['nil'] then
+ infers['nil'] = nil
+ end
+ if infers['number'] then
+ enableInteger = false
+ end
+ if not enableInteger and infers['integer'] then
+ infers['integer'] = nil
+ infers['number'] = true
+ end
+ -- stringlib 就是 string
+ if infers['stringlib'] and infers['string'] then
+ infers['stringlib'] = nil
+ end
+ -- 如果是通过 .. 来推测的,且结果里没有 number 与 integer,则推测为string
+ if infers[BE_CONNACT] then
+ infers[BE_CONNACT] = nil
+ if not infers['number'] and not infers['integer'] then
+ infers['string'] = true
+ end
+ end
+ -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型
+ if infers[STRING_OR_TABLE] then
+ infers[STRING_OR_TABLE] = nil
+ if not infers['table'] and not infers['string'] then
+ infers['table'] = true
+ infers['string'] = true
+ end
+ end
+ -- 如果有doc标记,则先移除table类型
+ if infers[CLASS] then
+ infers[CLASS] = nil
+ infers['table'] = nil
+ end
+ -- 用doc标记的table,加入table类型
+ if infers[TABLE] then
+ infers[TABLE] = nil
+ infers['table'] = true
+ end
+ if infers[BE_RETURN] then
+ infers[BE_RETURN] = nil
+ infers['nil'] = nil
+ end
+ infers['any'] = nil
+end
+
+---合并对象的推断类型
+---@param infers string[]
+---@return string
+function m.viewInfers(infers)
+ if infers[0] then
+ return infers[0]
+ end
+ -- 如果有显性的 any ,则直接显示为 any
+ if infers['any'] then
+ infers[0] = 'any'
+ return 'any'
+ end
+ local result = {}
+ local count = 0
+ for infer in pairs(infers) do
+ count = count + 1
+ result[count] = infer
+ end
+ -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any
+ if count == 0 then
+ infers[0] = 'any'
+ return 'any'
+ end
+ table.sort(result, function (a, b)
+ local sa = TypeSort[a] or 100
+ local sb = TypeSort[b] or 100
+ if sa == sb then
+ return a < b
+ else
+ return sa < sb
+ end
+ end)
+ infers[0] = table.concat(result, '|')
+ return infers[0]
+end
+
+---合并对象的值
+---@param literals string[]
+---@return string
+function m.viewLiterals(literals)
+ local result = {}
+ local count = 0
+ for infer in pairs(literals) do
+ count = count + 1
+ result[count] = util.viewLiteral(infer)
+ end
+ if count == 0 then
+ return nil
+ end
+ table.sort(result)
+ local view = table.concat(result, '|')
+ return view
+end
+
+function m.viewDocName(doc)
+ if not doc then
+ return nil
+ end
+ if doc.type == 'doc.type' then
+ local list = {}
+ for _, tp in ipairs(doc.types) do
+ list[#list+1] = m.getDocName(tp)
+ end
+ for _, enum in ipairs(doc.enums) do
+ list[#list+1] = m.getDocName(enum)
+ end
+ return table.concat(list, '|')
+ end
+ return m.getDocName(doc)
+end
+
+function m.getDocName(doc)
+ if not doc then
+ return nil
+ end
+ if doc.type == 'doc.class.name'
+ or doc.type == 'doc.type.name' then
+ local name = doc[1] or '?'
+ if doc.typeGeneric then
+ return '<' .. name .. '>'
+ else
+ return name
+ end
+ end
+ if doc.type == 'doc.type.array' then
+ local nodeName = m.viewDocName(doc.node) or '?'
+ return nodeName .. '[]'
+ end
+ if doc.type == 'doc.type.table' then
+ local key = m.viewDocName(doc.tkey) or '?'
+ local value = m.viewDocName(doc.tvalue) or '?'
+ return ('table<%s, %s>'):format(key, value)
+ end
+ if doc.type == 'doc.type.function' then
+ return 'function'
+ end
+ if doc.type == 'doc.type.enum'
+ or doc.type == 'doc.resume' then
+ local value = doc[1] or '?'
+ return value
+ end
+end
+
+function m.viewDocFunction(doc)
+ if doc.type ~= 'doc.type.function' then
+ return ''
+ end
+ local args = {}
+ for i, arg in ipairs(doc.args) do
+ args[i] = ('%s: %s'):format(arg.name[1], m.viewDocName(arg.extends))
+ end
+ local label = ('fun(%s)'):format(table.concat(args, ', '))
+ if #doc.returns > 0 then
+ local returns = {}
+ for i, rtn in ipairs(doc.returns) do
+ returns[i] = m.viewDocName(rtn)
+ end
+ label = ('%s:%s'):format(label, table.concat(returns))
+ end
+ return label
+end
+
+---显示对象的推断类型
+---@param source parser.guide.object
+---@param mark table
+---@return string
+local function searchInfer(source, infers, mark)
+ if bindClassOrType(source) then
+ return
+ end
+ if searchInferOfValue(source, infers, mark) then
+ return
+ end
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchInferOfValue(value, infers, mark)
+ return
+ end
+ -- check LuaDoc
+ local docName = m.getDocName(source)
+ if docName then
+ infers[docName] = true
+ if docName ~= 'unknown' then
+ infers[CLASS] = true
+ end
+ if docName == 'table' then
+ infers[TABLE] = true
+ end
+ end
+ if source.parent.type == 'unary' then
+ local op = source.parent.op.type
+ -- # XX -> string | table
+ if op == '#' then
+ infers[STRING_OR_TABLE] = true
+ return
+ end
+ if op == '-' then
+ infers['number'] = true
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+ return
+ end
+ if source.parent.type == 'binary' then
+ local op = source.parent.op.type
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '/'
+ or op == '//'
+ or op == '^'
+ or op == '%' then
+ infers['number'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '|'
+ or op == '&' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers[BE_CONNACT] = true
+ return
+ end
+ end
+ -- X.a -> table
+ if source.next and source.next.node == source then
+ if source.next.type == 'setfield'
+ or source.next.type == 'setindex'
+ or source.next.type == 'setmethod'
+ or source.next.type == 'getfield'
+ or source.next.type == 'getindex' then
+ infers['table'] = true
+ end
+ if source.next.type == 'getmethod' then
+ infers[STRING_OR_TABLE] = true
+ end
+ end
+ -- return XX
+ if source.parent.type == 'return' then
+ infers[BE_RETURN] = true
+ end
+end
+
+local function searchLiteral(source, literals, mark)
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchLiteralOfValue(value, literals, mark)
+ return
+ end
+end
+
+---搜索对象的推断类型
+---@param source parser.guide.object
+---@param field? string
+---@param mark? table
+---@return string[]
+function m.searchInfers(source, field, mark)
+ if not source then
+ return nil
+ end
+ local defs = vm.getDefs(source, field)
+ local infers = {}
+ mark = mark or {}
+ if not field then
+ mark[source] = true
+ searchInfer(source, infers)
+ local id = noder.getID(source)
+ if id then
+ local node = noder.getNodeByID(source, id)
+ if node and node.sources then
+ for _, src in ipairs(node.sources) do
+ if not mark[src] then
+ mark[src] = true
+ searchInfer(src, infers, mark)
+ end
+ end
+ end
+ end
+ end
+ if source.type == 'field' or source.type == 'method' then
+ mark[source.parent] = true
+ searchInfer(source.parent, infers, mark)
+ end
+ for _, def in ipairs(defs) do
+ if not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers, mark)
+ end
+ end
+ if source.docParam then
+ local docType = source.docParam.extends
+ if docType.type == 'doc.type' then
+ for _, def in ipairs(docType.types) do
+ if def.typeGeneric and not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers, mark)
+ end
+ end
+ end
+ end
+ if source.type == 'doc.type' then
+ if source.type == 'doc.type' then
+ for _, def in ipairs(source.types) do
+ if def.typeGeneric and not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers, mark)
+ end
+ end
+ end
+ end
+ cleanInfers(infers)
+ return infers
+end
+
+---搜索对象的字面量值
+---@param source parser.guide.object
+---@param field? string
+---@param mark? table
+---@return table
+function m.searchLiterals(source, field, mark)
+ local defs = vm.getDefs(source, field)
+ local literals = {}
+ mark = mark or {}
+ if not field then
+ mark[source] = true
+ searchLiteral(source, literals, mark)
+ end
+ for _, def in ipairs(defs) do
+ if not mark[def] then
+ mark[def] = true
+ searchLiteral(def, literals, mark)
+ end
+ end
+ return literals
+end
+
+---搜索并显示推断值
+---@param source parser.guide.object
+---@param field? string
+---@return string
+function m.searchAndViewLiterals(source, field, mark)
+ if not source then
+ return nil
+ end
+ local literals = m.searchLiterals(source, field, mark)
+ local view = m.viewLiterals(literals)
+ return view
+end
+
+---判断对象的推断值是否是 true
+---@param source parser.guide.object
+---@param mark? table
+function m.isTrue(source, mark)
+ if not source then
+ return false
+ end
+ local literals = m.searchLiterals(source, nil, mark)
+ for literal in pairs(literals) do
+ if literal ~= false then
+ return true
+ end
+ end
+ return false
+end
+
+---判断对象的推断类型是否包含某个类型
+function m.hasType(source, tp, mark)
+ local infers = m.searchInfers(source, nil, mark)
+ return infers[tp] or false
+end
+
+---搜索并显示推断类型
+---@param source parser.guide.object
+---@param field? string
+---@return string
+function m.searchAndViewInfers(source, field, mark)
+ if not source then
+ return 'any'
+ end
+ local infers = m.searchInfers(source, field, mark)
+ local view = m.viewInfers(infers)
+ return view
+end
+
+---搜索并显示推断的class
+---@param source parser.guide.object
+---@return string?
+function m.getClass(source)
+ if not source then
+ return nil
+ end
+ local infers = {}
+ local defs = vm.getDefs(source)
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class.name' then
+ infers[def[1]] = true
+ end
+ end
+ local view = m.viewInfers(infers)
+ if view == 'any' then
+ return nil
+ end
+ return view
+end
+
+return m
diff --git a/script/core/keyword.lua b/script/core/keyword.lua
index 71ea4969..73892f18 100644
--- a/script/core/keyword.lua
+++ b/script/core/keyword.lua
@@ -1,6 +1,6 @@
local define = require 'proto.define'
-local guide = require 'core.guide'
local files = require 'files'
+local guide = require 'parser.guide'
local keyWordMap = {
{'do', function (info, results)
diff --git a/script/core/noder.lua b/script/core/noder.lua
new file mode 100644
index 00000000..43d349ee
--- /dev/null
+++ b/script/core/noder.lua
@@ -0,0 +1,1007 @@
+local util = require 'utility'
+local guide = require 'parser.guide'
+local collector = require 'core.collector'
+
+local LastIDCache = {}
+local FirstIDCache = {}
+local SPLIT_CHAR = '\x1F'
+local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$'
+local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*'
+local ANY_FIELD_CHAR = '*'
+local INDEX_CHAR = '['
+local RETURN_INDEX = SPLIT_CHAR .. '#'
+local PARAM_INDEX = SPLIT_CHAR .. '&'
+local TABLE_KEY = SPLIT_CHAR .. '<'
+local INDEX_FIELD = SPLIT_CHAR .. INDEX_CHAR
+local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR
+local URI_CHAR = '@'
+local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(.*)'
+
+---@class node
+-- 当前节点的id
+---@field id string
+-- 使用该ID的单元
+---@field sources parser.guide.object[]
+-- 前进的关联ID
+---@field forward string[]
+-- 后退的关联ID
+---@field backward string[]
+-- 函数调用参数信息(用于泛型)
+---@field call parser.guide.object
+
+---@alias noders table<string, node[]>
+
+---创建source的链接信息
+---@param noders noders
+---@param id string
+---@return node
+local function getNode(noders, id)
+ if not noders[id] then
+ noders[id] = {
+ id = id,
+ }
+ end
+ return noders[id]
+end
+
+---获取语法树单元的key
+---@param source parser.guide.object
+---@return string? key
+---@return parser.guide.object? node
+local function getKey(source)
+ if source.type == 'local' then
+ return tostring(source.start), nil
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ return tostring(source.node.start), nil
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ local node = source.node
+ if node.tag == '_ENV' then
+ return ('%q'):format(source[1] or ''), nil
+ else
+ return ('%q'):format(source[1] or ''), node
+ end
+ elseif source.type == 'getfield'
+ or source.type == 'setfield' then
+ return ('%q'):format(source.field and source.field[1] or ''), source.node
+ elseif source.type == 'tablefield' then
+ return ('%q'):format(source.field and source.field[1] or ''), source.parent
+ elseif source.type == 'getmethod'
+ or source.type == 'setmethod' then
+ return ('%q'):format(source.method and source.method[1] or ''), source.node
+ elseif source.type == 'setindex'
+ or source.type == 'getindex' then
+ local index = source.index
+ if not index then
+ return INDEX_CHAR, source.node
+ end
+ if index.type == 'string'
+ or index.type == 'boolean'
+ or index.type == 'number' then
+ return ('%q'):format(index[1] or ''), source.node
+ else
+ return INDEX_CHAR, source.node
+ end
+ elseif source.type == 'tableindex' then
+ local index = source.index
+ if not index then
+ return ANY_FIELD_CHAR, source.parent
+ end
+ if index.type == 'string'
+ or index.type == 'boolean'
+ or index.type == 'number' then
+ return ('%q'):format(index[1] or ''), source.parent
+ elseif index.type ~= 'function'
+ and index.type ~= 'table' then
+ return ANY_FIELD_CHAR, source.parent
+ end
+ elseif source.type == 'table' then
+ return source.start, nil
+ elseif source.type == 'label' then
+ return source.start, nil
+ elseif source.type == 'goto' then
+ if source.node then
+ return source.node.start, nil
+ end
+ return nil, nil
+ elseif source.type == 'function' then
+ return source.start, nil
+ elseif source.type == 'string' then
+ return '', nil
+ elseif source.type == 'integer'
+ or source.type == 'number'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return source.start, nil
+ elseif source.type == '...' then
+ return source.start, nil
+ elseif source.type == 'varargs' then
+ if source.node then
+ return source.node.start, nil
+ end
+ elseif source.type == 'select' then
+ return ('%d%s%d'):format(source.start, RETURN_INDEX, source.sindex)
+ elseif source.type == 'call' then
+ local node = source.node
+ if node.special == 'rawget'
+ or node.special == 'rawset' then
+ if not source.args then
+ return nil, nil
+ end
+ local tbl, key = source.args[1], source.args[2]
+ if not tbl or not key then
+ return nil, nil
+ end
+ if key.type == 'string' then
+ return ('%q'):format(key[1] or ''), tbl
+ else
+ return '', tbl
+ end
+ end
+ return source.finish, nil
+ elseif source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name'
+ or source.type == 'doc.see.name' then
+ local name = source[1]
+ return name, nil
+ elseif source.type == 'doc.type.name' then
+ local name = source[1]
+ if source.typeGeneric then
+ return source.typeGeneric[name][1].start, nil
+ else
+ return name, nil
+ end
+ elseif source.type == 'doc.class'
+ or source.type == 'doc.type'
+ or source.type == 'doc.param'
+ or source.type == 'doc.vararg'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.function' then
+ return source.start, nil
+ elseif source.type == 'doc.see.field' then
+ return ('%q'):format(source[1]), source.parent.name
+ elseif source.type == 'generic.closure' then
+ return source.call.start, nil
+ elseif source.type == 'generic.value' then
+ return ('%s|%s'):format(
+ source.closure.call.start,
+ getKey(source.proto)
+ )
+ end
+ return nil, nil
+end
+
+local function checkMode(source)
+ if source.type == 'table' then
+ return 't:'
+ end
+ if source.type == 'select' then
+ return 's:'
+ end
+ if source.type == 'function' then
+ return 'f:'
+ end
+ if source.type == 'string' then
+ return 'str:'
+ end
+ if source.type == 'number'
+ or source.type == 'integer'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return 'i:'
+ end
+ if source.type == 'call' then
+ return 'c:'
+ end
+ if source.type == '...'
+ or source.type == 'varargs' then
+ return 'va:'
+ end
+ if source.type == 'doc.class.name'
+ or source.type == 'doc.type.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name' then
+ if source.typeGeneric then
+ return 'dg:'
+ end
+ return 'dn:'
+ end
+ if source.type == 'doc.field.name' then
+ return 'dfn:'
+ end
+ if source.type == 'doc.see.name' then
+ return 'dsn:'
+ end
+ if source.type == 'doc.class' then
+ return 'dc:'
+ end
+ if source.type == 'doc.type' then
+ return 'dt:'
+ end
+ if source.type == 'doc.param' then
+ return 'dp:'
+ end
+ if source.type == 'doc.type.function' then
+ return 'dfun:'
+ end
+ if source.type == 'doc.type.table' then
+ return 'dtable:'
+ end
+ if source.type == 'doc.type.array' then
+ return 'darray:'
+ end
+ if source.type == 'doc.vararg' then
+ return 'dv:'
+ end
+ if source.type == 'doc.type.enum'
+ or source.type == 'doc.resume' then
+ return 'de:'
+ end
+ if source.type == 'generic.closure' then
+ return 'gc:'
+ end
+ if source.type == 'generic.value' then
+ local id = 'gv:'
+ if guide.getUri(source.closure.call) ~= guide.getUri(source.proto) then
+ id = id .. URI_CHAR .. guide.getUri(source.closure.call)
+ end
+ return id
+ end
+ if guide.isGlobal(source) then
+ return 'g:'
+ end
+ if source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ source = source.node
+ end
+ if source.parent.type == 'funcargs' then
+ return 'p:'
+ end
+ return 'l:'
+end
+
+local IDList = {}
+---获取语法树单元的字符串ID
+---@param source parser.guide.object
+---@return string? id
+local function getID(source)
+ if not source then
+ return nil
+ end
+ if source._id ~= nil then
+ return source._id or nil
+ end
+ if source.type == 'field'
+ or source.type == 'method' then
+ source._id = false
+ return nil
+ end
+ local current = source
+ local index = 0
+ while true do
+ if current.type == 'paren' then
+ current = current.exp
+ if not current then
+ break
+ end
+ goto CONTINUE
+ end
+ local id, node = getKey(current)
+ if not id then
+ break
+ end
+ index = index + 1
+ IDList[index] = id
+ if not node then
+ break
+ end
+ current = node
+ if current.special == '_G' then
+ for i = index, 2, -1 do
+ if IDList[i] == '"_G"' then
+ IDList[i] = nil
+ end
+ end
+ break
+ end
+ ::CONTINUE::
+ end
+ if index == 0 then
+ source._id = false
+ return nil
+ end
+ for i = index + 1, #IDList do
+ IDList[i] = nil
+ end
+ local mode = checkMode(current)
+ if not mode then
+ source._id = false
+ return nil
+ end
+ util.revertTable(IDList)
+ local id = mode .. table.concat(IDList, SPLIT_CHAR)
+ source._id = id
+ return id
+end
+
+---添加关联的前进ID
+---@param noders noders
+---@param id string
+---@param forwardID string
+local function pushForward(noders, id, forwardID, tag)
+ if not id
+ or not forwardID
+ or forwardID == ''
+ or id == forwardID then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.forward then
+ node.forward = {}
+ end
+ if node.forward[forwardID] ~= nil then
+ return
+ end
+ node.forward[forwardID] = tag or false
+ node.forward[#node.forward+1] = forwardID
+end
+
+---添加关联的后退ID
+---@param noders noders
+---@param id string
+---@param backwardID string
+local function pushBackward(noders, id, backwardID, tag)
+ if not id
+ or not backwardID
+ or backwardID == ''
+ or id == backwardID then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.backward then
+ node.backward = {}
+ end
+ if node.backward[backwardID] ~= nil then
+ return
+ end
+ node.backward[backwardID] = tag or false
+ node.backward[#node.backward+1] = backwardID
+end
+
+local m = {}
+
+m.SPLIT_CHAR = SPLIT_CHAR
+m.RETURN_INDEX = RETURN_INDEX
+m.PARAM_INDEX = PARAM_INDEX
+m.TABLE_KEY = TABLE_KEY
+m.ANY_FIELD = ANY_FIELD
+m.URI_CHAR = URI_CHAR
+m.INDEX_FIELD = INDEX_FIELD
+
+--- 寻找doc的主体
+---@param obj parser.guide.object
+---@return parser.guide.object
+local function getDocStateWithoutCrossFunction(obj)
+ for _ = 1, 1000 do
+ local parent = obj.parent
+ if not parent then
+ return obj
+ end
+ if parent.type == 'doc' then
+ return obj
+ end
+ if parent.type == 'doc.type.function' then
+ return nil
+ end
+ obj = parent
+ end
+ error('guide.getDocState overstack')
+end
+
+---添加关联单元
+---@param noders noders
+---@param source parser.guide.object
+function m.pushSource(noders, source)
+ local id = m.getID(source)
+ if not id then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.sources then
+ node.sources = {}
+ end
+ node.sources[#node.sources+1] = source
+end
+
+local function bindValue(noders, source, id)
+ local value = source.value
+ if not value then
+ return
+ end
+ local valueID = getID(value)
+ if not valueID then
+ return
+ end
+ if source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ source = source.node
+ end
+ if source.bindDocs and value.type ~= 'table' then
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.type' then
+ return
+ end
+ end
+ end
+ -- x = y : x -> y
+ pushForward(noders, id, valueID, 'set')
+ -- 参数/call禁止反向查找赋值
+ local valueType = valueID:match '^.-:'
+ if valueType ~= 'p:'
+ and valueType ~= 's:'
+ and valueType ~= 'c:' then
+ pushBackward(noders, valueID, id, 'set')
+ end
+end
+
+local function compileCall(noders, call, sourceID, returnIndex)
+ if not sourceID then
+ return
+ end
+ local node = call.node
+ local nodeID = getID(node)
+ if not nodeID then
+ return
+ end
+ local callID = getID(call)
+ if not callID then
+ return
+ end
+ -- 将setmetatable映射到 param1 以及 param2.__index 上
+ if node.special == 'setmetatable' then
+ local tblID = getID(call.args and call.args[1])
+ local metaID = getID(call.args and call.args[2])
+ local indexID
+ if metaID then
+ indexID = ('%s%s%q'):format(
+ metaID,
+ SPLIT_CHAR,
+ '__index'
+ )
+ end
+ pushForward(noders, sourceID, tblID)
+ pushForward(noders, sourceID, indexID)
+ pushBackward(noders, tblID, sourceID)
+ return
+ --pushBackward(noders, indexID, callID)
+ end
+ if node.special == 'require' then
+ local arg1 = call.args and call.args[1]
+ if arg1 and arg1.type == 'string' then
+ getNode(noders, sourceID).require = arg1[1]
+ end
+ return
+ end
+ if node.special == 'pcall'
+ or node.special == 'xpcall' then
+ local index = returnIndex - 1
+ if index <= 0 then
+ return
+ end
+ local funcID = call.args and getID(call.args[1])
+ if not funcID then
+ return
+ end
+ local pfuncXID = ('%s%s%s'):format(
+ funcID,
+ RETURN_INDEX,
+ index
+ )
+ pushForward(noders, sourceID, pfuncXID)
+ --pushBackward(noders, funcXID, id)
+ return
+ end
+ local funcXID = ('%s%s%s'):format(
+ nodeID,
+ RETURN_INDEX,
+ returnIndex
+ )
+ getNode(noders, sourceID).call = call
+ pushForward(noders, sourceID, funcXID)
+end
+
+---@param uri uri
+---@param noders noders
+---@param source parser.guide.object
+---@return parser.guide.object[]
+function m.compileNode(uri, noders, source)
+ local id = getID(source)
+ bindValue(noders, source, id)
+ if source.special == 'setmetatable'
+ or source.special == 'require'
+ or source.special == 'dofile'
+ or source.special == 'loadfile'
+ or source.special == 'rawset'
+ or source.special == 'rawget' then
+ local node = getNode(noders, id)
+ node.skip = true
+ end
+ -- self -> mt:xx
+ if source.type == 'local' and source[1] == 'self' then
+ local func = guide.getParentFunction(source)
+ if func.isGeneric then
+ return
+ end
+ if source.parent.type ~= 'funcargs' then
+ return
+ end
+ local setmethod = func.parent
+ -- guess `self`
+ if setmethod and ( setmethod.type == 'setmethod'
+ or setmethod.type == 'setfield'
+ or setmethod.type == 'setindex') then
+ pushForward(noders, id, getID(setmethod.node), 'method')
+ --pushBackward(noders, getID(setmethod.node), id, 'method')
+ end
+ end
+ -- 分解 @type
+ if source.type == 'doc.type' then
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
+ end
+ end
+ for _, enumUnit in ipairs(source.enums) do
+ pushForward(noders, id, getID(enumUnit))
+ end
+ for _, resumeUnit in ipairs(source.resumes) do
+ pushForward(noders, id, getID(resumeUnit))
+ end
+ for _, typeUnit in ipairs(source.types) do
+ local unitID = getID(typeUnit)
+ pushForward(noders, id, unitID)
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushBackward(noders, unitID, getID(src))
+ end
+ end
+ end
+ end
+ -- 分解 @alias
+ if source.type == 'doc.alias' then
+ pushForward(noders, getID(source.alias), getID(source.extends))
+ end
+ -- 分解 @class
+ if source.type == 'doc.class' then
+ pushForward(noders, id, getID(source.class))
+ pushForward(noders, getID(source.class), id)
+ if source.extends then
+ for _, ext in ipairs(source.extends) do
+ pushBackward(noders, id, getID(ext))
+ end
+ end
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
+ end
+ end
+ for _, field in ipairs(source.fields) do
+ local key = field.field[1]
+ if key then
+ local keyID = ('%s%s%q'):format(
+ id,
+ SPLIT_CHAR,
+ key
+ )
+ pushForward(noders, keyID, getID(field.field))
+ pushForward(noders, getID(field.field), keyID)
+ pushForward(noders, keyID, getID(field.extends))
+ pushBackward(noders, getID(field.extends), keyID)
+ end
+ end
+ end
+ if source.type == 'doc.param' then
+ pushForward(noders, id, getID(source.extends))
+ for _, src in ipairs(source.bindSources) do
+ if src.type == 'local' and src.parent.type == 'in' then
+ pushForward(noders, getID(src), id)
+ end
+ end
+ end
+ if source.type == 'doc.vararg' then
+ pushForward(noders, getID(source), getID(source.vararg))
+ end
+ if source.type == 'doc.see' then
+ local nameID = getID(source.name)
+ local classID = nameID:gsub('^dsn:', 'dn:')
+ pushForward(noders, nameID, classID)
+ if source.field then
+ local fieldID = getID(source.field)
+ local fieldClassID = fieldID:gsub('^dsn:', 'dn:')
+ pushForward(noders, fieldID, fieldClassID)
+ end
+ end
+ if source.type == 'call' then
+ if source.parent.type ~= 'select' then
+ compileCall(noders, source, id, 1)
+ end
+ end
+ if source.type == 'select' then
+ if source.vararg.type == 'call' then
+ local call = source.vararg
+ compileCall(noders, call, id, source.sindex)
+ end
+ if source.vararg.type == 'varargs' then
+ pushForward(noders, id, getID(source.vararg))
+ end
+ end
+ if source.type == 'doc.type.function' then
+ if source.returns then
+ for index, rtn in ipairs(source.returns) do
+ local returnID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ index
+ )
+ pushForward(noders, returnID, getID(rtn))
+ end
+ end
+ -- @type fun(x: T):T 的情况
+ local docType = getDocStateWithoutCrossFunction(source)
+ if docType and docType.type == 'doc.type' then
+ guide.eachSourceType(source, 'doc.type.name', function (typeName)
+ if typeName.typeGeneric then
+ source.isGeneric = true
+ return false
+ end
+ end)
+ end
+ end
+ if source.type == 'doc.type.table' then
+ if source.tkey then
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, getID(source.tkey))
+ end
+ if source.tvalue then
+ local valueID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, valueID, getID(source.tvalue))
+ end
+ end
+ if source.type == 'doc.type.array' then
+ if source.node then
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source.node))
+ end
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, 'dn:integer')
+ end
+ if source.type == 'doc.type.name' then
+ collector.subscribe(uri, id, getNode(noders, id))
+ end
+ if source.type == 'doc.class.name' then
+ collector.subscribe(uri, id, getNode(noders, id))
+ collector.subscribe(uri, 'def:' .. id, getNode(noders, id))
+ collector.subscribe(uri, 'def:dn', getNode(noders, id))
+ end
+ if source.type == 'doc.alias.name' then
+ collector.subscribe(uri, id, getNode(noders, id))
+ collector.subscribe(uri, 'def:' .. id, getNode(noders, id))
+ collector.subscribe(uri, 'def:dn', getNode(noders, id))
+ end
+ if guide.isGlobal(source) then
+ collector.subscribe(uri, id, getNode(noders, id))
+ if guide.isSet(source) then
+ collector.subscribe(uri, 'def:' .. id, getNode(noders, id))
+ collector.subscribe(uri, 'def:g', getNode(noders, id))
+ end
+ end
+ -- 将函数的返回值映射到具体的返回值上
+ if source.type == 'function' then
+ local hasDocReturn = {}
+ -- 检查 luadoc
+ if source.bindDocs then
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ local fullID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ rtn.returnIndex
+ )
+ pushForward(noders, fullID, getID(rtn))
+ hasDocReturn[rtn.returnIndex] = true
+ end
+ end
+ if doc.type == 'doc.param' then
+ local paramName = doc.param[1]
+ if source.docParamMap then
+ local paramIndex = source.docParamMap[paramName]
+ local param = source.args[paramIndex]
+ if param then
+ pushForward(noders, getID(param), getID(doc))
+ param.docParam = doc
+ end
+ end
+ end
+ if doc.type == 'doc.vararg' then
+ for _, param in ipairs(source.args) do
+ if param.type == '...' then
+ pushForward(noders, getID(param), getID(doc))
+ end
+ end
+ end
+ if doc.type == 'doc.generic' then
+ source.isGeneric = true
+ end
+ if doc.type == 'doc.overload' then
+ pushForward(noders, id, getID(doc.overload))
+ end
+ end
+ end
+ -- 检查实体返回值
+ if source.returns then
+ local returns = {}
+ for _, rtn in ipairs(source.returns) do
+ for index, rtnObj in ipairs(rtn) do
+ if not hasDocReturn[index] then
+ if not returns[index] then
+ returns[index] = {}
+ end
+ returns[index][#returns[index]+1] = rtnObj
+ end
+ end
+ end
+ for index, rtnObjs in ipairs(returns) do
+ local returnID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ index
+ )
+ for _, rtnObj in ipairs(rtnObjs) do
+ pushForward(noders, returnID, getID(rtnObj))
+ if rtnObj.type == 'function'
+ or rtnObj.type == 'call' then
+ --pushBackward(noders, getID(rtnObj), returnID)
+ end
+ end
+ end
+ end
+ end
+ if source.type == 'table' then
+ if #source == 1 and source[1].type == 'varargs' then
+ source.array = source[1]
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source[1]))
+ end
+ end
+ if source.type == 'main' then
+ if source.returns then
+ for _, rtn in ipairs(source.returns) do
+ local rtnObj = rtn[1]
+ if rtnObj then
+ pushForward(noders, 'mainreturn', getID(rtnObj))
+ --pushBackward(noders, getID(rtnObj), 'mainreturn')
+ end
+ end
+ end
+ end
+ if source.type == 'generic.closure' then
+ for i, rtn in ipairs(source.returns) do
+ local closureID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ i
+ )
+ local returnID = getID(rtn)
+ pushForward(noders, closureID, returnID)
+ end
+ end
+ if source.type == 'generic.value' then
+ local proto = source.proto
+ local closure = source.closure
+ local upvalues = closure.upvalues
+ if proto.type == 'doc.type.name' then
+ local key = proto[1]
+ if upvalues[key] then
+ for _, paramID in ipairs(upvalues[key]) do
+ pushForward(noders, id, paramID)
+ pushBackward(noders, paramID, id)
+ end
+ end
+ end
+ if proto.type == 'doc.type' then
+ for _, tp in ipairs(source.types) do
+ pushForward(noders, id, getID(tp))
+ pushBackward(noders, getID(tp), id)
+ end
+ end
+ if proto.type == 'doc.type.array' then
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source.node))
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, 'dn:integer')
+ end
+ if proto.type == 'doc.type.table' then
+ if source.tkey then
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, getID(source.tkey))
+ end
+ if source.tvalue then
+ local valueID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, valueID, getID(source.tvalue))
+ end
+ end
+ end
+end
+
+---根据ID来获取所有的node
+---@param root parser.guide.object
+---@param id string
+---@return node?
+function m.getNodeByID(root, id)
+ root = guide.getRoot(root)
+ local noders = root._noders
+ if not noders then
+ return nil
+ end
+ return noders[id]
+end
+
+---根据ID来获取第一个节点的ID
+---@param id string
+---@return string
+function m.getFirstID(id)
+ if FirstIDCache[id] then
+ return FirstIDCache[id] or nil
+ end
+ local firstID, count = id:match(FIRST_REGEX)
+ if count == 0 then
+ FirstIDCache[id] = false
+ return nil
+ end
+ FirstIDCache[id] = firstID
+ return firstID
+end
+
+---根据ID来获取上个节点的ID
+---@param id string
+---@return string
+function m.getLastID(id)
+ if LastIDCache[id] then
+ return LastIDCache[id] or nil
+ end
+ local lastID, count = id:gsub(LAST_REGEX, '')
+ if count == 0 then
+ LastIDCache[id] = false
+ return nil
+ end
+ LastIDCache[id] = lastID
+ return lastID
+end
+
+---测试id是否包含field,如果遇到函数调用则中断
+---@param id string
+---@return boolean
+function m.hasField(id)
+ local firstID = m.getFirstID(id)
+ if firstID == id then
+ return false
+ end
+ local nextChar = id:sub(#firstID + 1, #firstID + 1)
+ if nextChar ~= SPLIT_CHAR then
+ return false
+ end
+ local next2Char = id:sub(#firstID + 2, #firstID + 2)
+ if next2Char == RETURN_INDEX
+ or next2Char == PARAM_INDEX then
+ return false
+ end
+ return true
+end
+
+---把形如 `@file:\\\XXXXX@gv:1|1`拆分成uri与id
+---@param id string
+---@return uri? string
+---@return string id
+function m.getUriAndID(id)
+ local uri, newID = id:match(URI_REGEX)
+ return uri, newID
+end
+
+---获取source的ID
+---@param source parser.guide.object
+---@return string
+function m.getID(source)
+ return getID(source)
+end
+
+---获取source的key
+---@param source parser.guide.object
+---@return string
+function m.getKey(source)
+ return getKey(source)
+end
+
+---清除临时id(用于泛型的临时对象)
+---@param root parser.guide.object
+---@param id string
+function m.removeID(root, id)
+ root = guide.getRoot(root)
+ local noders = root._noders
+ noders[id] = nil
+end
+
+---寻找doc的主体
+---@param doc parser.guide.object
+function m.getDocState(doc)
+ return getDocStateWithoutCrossFunction(doc)
+end
+
+---获取对象的noders
+---@param source parser.guide.object
+---@return noders
+function m.getNoders(source)
+ local root = guide.getRoot(source)
+ if not root._noders then
+ root._noders = {}
+ end
+ return root._noders
+end
+
+---编译整个文件的node
+---@param source parser.guide.object
+---@return table
+function m.compileNodes(source)
+ local root = guide.getRoot(source)
+ local noders = m.getNoders(source)
+ if next(noders) then
+ return noders
+ end
+ local uri = guide.getUri(root)
+ collector.dropUri(uri)
+ log.debug('compileNodes:', guide.getUri(root))
+ guide.eachSource(root, function (src)
+ m.pushSource(noders, src)
+ m.compileNode(uri, noders, src)
+ end)
+ log.debug('compileNodes finish:', guide.getUri(root))
+ return noders
+end
+
+return m
diff --git a/script/core/reference.lua b/script/core/reference.lua
index 7620b09e..6ea79f5f 100644
--- a/script/core/reference.lua
+++ b/script/core/reference.lua
@@ -1,4 +1,5 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
+local guide = require 'parser.guide'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
@@ -6,8 +7,8 @@ local findSource = require 'core.find-source'
local function sortResults(results)
-- 先按照顺序排序
table.sort(results, function (a, b)
- local u1 = guide.getUri(a.target)
- local u2 = guide.getUri(b.target)
+ local u1 = searcher.getUri(a.target)
+ local u2 = searcher.getUri(b.target)
if u1 == u2 then
return a.target.start < b.target.start
else
@@ -19,7 +20,7 @@ local function sortResults(results)
for i = #results, 1, -1 do
local res = results[i].target
local f = res.finish
- local uri = guide.getUri(res)
+ local uri = searcher.getUri(res)
if lf and f > lf and uri == lu then
table.remove(results, i)
else
@@ -52,7 +53,7 @@ local accept = {
}
return function (uri, offset)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
@@ -64,11 +65,23 @@ return function (uri, offset)
local metaSource = vm.isMetaFile(uri)
+ local refs = vm.getRefs(source)
+ local values = {}
+ for _, src in ipairs(refs) do
+ local value = searcher.getObjectValue(src)
+ if value and value ~= src and guide.isLiteral(value) then
+ values[value] = true
+ end
+ end
+
local results = {}
- for _, src in ipairs(vm.getRefs(source, 5)) do
+ for _, src in ipairs(refs) do
if src.dummy then
goto CONTINUE
end
+ if values[src] then
+ goto CONTINUE
+ end
local root = guide.getRoot(src)
if not root then
goto CONTINUE
diff --git a/script/core/rename.lua b/script/core/rename.lua
index da82b0a6..bc85ac14 100644
--- a/script/core/rename.lua
+++ b/script/core/rename.lua
@@ -1,11 +1,11 @@
local files = require 'files'
local vm = require 'vm'
-local guide = require 'core.guide'
local proto = require 'proto'
local define = require 'proto.define'
local util = require 'utility'
local findSource = require 'core.find-source'
-local ws = require 'workspace'
+local guide = require 'parser.guide'
+local noder = require 'core.noder'
local Forcing
@@ -185,7 +185,7 @@ local function renameField(source, newname, callback)
end
callback(source, source.start, source.finish, newname)
elseif parent.type == 'setmethod' then
- local uri = guide.getUri(source)
+ local uri = guide.getUri(source)
local text = files.getText(uri)
local func = parent.value
-- function mt:name () end --> mt['newname'] = function (self) end
@@ -292,14 +292,14 @@ local function ofField(source, newname, callback)
else
node = source.node
end
- for _, src in ipairs(vm.getFields(node, 5)) do
+ for _, src in ipairs(vm.getRefs(node, '*')) do
ofFieldThen(key, src, newname, callback)
end
end
local function ofGlobal(source, newname, callback)
local key = guide.getKeyName(source)
- for _, src in ipairs(vm.getRefs(source, 0)) do
+ for _, src in ipairs(vm.getRefs(source)) do
ofFieldThen(key, src, newname, callback)
end
end
@@ -308,24 +308,27 @@ local function ofLabel(source, newname, callback)
if not isValidName(newname) and not askForcing(newname)then
return false
end
- for _, src in ipairs(vm.getRefs(source, 0)) do
+ for _, src in ipairs(vm.getRefs(source)) do
callback(src, src.start, src.finish, newname)
end
end
local function ofDocTypeName(source, newname, callback)
- for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ local oldname = source[1]
+ for _, doc in ipairs(vm.getRefs(source)) do
if doc.type == 'doc.class.name'
or doc.type == 'doc.type.name'
or doc.type == 'doc.alias.name' then
- callback(doc, doc.start, doc.finish, newname)
+ if oldname == doc[1] then
+ callback(doc, doc.start, doc.finish, newname)
+ end
end
end
end
local function ofDocParamName(source, newname, callback)
callback(source, source.start, source.finish, newname)
- local doc = guide.getDocState(source)
+ local doc = noder.getDocState(source)
if doc.bindSources then
for _, src in ipairs(doc.bindSources) do
if src.type == 'local'
@@ -440,7 +443,7 @@ function m.rename(uri, pos, newname)
if not newname then
return nil
end
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
@@ -489,7 +492,7 @@ function m.rename(uri, pos, newname)
end
function m.prepareRename(uri, pos)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
new file mode 100644
index 00000000..5a417765
--- /dev/null
+++ b/script/core/searcher.lua
@@ -0,0 +1,838 @@
+local noder = require 'core.noder'
+local guide = require 'parser.guide'
+local files = require 'files'
+local generic = require 'core.generic'
+local ws = require 'workspace'
+local vm = require 'vm.vm'
+local await = require 'await'
+local collector = require 'core.collector'
+
+local NONE = {'NONE'}
+local LAST = {'LAST'}
+
+local ignoredIDs = {
+ ['dn:unknown'] = true,
+ ['dn:nil'] = true,
+ ['dn:any'] = true,
+ ['dn:boolean'] = true,
+ ['dn:string'] = true,
+ ['dn:table'] = true,
+ ['dn:number'] = true,
+ ['dn:integer'] = true,
+ ['dn:userdata'] = true,
+ ['dn:lightuserdata'] = true,
+ ['dn:function'] = true,
+ ['dn:thread'] = true,
+}
+
+local m = {}
+
+---@alias guide.searchmode '"ref"'|'"def"'
+
+---添加结果
+---@param status guide.status
+---@param mode guide.searchmode
+---@param source parser.guide.object
+---@param force boolean
+function m.pushResult(status, mode, source, force)
+ if not source then
+ return
+ end
+ local results = status.results
+ local mark = status.mark
+ if mark[source] then
+ return
+ end
+ mark[source] = true
+ if force then
+ results[#results+1] = source
+ return
+ end
+ local parent = source.parent
+ if mode == 'def' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal'
+ or source.type == 'label'
+ or source.type == 'setfield'
+ or source.type == 'setmethod'
+ or source.type == 'setindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset' then
+ results[#results+1] = source
+ end
+ end
+ if parent.type == 'return' then
+ if noder.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
+ elseif mode == 'ref' or mode == 'field' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal'
+ or source.type == 'label'
+ or source.type == 'goto'
+ or source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod'
+ or source.type == 'setindex'
+ or source.type == 'getindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number'
+ or source.type == 'nil'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.type.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset'
+ or source.node.special == 'rawget' then
+ results[#results+1] = source
+ end
+ end
+ if parent.type == 'return' then
+ if noder.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
+ end
+end
+
+---获取uri
+---@param obj parser.guide.object
+---@return uri
+function m.getUri(obj)
+ if obj.uri then
+ return obj.uri
+ end
+ local root = guide.getRoot(obj)
+ if root then
+ return root.uri
+ end
+ return ''
+end
+
+---@param obj parser.guide.object
+---@return parser.guide.object?
+function m.getObjectValue(obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return nil
+ end
+ end
+ if obj.type == 'boolean'
+ or obj.type == 'number'
+ or obj.type == 'integer'
+ or obj.type == 'string' then
+ return obj
+ end
+ if obj.value then
+ return obj.value
+ end
+ if obj.type == 'field'
+ or obj.type == 'method' then
+ return obj.parent and obj.parent.value
+ end
+ if obj.type == 'call' then
+ if obj.node.special == 'rawset' then
+ return obj.args and obj.args[3]
+ else
+ return obj
+ end
+ end
+ if obj.type == 'select' then
+ return obj
+ end
+ return nil
+end
+
+local function crossSearch(status, uri, expect, mode, sourceUri)
+ if status.lock[uri] then
+ return
+ end
+ status.lock[uri] = true
+ await.delay()
+ if TRACE then
+ log.debug('crossSearch', uri, expect)
+ end
+ if FOOTPRINT then
+ status.footprint[#status.footprint+1] = ('cross search:%s %s'):format(uri, expect)
+ end
+ m.searchRefsByID(status, uri, expect, mode)
+ status.lock[uri] = nil
+ if FOOTPRINT then
+ status.footprint[#status.footprint+1] = ('cross search finish, back to: %s'):format(sourceUri)
+ end
+end
+
+local function checkCache(status, uri, expect, mode)
+ local cache = vm.getCache('search:' .. mode)
+ local fileCache = cache[uri]
+ if not fileCache then
+ fileCache = {}
+ cache[uri] = fileCache
+ end
+ if fileCache[expect] then
+ for _, res in ipairs(fileCache[expect]) do
+ m.pushResult(status, mode, res, true)
+ end
+ return true
+ end
+ fileCache[expect] = status.results
+ return false
+end
+
+function m.searchRefsByID(status, uri, expect, mode)
+ local ast = files.getState(uri)
+ if not ast then
+ return
+ end
+ local root = ast.ast
+ local searchStep
+ noder.compileNodes(root)
+
+ status.id = expect
+
+ local callStack = status.callStack
+
+ local mark = {}
+
+ local function search(id, field)
+ local firstID = noder.getFirstID(id)
+ if ignoredIDs[firstID] and (field or firstID ~= id) then
+ return
+ end
+ local cmark = mark[id]
+ if not cmark then
+ cmark = {}
+ mark[id] = cmark
+ end
+ if cmark[field or NONE] then
+ return
+ end
+ if TRACE then
+ log.debug('search:', id, field)
+ end
+ if FOOTPRINT then
+ if field then
+ status.footprint[#status.footprint+1] = 'search\t' .. id .. '\t' .. field
+ else
+ status.footprint[#status.footprint+1] = 'search\t' .. id
+ end
+ end
+ cmark[field or NONE] = true
+ searchStep(id, field)
+ if TRACE then
+ log.debug('pop:', id, field)
+ end
+ if FOOTPRINT then
+ if field then
+ status.footprint[#status.footprint+1] = 'pop\t' .. id .. '\t' .. field
+ else
+ status.footprint[#status.footprint+1] = 'pop\t' .. id
+ end
+ end
+ end
+
+ local function checkLastID(id, field)
+ local cmark = mark[id]
+ if not cmark then
+ cmark = {}
+ mark[id] = cmark
+ end
+ if cmark[LAST] then
+ return
+ end
+ local lastID = noder.getLastID(id)
+ if not lastID then
+ return
+ end
+ local newField = id:sub(#lastID + 1)
+ if field then
+ newField = newField .. field
+ end
+ cmark[LAST] = true
+ search(lastID, newField)
+ return lastID
+ end
+
+ local function searchID(id, field)
+ if not id then
+ return
+ end
+ if field then
+ id = id .. field
+ end
+ search(id, nil)
+ end
+
+ local function isCallID(field)
+ if not field then
+ return false
+ end
+ if field:sub(1, 2) == noder.RETURN_INDEX then
+ return true
+ end
+ return false
+ end
+
+ local function findLastCall()
+ for i = #callStack, 1, -1 do
+ local call = callStack[i]
+ if call then
+ -- 标记此处的call失效,等待在堆栈平衡时弹出
+ callStack[i] = false
+ return call
+ end
+ end
+ return nil
+ end
+
+ local genericCallArgs = {}
+ local closureCache = {}
+ local function checkGeneric(source, field)
+ if not source.isGeneric then
+ return
+ end
+ if not isCallID(field) then
+ return
+ end
+ local call = findLastCall()
+ if not call then
+ return
+ end
+
+ if call.args then
+ for _, arg in ipairs(call.args) do
+ genericCallArgs[arg] = true
+ end
+ end
+
+ local cacheID = noder.getID(source) .. noder.getID(call)
+ local closure = closureCache[cacheID]
+ if closure == false then
+ return
+ end
+ if not closure then
+ closure = generic.createClosure(source, call)
+ closureCache[cacheID] = closure or false
+ if not closure then
+ return
+ end
+ end
+ local id = noder.getID(closure)
+ searchID(id, field)
+ end
+
+ local function checkENV(source, field)
+ if not field then
+ return
+ end
+ if source.special ~= '_G' then
+ return
+ end
+ local newID = 'g:' .. field:sub(2)
+ searchID(newID)
+ end
+
+ local function checkThenPushTag(ward, tag)
+ if not tag then
+ return true
+ end
+ local checkTags
+ local pushTags
+ if ward == 'forward' then
+ checkTags = status.btag
+ pushTags = status.ftag
+ else
+ checkTags = status.ftag
+ pushTags = status.btag
+ end
+ if checkTags[tag] and checkTags[tag] > 0 then
+ return false
+ end
+ pushTags[tag] = (pushTags[tag] or 0) + 1
+ return true
+ end
+
+ local function popTag(ward, tag)
+ if not tag then
+ return
+ end
+ local popTags
+ if ward == 'forward' then
+ popTags = status.ftag
+ else
+ popTags = status.btag
+ end
+ popTags[tag] = popTags[tag] - 1
+ end
+
+ local function checkForward(id, node, field)
+ for _, forwardID in ipairs(node.forward) do
+ local tag = node.forward[forwardID]
+ if not checkThenPushTag('forward', tag) then
+ goto CONTINUE
+ end
+ local targetUri, targetID = noder.getUriAndID(forwardID)
+ if targetUri and not files.eq(targetUri, uri) then
+ crossSearch(status, targetUri, targetID .. (field or ''), mode, uri)
+ else
+ searchID(targetID or forwardID, field)
+ end
+ popTag('forward', tag)
+ ::CONTINUE::
+ end
+ end
+
+ local function checkBackward(id, node, field)
+ if mode ~= 'ref' and mode ~= 'field' and not field then
+ return
+ end
+ for _, backwardID in ipairs(node.backward) do
+ local tag = node.backward[backwardID]
+ if not checkThenPushTag('backward', tag) then
+ goto CONTINUE
+ end
+ local targetUri, targetID = noder.getUriAndID(backwardID)
+ if targetUri and not files.eq(targetUri, uri) then
+ crossSearch(status, targetUri, targetID .. (field or ''), mode, uri)
+ else
+ searchID(targetID or backwardID, field)
+ end
+ popTag('backward', tag)
+ ::CONTINUE::
+ end
+ end
+
+ local function checkSpecial(id, node, field)
+ -- Special rule: ('').XX -> stringlib.XX
+ if id == 'str:'
+ or id == 'dn:string' then
+ if field or mode == 'field' then
+ searchID('dn:stringlib', field)
+ end
+ return true
+ end
+ return false
+ end
+
+ local function checkRequire(requireName, field)
+ local tid = 'mainreturn' .. (field or '')
+ local uris = ws.findUrisByRequirePath(requireName)
+ if FOOTPRINT then
+ status.footprint[#status.footprint+1] = ('require %q:\n%s'):format(requireName, table.concat(uris, '\n'))
+ end
+ for _, ruri in ipairs(uris) do
+ if not files.eq(uri, ruri) then
+ crossSearch(status, ruri, tid, mode, uri)
+ end
+ end
+ end
+
+ local function checkGlobal(id, node, field)
+ if status.crossed[id] then
+ return
+ end
+ status.crossed[id] = true
+ --if not checkThenPushTag('forward', 'set') then
+ -- return
+ --end
+ local isCall = field and field:sub(2, 2) == noder.RETURN_INDEX
+ local tid = id .. (field or '')
+ if FOOTPRINT then
+ status.footprint[#status.footprint+1] = ('checkGlobal:%s + %s, isCall: %s'):format(id, field, isCall, tid)
+ end
+ for guri, def in collector.each(id) do
+ if def then
+ crossSearch(status, guri, tid, mode, uri)
+ goto CONTINUE
+ end
+ if isCall then
+ goto CONTINUE
+ end
+ if not field then
+ goto CONTINUE
+ end
+ if mode == 'def' then
+ goto CONTINUE
+ end
+ if not files.eq(uri, guri) then
+ goto CONTINUE
+ end
+ crossSearch(status, guri, tid, mode, uri)
+ ::CONTINUE::
+ end
+ --popTag('forward', 'set')
+ end
+
+ local function checkClass(id, node, field)
+ if status.crossed[id] then
+ return
+ end
+ status.crossed[id] = true
+ local tid = id .. (field or '')
+ for guri in collector.each(id) do
+ if not files.eq(uri, guri) then
+ crossSearch(status, guri, tid, mode, uri)
+ end
+ end
+ end
+
+ local function searchNode(id, node, field)
+ if node.call then
+ callStack[#callStack+1] = node.call
+ end
+ if field == nil and node.sources then
+ for _, source in ipairs(node.sources) do
+ local force = genericCallArgs[source]
+ m.pushResult(status, mode, source, force)
+ end
+ end
+
+ if node.require then
+ checkRequire(node.require, field)
+ return
+ end
+
+ local isSepcial = checkSpecial(id, node, field)
+ if not isSepcial then
+ if node.forward then
+ checkForward(id, node, field)
+ end
+ if node.backward then
+ checkBackward(id, node, field)
+ end
+ end
+
+ if node.sources then
+ checkGeneric(node.sources[1], field)
+ checkENV(node.sources[1], field)
+ end
+
+ --checkMainReturn(id, node, field)
+
+ if node.call then
+ callStack[#callStack] = nil
+ end
+
+ return false
+ end
+
+ local function checkAnyField(id, field)
+ if mode == 'ref' or mode == 'field' then
+ return
+ end
+ local lastID = noder.getLastID(id)
+ if not lastID then
+ return
+ end
+ local originField = id:sub(#lastID + 1)
+ if originField == noder.TABLE_KEY then
+ return
+ end
+ local anyFieldID = lastID .. noder.ANY_FIELD
+ local anyFieldNode = noder.getNodeByID(root, anyFieldID)
+ if anyFieldNode then
+ searchNode(anyFieldID, anyFieldNode, field)
+ end
+ end
+
+ local stepCount = 0
+ function searchStep(id, field)
+ stepCount = stepCount + 1
+ status.count = status.count + 1
+ if stepCount > 1000
+ or status.count > 10000 then
+ if TEST then
+ if FOOTPRINT then
+ log.debug(table.concat(status.footprint, '\n'))
+ end
+ error('too large!')
+ else
+ log.warn('too large!')
+ if FOOTPRINT then
+ log.debug(table.concat(status.footprint, '\n'))
+ end
+ return
+ end
+ end
+ local node = noder.getNodeByID(root, id)
+ if node then
+ searchNode(id, node, field)
+ if node.skip and field then
+ return
+ end
+ end
+ checkGlobal(id, node, field)
+ checkClass(id, node, field)
+ checkLastID(id, field)
+ checkAnyField(id, field)
+ end
+
+ search(expect)
+
+ --清除来自泛型的临时对象
+ for _, closure in pairs(closureCache) do
+ noder.removeID(root, noder.getID(closure))
+ if closure then
+ for _, value in ipairs(closure.values) do
+ noder.removeID(root, noder.getID(value))
+ end
+ end
+ end
+end
+
+local function prepareSearch(source)
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ local root = guide.getRoot(source)
+ noder.compileNodes(root)
+ local uri = guide.getUri(source)
+ local id = noder.getID(source)
+ return uri, id
+end
+
+local function getField(status, source, mode)
+ if source.type == 'table' then
+ for _, field in ipairs(source) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ m.pushResult(status, mode, field)
+ end
+ end
+ return
+ end
+ if source.type == 'doc.class.name' then
+ local class = source.parent
+ for _, field in ipairs(class.fields) do
+ m.pushResult(status, mode, field.field)
+ end
+ return
+ end
+ local field = source.next
+ if field then
+ if field.type == 'getmethod'
+ or field.type == 'setmethod'
+ or field.type == 'getfield'
+ or field.type == 'setfield'
+ or field.type == 'getindex'
+ or field.type == 'setindex' then
+ m.pushResult(status, mode, field)
+ end
+ return
+ end
+end
+
+local function searchAllGlobalByUri(status, mode, uri, fullID)
+ local ast = files.getState(uri)
+ if not ast then
+ return
+ end
+ local root = ast.ast
+ noder.compileNodes(root)
+ local noders = noder.getNoders(root)
+ if fullID then
+ for id, node in pairs(noders) do
+ if node.sources
+ and id == fullID then
+ for _, source in ipairs(node.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ end
+ else
+ for id, node in pairs(noders) do
+ if node.sources
+ and id:sub(1, 2) == 'g:'
+ and not id:find(noder.SPLIT_CHAR) then
+ for _, source in ipairs(node.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ end
+ end
+end
+
+local function searchAllGlobals(status, mode)
+ for uri in files.eachFile() do
+ searchAllGlobalByUri(status, mode, uri)
+ end
+end
+
+---查找全局变量
+---@param uri uri
+---@param mode guide.searchmode
+---@param name string
+---@return parser.guide.object[]
+function m.findGlobals(uri, mode, name)
+ local status = m.status(mode)
+
+ if name then
+ local fullID = ('g:%q'):format(name)
+ searchAllGlobalByUri(status, mode, uri, fullID)
+ else
+ searchAllGlobalByUri(status, mode, uri)
+ end
+
+ return status.results
+end
+
+---搜索对象的引用
+---@param status guide.status
+---@param source parser.guide.object
+---@param mode guide.searchmode
+function m.searchRefs(status, source, mode)
+ local uri, id = prepareSearch(source)
+ if not id then
+ return
+ end
+
+ if checkCache(status, uri, id, mode) then
+ return
+ end
+
+ if TRACE then
+ log.debug('searchRefs:', id)
+ end
+ m.searchRefsByID(status, uri, id, mode)
+end
+
+---搜索对象的field
+---@param status guide.status
+---@param source parser.guide.object
+---@param mode guide.searchmode
+---@param field string
+function m.searchFields(status, source, mode, field)
+ local uri, id = prepareSearch(source)
+ if not id then
+ return
+ end
+ if TRACE then
+ log.debug('searchFields:', id, field)
+ end
+ if field == '*' then
+ if source.special == '_G' then
+ if checkCache(status, uri, '*', mode) then
+ return
+ end
+ searchAllGlobals(status, mode)
+ else
+ if checkCache(status, uri, id .. '*', mode) then
+ return
+ end
+ local newStatus = m.status('field')
+ m.searchRefsByID(newStatus, uri, id, 'field')
+ for _, def in ipairs(newStatus.results) do
+ getField(status, def, mode)
+ end
+ end
+ else
+ if source.special == '_G' then
+ local fullID = ('g:%q'):format(field)
+ if checkCache(status, uri, fullID, mode) then
+ return
+ end
+ m.searchRefsByID(status, uri, fullID, mode)
+ else
+ local fullID = ('%s%s%q'):format(id, noder.SPLIT_CHAR, field)
+ if checkCache(status, uri, fullID, mode) then
+ return
+ end
+ m.searchRefsByID(status, uri, fullID, mode)
+ end
+ end
+end
+
+---@class guide.status
+---搜索结果
+---@field results parser.guide.object[]
+
+---创建搜索状态
+---@param mode guide.searchmode
+---@return guide.status
+function m.status(mode)
+ local status = {
+ callStack = {},
+ crossed = {},
+ lock = {},
+ results = {},
+ mark = {},
+ footprint = {},
+ count = 0,
+ ftag = {},
+ btag = {},
+ cache = vm.getCache('searcher:' .. mode)
+ }
+ return status
+end
+
+--- 请求对象的引用
+---@param obj parser.guide.object
+---@param field? string
+---@return parser.guide.object[]
+function m.requestReference(obj, field)
+ local status = m.status('ref')
+
+ if field then
+ m.searchFields(status, obj, 'ref', field)
+ else
+ m.searchRefs(status, obj, 'ref')
+ end
+
+ return status.results
+end
+
+--- 请求对象的定义
+---@param obj parser.guide.object
+---@param field? string
+---@return parser.guide.object[]
+function m.requestDefinition(obj, field)
+ local status = m.status('def')
+
+ if field then
+ m.searchFields(status, obj, 'def', field)
+ else
+ m.searchRefs(status, obj, 'def')
+ end
+
+ return status.results
+end
+
+return m
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index f8feaa09..f310e3f1 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -1,9 +1,10 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local await = require 'await'
local define = require 'proto.define'
local vm = require 'vm'
local util = require 'utility'
+local guide = require 'parser.guide'
local Care = {}
Care['setglobal'] = function (source, results)
@@ -212,7 +213,7 @@ local function buildTokens(uri, results)
end
return function (uri, start, finish)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local lines = files.getLines(uri)
local text = files.getText(uri)
if not ast then
diff --git a/script/core/signature.lua b/script/core/signature.lua
index eb740784..8de1c374 100644
--- a/script/core/signature.lua
+++ b/script/core/signature.lua
@@ -1,8 +1,9 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local vm = require 'vm'
local hoverLabel = require 'core.hover.label'
local hoverDesc = require 'core.hover.description'
+local guide = require 'parser.guide'
local function findNearCall(uri, ast, pos)
local text = files.getText(uri)
@@ -96,10 +97,10 @@ local function makeSignatures(call, pos)
index = 1
end
local signs = {}
- local defs = vm.getDefs(node, 0)
+ local defs = vm.getDefs(node)
local mark = {}
for _, src in ipairs(defs) do
- src = guide.getObjectValue(src) or src
+ src = searcher.getObjectValue(src) or src
if src.type == 'function'
or src.type == 'doc.type.function' then
if not mark[src] then
@@ -132,7 +133,7 @@ local function skipSpace(text, offset)
end
return function (uri, pos)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return nil
end
diff --git a/script/core/type-formatting.lua b/script/core/type-formatting.lua
index c2290ef3..b01a1999 100644
--- a/script/core/type-formatting.lua
+++ b/script/core/type-formatting.lua
@@ -1,6 +1,6 @@
local files = require 'files'
local lookBackward = require 'core.look-backward'
-local guide = require 'core.guide'
+local guide = require "parser.guide"
local function insertIndentation(uri, offset, edits)
local lines = files.getLines(uri)
@@ -69,7 +69,7 @@ local function checkSplitOneLine(results, uri, offset, ch)
end
return function (uri, offset, ch)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
local text = files.getOriginText(uri)
if not ast or not text then
return nil
diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua
index ae420d32..18ab1eeb 100644
--- a/script/core/workspace-symbol.lua
+++ b/script/core/workspace-symbol.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local matchKey = require 'core.matchkey'
local define = require 'proto.define'
local await = require 'await'
@@ -47,7 +47,7 @@ local function buildSource(uri, source, key, results)
end
local function searchFile(uri, key, results)
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return
end
diff --git a/script/files.lua b/script/files.lua
index 9cc6b549..a653b364 100644
--- a/script/files.lua
+++ b/script/files.lua
@@ -9,7 +9,7 @@ local await = require 'await'
local timer = require 'timer'
local plugin = require 'plugin'
local util = require 'utility'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local smerger = require 'string-merger'
local progress = require "progress"
@@ -34,9 +34,9 @@ m.assocMatcher = nil
m.globalVersion = 0
m.fileCount = 0
m.astCount = 0
-m.linesMap = setmetatable({}, { __mode = 'v' })
-m.originLinesMap = setmetatable({}, { __mode = 'v' })
-m.astMap = setmetatable({}, { __mode = 'v' })
+m.linesMap = {} --setmetatable({}, { __mode = 'v' })
+m.originLinesMap = {} --setmetatable({}, { __mode = 'v' })
+m.astMap = {} --setmetatable({}, { __mode = 'v' })
local uriMap = {}
local function getUriKey(uri)
@@ -345,6 +345,7 @@ function m.getAllUris()
i = i + 1
files[i] = uri
end
+ table.sort(files)
end
return m._pairsCache
end
@@ -377,7 +378,7 @@ function m.eachDll()
return pairs(map)
end
-function m.compileAst(uri, text)
+function m.compileState(uri, text)
local ws = require 'workspace'
if not m.isOpen(uri) and #text >= config.config.workspace.preloadFileSize * 1000 then
if not m.notifyCache['preloadFileSize'] then
@@ -445,8 +446,8 @@ end
--- 获取文件语法树
---@param uri uri
----@return table ast
-function m.getAst(uri)
+---@return table state
+function m.getState(uri)
uri = getUriKey(uri)
if uri ~= '' and not m.isLua(uri) then
return nil
@@ -457,7 +458,7 @@ function m.getAst(uri)
end
local ast = m.astMap[uri]
if not ast then
- ast = m.compileAst(uri, file.text)
+ ast = m.compileState(uri, file.text)
m.astMap[uri] = ast
--await.delay()
end
diff --git a/script/parser/ast.lua b/script/parser/ast.lua
index 45d77631..40b5788e 100644
--- a/script/parser/ast.lua
+++ b/script/parser/ast.lua
@@ -110,7 +110,7 @@ local function getSelect(vararg, index)
start = vararg.start,
finish = vararg.finish,
vararg = vararg,
- index = index,
+ sindex = index,
}
end
@@ -1460,8 +1460,14 @@ local Defs = {
local values
if func then
local call = createCall(exp, func.finish + 1, exp.finish)
+ if #exp == 0 then
+ exp[1] = getSelect(func, 2)
+ exp[2] = getSelect(func, 3)
+ exp[3] = getSelect(func, 4)
+ end
call.node = func
- call.start = func.start
+ call.start = inA
+ call.finish = doB - 1
func.next = call
func.iterator = true
values = { call }
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index a7e0dc1f..21be406d 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -125,6 +125,7 @@ local vmMap = {
vararg.ref = {}
end
vararg.ref[#vararg.ref+1] = obj
+ obj.node = vararg
end
end
end,
@@ -150,8 +151,8 @@ local vmMap = {
local value = obj.value
local localself = {
type = 'local',
- start = 0,
- finish = 0,
+ start = value.start,
+ finish = value.finish,
method = obj,
effect = obj.finish,
tag = 'self',
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 2369e84f..ad07e90e 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -1,34 +1,8 @@
-local util = require 'utility'
local error = error
local type = type
-local next = next
-local tostring = tostring
-local print = print
-local ipairs = ipairs
-local tableInsert = table.insert
-local tableUnpack = table.unpack
-local tableRemove = table.remove
-local tableMove = table.move
-local tableSort = table.sort
-local tableConcat = table.concat
-local mathType = math.type
-local pairs = pairs
-local setmetatable = setmetatable
-local assert = assert
-local select = select
-local osClock = os.clock
-local tonumber = tonumber
-local tointeger = math.tointeger
-local DEVELOP = _G.DEVELOP
-local log = log
-local _G = _G
---@class parser.guide.object
-local function logWarn(...)
- log.warn(...)
-end
-
---@class guide
---@field debugMode boolean
local m = {}
@@ -91,7 +65,7 @@ m.childMap = {
['doc'] = {'#'},
['doc.class'] = {'class', '#extends', 'comment'},
- ['doc.type'] = {'#types', '#enums', 'name', 'comment'},
+ ['doc.type'] = {'#types', '#enums', '#resumes', 'name', 'comment'},
['doc.alias'] = {'alias', 'extends', 'comment'},
['doc.param'] = {'param', 'extends', 'comment'},
['doc.return'] = {'#returns', 'comment'},
@@ -100,9 +74,9 @@ m.childMap = {
['doc.generic.object'] = {'generic', 'extends', 'comment'},
['doc.vararg'] = {'vararg', 'comment'},
['doc.type.array'] = {'node'},
- ['doc.type.table'] = {'node', 'key', 'value', 'comment'},
+ ['doc.type.table'] = {'tkey', 'tvalue', 'comment'},
['doc.type.function'] = {'#args', '#returns', 'comment'},
- ['doc.type.typeliteral'] = {'node'},
+ ['doc.type.literal'] = {'node'},
['doc.type.arg'] = {'extends'},
['doc.overload'] = {'overload', 'comment'},
['doc.see'] = {'name', 'field'},
@@ -123,19 +97,31 @@ m.actionMap = {
['funcargs'] = {'#'},
}
-local TypeSort = {
- ['boolean'] = 1,
- ['string'] = 2,
- ['integer'] = 3,
- ['number'] = 4,
- ['table'] = 5,
- ['function'] = 6,
- ['true'] = 101,
- ['false'] = 102,
- ['nil'] = 999,
-}
+local inf = 1 / 0
+local nan = 0 / 0
+
+local function isInteger(n)
+ if math.type then
+ return math.type(n) == 'integer'
+ else
+ return type(n) == 'number' and n % 1 == 0
+ end
+end
-local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
+local function formatNumber(n)
+ if n == inf
+ or n == -inf
+ or n == nan
+ or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则
+ return ('%q'):format(n)
+ end
+ if isInteger(n) then
+ return tostring(n)
+ end
+ local str = ('%.10f'):format(n)
+ str = str:gsub('%.?0*$', '')
+ return str
+end
--- 是否是字面量
---@param obj parser.guide.object
@@ -182,23 +168,6 @@ function m.getParentFunction(obj)
return nil
end
---- 寻找父的table类型 doc.type.table
----@param obj parser.guide.object
----@return parser.guide.object
-function m.getParentDocTypeTable(obj)
- for _ = 1, 1000 do
- local parent = obj.parent
- if not parent then
- return nil
- end
- if parent.type == 'doc.type.table' then
- return obj
- end
- obj = parent
- end
- error('guide.getParentDocTypeTable overstack')
-end
-
--- 寻找所在区块
---@param obj parser.guide.object
---@return parser.guide.object
@@ -293,10 +262,19 @@ end
---@param obj parser.guide.object
---@return parser.guide.object
function m.getRoot(obj)
+ local source = obj
+ if source._root then
+ return source._root
+ end
for _ = 1, 1000 do
if obj.type == 'main' then
+ source._root = obj
return obj
end
+ if obj._root then
+ source._root = obj._root
+ return source._root
+ end
local parent = obj.parent
if not parent then
return nil
@@ -501,8 +479,8 @@ function m.addChilds(list, obj, map)
for i = 1, #keys do
local key = keys[i]
if key == '#' then
- for i = 1, #obj do
- list[#list+1] = obj[i]
+ for j = 1, #obj do
+ list[#list+1] = obj[j]
end
elseif obj[key] then
list[#list+1] = obj[key]
@@ -510,8 +488,8 @@ function m.addChilds(list, obj, map)
and key:sub(1, 1) == '#' then
key = key:sub(2)
if obj[key] then
- for i = 1, #obj[key] do
- list[#list+1] = obj[key][i]
+ for j = 1, #obj[key] do
+ list[#list+1] = obj[key][j]
end
end
end
@@ -613,9 +591,16 @@ function m.eachSource(ast, callback)
index = index + 1
if not mark[obj] then
mark[obj] = true
- callback(obj)
+ local res = callback(obj)
+ if res == true then
+ goto CONTINUE
+ end
+ if res == false then
+ return
+ end
m.addChilds(list, obj, m.childMap)
end
+ ::CONTINUE::
end
end
@@ -718,4 +703,294 @@ function m.lineData(lines, row)
return lines[row]
end
+function m.isSet(source)
+ local tp = source.type
+ if tp == 'setglobal'
+ or tp == 'local'
+ or tp == 'setlocal'
+ or tp == 'setfield'
+ or tp == 'setmethod'
+ or tp == 'setindex'
+ or tp == 'tablefield'
+ or tp == 'tableindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(source.node)
+ if special == 'rawset' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isGet(source)
+ local tp = source.type
+ if tp == 'getglobal'
+ or tp == 'getlocal'
+ or tp == 'getfield'
+ or tp == 'getmethod'
+ or tp == 'getindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(source.node)
+ if special == 'rawget' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.getSpecial(source)
+ if not source then
+ return nil
+ end
+ return source.special
+end
+
+function m.getKeyNameOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'field'
+ or tp == 'method' then
+ return obj[1]
+ elseif tp == 'string' then
+ local s = obj[1]
+ if s then
+ return s
+ end
+ elseif tp == 'number' then
+ local n = obj[1]
+ if n then
+ return ('%s'):format(formatNumber(obj[1]))
+ end
+ elseif tp == 'boolean' then
+ local b = obj[1]
+ if b then
+ return tostring(b)
+ end
+ end
+end
+
+function m.getKeyName(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return obj[1]
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return obj[1]
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ if obj.field then
+ return obj.field[1]
+ end
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ if obj.method then
+ return obj.method[1]
+ end
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getKeyNameOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method'
+ or tp == 'doc.see.field' then
+ return obj[1]
+ elseif tp == 'doc.class' then
+ return obj.class[1]
+ elseif tp == 'doc.alias' then
+ return obj.alias[1]
+ elseif tp == 'doc.field' then
+ return obj.field[1]
+ elseif tp == 'doc.field.name' then
+ return obj[1]
+ elseif tp == 'dummy' then
+ return obj[1]
+ end
+ return m.getKeyNameOfLiteral(obj)
+end
+
+function m.getKeyTypeOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'field'
+ or tp == 'method' then
+ return 'string'
+ elseif tp == 'string' then
+ return 'string'
+ elseif tp == 'number' then
+ return 'number'
+ elseif tp == 'boolean' then
+ return 'boolean'
+ end
+end
+
+function m.getKeyType(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return 'string'
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return 'local'
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ return 'string'
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ return 'string'
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getKeyTypeOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method'
+ or tp == 'doc.see.field' then
+ return 'string'
+ elseif tp == 'doc.class' then
+ return 'string'
+ elseif tp == 'doc.alias' then
+ return 'string'
+ elseif tp == 'doc.field' then
+ return 'string'
+ elseif tp == 'dummy' then
+ return 'string'
+ end
+ if tp == 'doc.field.name' then
+ return 'string'
+ end
+ return m.getKeyTypeOfLiteral(obj)
+end
+
+--- 测试 a 到 b 的路径(不经过函数,不考虑 goto),
+--- 每个路径是一个 block 。
+---
+--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>`
+---
+--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>`
+---
+--- 否则返回 `false`
+---
+--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。
+---@param a table
+---@param b table
+---@return string|boolean mode
+---@return table pathA?
+---@return table pathB?
+function m.getPath(a, b, sameFunction)
+ --- 首先测试双方在同一个函数内
+ if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then
+ return false
+ end
+ local mode
+ local objA
+ local objB
+ if a.finish < b.start then
+ mode = 'before'
+ objA = a
+ objB = b
+ elseif a.start > b.finish then
+ mode = 'after'
+ objA = b
+ objB = a
+ else
+ return 'equal', {}, {}
+ end
+ local pathA = {}
+ local pathB = {}
+ for _ = 1, 1000 do
+ objA = m.getParentBlock(objA)
+ pathA[#pathA+1] = objA
+ if (not sameFunction and objA.type == 'function') or objA.type == 'main' then
+ break
+ end
+ end
+ for _ = 1, 1000 do
+ objB = m.getParentBlock(objB)
+ pathB[#pathB+1] = objB
+ if (not sameFunction and objA.type == 'function') or objB.type == 'main' then
+ break
+ end
+ end
+ -- pathA: {1, 2, 3, 4, 5}
+ -- pathB: {5, 6, 2, 3}
+ local top = #pathB
+ local start
+ for i = #pathA, 1, -1 do
+ local currentBlock = pathA[i]
+ if currentBlock == pathB[top] then
+ start = i
+ break
+ end
+ end
+ if not start then
+ return nil
+ end
+ -- pathA: { 1, 2, 3}
+ -- pathB: {5, 6, 2, 3}
+ local extra = 0
+ local align = top - start
+ for i = start, 1, -1 do
+ local currentA = pathA[i]
+ local currentB = pathB[i+align]
+ if currentA ~= currentB then
+ extra = i
+ break
+ end
+ end
+ -- pathA: {1}
+ local resultA = {}
+ for i = extra, 1, -1 do
+ resultA[#resultA+1] = pathA[i]
+ end
+ -- pathB: {5, 6}
+ local resultB = {}
+ for i = extra + align, 1, -1 do
+ resultB[#resultB+1] = pathB[i]
+ end
+ return mode, resultA, resultB
+end
+
+---是否是全局变量(包括 _G.XXX 形式)
+---@param source parser.guide.object
+---@return boolean
+function m.isGlobal(source)
+ if source._isGlobal ~= nil then
+ return source._isGlobal
+ end
+ if source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ if source.node and source.node.tag == '_ENV' then
+ source._isGlobal = true
+ return true
+ end
+ end
+ if source.type == 'field' then
+ source = source.parent
+ end
+ if source.special == '_G' then
+ source._isGlobal = true
+ return true
+ end
+ source._isGlobal = false
+ return false
+end
+
return m
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index ae8e3f34..335c8f24 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1,7 +1,7 @@
local m = require 'lpeglabel'
local re = require 'parser.relabel'
local lines = require 'parser.lines'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local grammar = require 'parser.grammar'
local TokenTypes, TokenStarts, TokenFinishs, TokenContents
@@ -194,6 +194,7 @@ local function parseClass(parent)
local result = {
type = 'doc.class',
parent = parent,
+ fields = {},
}
result.class = parseName('doc.class.name', result)
if not result.class then
@@ -300,8 +301,8 @@ local function parseTypeUnitTable(parent, node)
node.parent = result;
result.finish = getFinish()
- result.key = key
- result.value = value
+ result.tkey = key
+ result.tvalue = value
return result
end
@@ -425,9 +426,10 @@ local function parseTypeUnit(parent, content)
return result
end
-local function parseResume()
+local function parseResume(parent)
local result = {
- type = 'doc.resume'
+ type = 'doc.resume',
+ parent = parent,
}
if checkToken('symbol', '>', 1) then
@@ -456,7 +458,6 @@ local function parseResume()
return result
end
-local LastType
function parseType(parent)
local result = {
type = 'doc.type',
@@ -484,13 +485,7 @@ function parseType(parent)
break
end
-- TypeLiteral,指代类型的字面值。比如,对于类 Cat 来说,它的 TypeLiteral 是 "Cat"
- typeLiteral = {
- type = 'doc.type.typeliteral',
- parent = result,
- start = getStart(),
- finish = nil,
- node = nil,
- }
+ typeLiteral = true
end
if tp == 'name' then
@@ -501,10 +496,7 @@ function parseType(parent)
end
if typeLiteral then
nextToken()
- typeLiteral.finish = getFinish()
- typeLiteral.node = typeUnit
- typeUnit.parent = typeLiteral
- typeUnit = typeLiteral
+ typeUnit.literal = true
end
result.types[#result.types+1] = typeUnit
if not result.start then
@@ -566,7 +558,7 @@ function parseType(parent)
row = row + i + 1
local finishPos = nextComm.text:find('#', 3) or #nextComm.text
parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1)
- local resume = parseResume()
+ local resume = parseResume(result)
if resume then
if comments then
resume.comment = table.concat(comments, '\n')
@@ -1122,17 +1114,25 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
end
local src = sources[index]
if src.start < start then
- left = index
+ left = index + 1
else
right = index
end
end
- for i = index - 1, max do
+
+ -- 从前往后进行绑定
+ for i = index, max do
local src = sources[i]
if src then
if src.start > finish then
break
end
+ -- 遇到table后中断,处理以下情况:
+ -- ---@type AAA
+ -- local t = {x = 1, y = 2}
+ if src.type == 'table' then
+ break
+ end
if src.start >= start then
src.bindDocs = binded
bindSources[#bindSources+1] = src
@@ -1152,21 +1152,22 @@ local function bindParamAndReturnIndex(binded)
if not func then
return
end
- if not func.args then
- return
- end
- local paramIndex = 0
- local paramMap = {}
- for _, param in ipairs(func.args) do
- paramIndex = paramIndex + 1
- if param[1] then
- paramMap[param[1]] = paramIndex
+ local paramMap
+ if func.args then
+ local paramIndex = 0
+ paramMap = {}
+ for _, param in ipairs(func.args) do
+ paramIndex = paramIndex + 1
+ if param[1] then
+ paramMap[param[1]] = paramIndex
+ end
end
+ func.docParamMap = paramMap
end
local returnIndex = 0
for _, doc in ipairs(binded) do
if doc.type == 'doc.param' then
- if doc.extends then
+ if paramMap and doc.extends then
doc.extends.paramIndex = paramMap[doc.param[1]]
end
elseif doc.type == 'doc.return' then
@@ -1178,6 +1179,24 @@ local function bindParamAndReturnIndex(binded)
end
end
+local function bindClassAndFields(binded)
+ local class
+ for _, doc in ipairs(binded) do
+ if doc.type == 'doc.class' then
+ -- 多个class连续写在一起,只有最后一个class可以绑定source
+ if class then
+ class.bindSources = nil
+ end
+ class = doc
+ elseif doc.type == 'doc.field' then
+ if class then
+ class.fields[#class.fields+1] = doc
+ doc.class = class
+ end
+ end
+ end
+end
+
local function bindDoc(sources, lns, binded)
if not binded then
return
@@ -1200,6 +1219,7 @@ local function bindDoc(sources, lns, binded)
bindDocsBetween(sources, binded, bindSources, nstart, nfinish)
end
bindParamAndReturnIndex(binded)
+ bindClassAndFields(binded)
end
local function bindDocs(state)
@@ -1214,6 +1234,7 @@ local function bindDocs(state)
or src.type == 'tablefield'
or src.type == 'tableindex'
or src.type == 'function'
+ or src.type == 'table'
or src.type == '...' then
sources[#sources+1] = src
end
diff --git a/script/proto/define.lua b/script/proto/define.lua
index abfaa9b0..f2ee7ab5 100644
--- a/script/proto/define.lua
+++ b/script/proto/define.lua
@@ -103,7 +103,7 @@ m.DiagnosticDefaultNeededFileStatus = {
['unused-local'] = 'Opened',
['unused-function'] = 'Opened',
['undefined-global'] = 'Any',
- ['undefined-field'] = 'Opened',
+ ['undefined-field'] = 'Any',
['global-in-nil-env'] = 'Any',
['unused-label'] = 'Opened',
['unused-vararg'] = 'Opened',
@@ -124,7 +124,7 @@ m.DiagnosticDefaultNeededFileStatus = {
['close-non-object'] = 'Any',
['count-down-loop'] = 'Any',
['no-implicit-any'] = 'None',
- ['deprecated'] = 'None',
+ ['deprecated'] = 'Opened',
['duplicate-doc-class'] = 'Any',
['undefined-doc-class'] = 'Any',
@@ -284,4 +284,19 @@ m.BuiltIn = {
['utf8'] = 'default',
}
+m.BuiltinClass = {
+ ['unknown'] = true,
+ ['any'] = true,
+ ['nil'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+ ['integer'] = true,
+ ['thread'] = true,
+ ['table'] = true,
+ ['string'] = true,
+ ['userdata'] = true,
+ ['lightuserdata'] = true,
+ ['Function'] = true,
+}
+
return m
diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua
index 883ae68c..4a207115 100644
--- a/script/provider/diagnostic.lua
+++ b/script/provider/diagnostic.lua
@@ -190,7 +190,7 @@ function m.doDiagnostic(uri)
await.delay()
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
m.clear(uri)
return
diff --git a/script/service/service.lua b/script/service/service.lua
index 44fd9aa4..247cb5b5 100644
--- a/script/service/service.lua
+++ b/script/service/service.lua
@@ -1,3 +1,4 @@
+---@diagnostic disable: deprecated
local pub = require 'pub'
local thread = require 'bee.thread'
local await = require 'await'
diff --git a/script/utility.lua b/script/utility.lua
index 04597a39..16c5e0c9 100644
--- a/script/utility.lua
+++ b/script/utility.lua
@@ -317,12 +317,12 @@ end
--- 排序后遍历
---@param t table
-function m.sortPairs(t)
+function m.sortPairs(t, sorter)
local keys = {}
for k in pairs(t) do
keys[#keys+1] = k
end
- tableSort(keys)
+ tableSort(keys, sorter)
local i = 0
return function ()
i = i + 1
diff --git a/script/vm/eachDef.lua b/script/vm/eachDef.lua
index d72c8f01..6f7af295 100644
--- a/script/vm/eachDef.lua
+++ b/script/vm/eachDef.lua
@@ -1,49 +1,7 @@
---@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local files = require 'files'
-local util = require 'utility'
-local await = require 'await'
-local config = require 'config'
+local vm = require 'vm.vm'
+local searcher = require 'core.searcher'
-local function getDefs(source, deep)
- local results = {}
- local lock = vm.lock('eachDef', source)
- if not lock then
- return results
- end
-
- await.delay()
-
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- local clock = os.clock()
- local myResults, count = guide.requestDefinition(source, vm.interface, deep)
- if DEVELOP and os.clock() - clock > 0.1 then
- log.warn('requestDefinition', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
- end
- vm.mergeResults(results, myResults)
-
- lock()
-
- return results
-end
-
-function vm.getDefs(source, deep)
- deep = deep or -999
- if guide.isGlobal(source) then
- local key = guide.getKeyName(source)
- if not key then
- return {}
- end
- return vm.getGlobalSets(key)
- else
- local cache = vm.getCache('eachDef')[source]
- if not cache or cache.deep < deep then
- cache = getDefs(source, deep)
- cache.deep = deep
- vm.getCache('eachDef')[source] = cache
- end
- return cache
- end
+function vm.getDefs(source, field)
+ return searcher.requestDefinition(source, field)
end
diff --git a/script/vm/eachField.lua b/script/vm/eachField.lua
deleted file mode 100644
index 59f35f0c..00000000
--- a/script/vm/eachField.lua
+++ /dev/null
@@ -1,109 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local await = require 'await'
-local config = require 'config'
-
-local function getFields(source, deep, filterKey)
- local unlock = vm.lock('eachField', source)
- if not unlock then
- return {}
- end
-
- while source.type == 'paren' do
- source = source.exp
- if not source then
- return {}
- end
- end
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- await.delay()
- local results = guide.requestFields(source, vm.interface, deep, filterKey)
-
- unlock()
- return results
-end
-
-local function getDefFields(source, deep, filterKey)
- local unlock = vm.lock('eachDefField', source)
- if not unlock then
- return {}
- end
-
- while source.type == 'paren' do
- source = source.exp
- if not source then
- return {}
- end
- end
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- await.delay()
- local results = guide.requestDefFields(source, vm.interface, deep, filterKey)
-
- unlock()
- return results
-end
-
-local function getFieldsBySource(source, deep, filterKey)
- deep = deep or -999
- local cache = vm.getCache('eachField')[source]
- if not cache or cache.deep < deep then
- cache = getFields(source, deep, filterKey)
- cache.deep = deep
- if not filterKey then
- vm.getCache('eachField')[source] = cache
- end
- end
- return cache
-end
-
-local function getDefFieldsBySource(source, deep, filterKey)
- deep = deep or -999
- local cache = vm.getCache('eachDefField')[source]
- if not cache or cache.deep < deep then
- cache = getDefFields(source, deep, filterKey)
- cache.deep = deep
- if not filterKey then
- vm.getCache('eachDefField')[source] = cache
- end
- end
- return cache
-end
-
-function vm.getFields(source, deep)
- if source.special == '_G' then
- return vm.getGlobals '*'
- end
- if guide.isGlobal(source) then
- local name = guide.getKeyName(source)
- if not name then
- return {}
- end
- local cache = vm.getCache('eachFieldOfGlobal')[name]
- or getFieldsBySource(source, deep)
- vm.getCache('eachFieldOfGlobal')[name] = cache
- return cache
- else
- return getFieldsBySource(source, deep)
- end
-end
-
-function vm.getDefFields(source, deep)
- if source.special == '_G' then
- return vm.getGlobalSets '*'
- end
- if guide.isGlobal(source) then
- local name = guide.getKeyName(source)
- if not name then
- return {}
- end
- local cache = vm.getCache('eachDefFieldOfGlobal')[name]
- or getDefFieldsBySource(source, deep)
- vm.getCache('eachDefFieldOfGlobal')[name] = cache
- return cache
- else
- return getDefFieldsBySource(source, deep)
- end
-end
diff --git a/script/vm/eachRef.lua b/script/vm/eachRef.lua
index 9d0f061c..5aca198e 100644
--- a/script/vm/eachRef.lua
+++ b/script/vm/eachRef.lua
@@ -1,48 +1,7 @@
---@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local util = require 'utility'
-local await = require 'await'
-local config = require 'config'
+local vm = require 'vm.vm'
+local searcher = require 'core.searcher'
-local function getRefs(source, deep)
- local results = {}
- local lock = vm.lock('eachRef', source)
- if not lock then
- return results
- end
-
- await.delay()
-
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- local clock = os.clock()
- local myResults, count = guide.requestReference(source, vm.interface, deep)
- if DEVELOP and os.clock() - clock > 0.1 then
- log.warn('requestReference', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
- end
- vm.mergeResults(results, myResults)
-
- lock()
-
- return results
-end
-
-function vm.getRefs(source, deep)
- deep = deep or -999
- if guide.isGlobal(source) then
- local key = guide.getKeyName(source)
- if not key then
- return {}
- end
- return vm.getGlobals(key)
- else
- local cache = vm.getCache('eachRef')[source]
- if not cache or cache.deep < deep then
- cache = getRefs(source, deep)
- cache.deep = deep
- vm.getCache('eachRef')[source] = cache
- end
- return cache
- end
+function vm.getRefs(source, field)
+ return searcher.requestReference(source, field)
end
diff --git a/script/vm/getClass.lua b/script/vm/getClass.lua
deleted file mode 100644
index 5c68e0bb..00000000
--- a/script/vm/getClass.lua
+++ /dev/null
@@ -1,64 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-
-local function lookUpDocClass(source)
- local infers = vm.getInfers(source, 0)
- for _, infer in ipairs(infers) do
- if infer.source.type == 'doc.class'
- or infer.source.type == 'doc.type' then
- return guide.viewInferType(infers)
- end
- end
- return nil
-end
-
-local function getClass(source, classes, depth, deep)
- local docClass = lookUpDocClass(source)
- if docClass then
- classes[docClass] = true
- return
- end
- if depth > 3 then
- return
- end
- local value = guide.getObjectValue(source) or source
- if not deep then
- if value and value.type == 'string' then
- classes[value[1]] = true
- end
- else
- for _, src in ipairs(vm.getDefFields(value)) do
- local key = vm.getKeyName(src)
- if not key then
- goto CONTINUE
- end
- local lkey = key:lower()
- if lkey == 'type'
- or lkey == '__name'
- or lkey == 'name'
- or lkey == 'class' then
- local value = guide.getObjectValue(src)
- if value and value.type == 'string' then
- classes[value[1]] = true
- end
- end
- ::CONTINUE::
- end
- end
- if next(classes) then
- return
- end
- vm.eachMeta(source, function (mt)
- getClass(mt, classes, depth + 1, deep)
- end)
-end
-
-function vm.getClass(source, deep)
- local classes = {}
- getClass(source, classes, 1, deep)
- if not next(classes) then
- return nil
- end
- return guide.mergeTypes(classes)
-end
diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua
index cfa9326f..16b82278 100644
--- a/script/vm/getDocs.lua
+++ b/script/vm/getDocs.lua
@@ -1,152 +1,65 @@
-local files = require 'files'
-local util = require 'utility'
-local guide = require 'core.guide'
----@type vm
-local vm = require 'vm.vm'
-local config = require 'config'
+local files = require 'files'
+local guide = require 'parser.guide'
+local vm = require 'vm.vm'
+local config = require 'config'
+local collector = require 'core.collector'
+local define = require 'proto.define'
+local noder = require 'core.noder'
-local function getTypesOfFile(uri)
- local types = {}
- local ast = files.getAst(uri)
- if not ast or not ast.ast.docs then
- return types
+---获取class与alias
+---@param name? string
+---@return parser.guide.object[]
+function vm.getDocDefines(name)
+ local cache = vm.getCache 'getDocDefines'
+ if cache[name] then
+ return cache[name]
end
- guide.eachSource(ast.ast.docs, function (src)
- if src.type == 'doc.type.name'
- or src.type == 'doc.class.name'
- or src.type == 'doc.extends.name'
- or src.type == 'doc.alias.name' then
- if src.type == 'doc.type.name' then
- if guide.getParentDocTypeTable(src) then
- return
- end
- end
- local name = src[1]
- if name then
- if not types[name] then
- types[name] = {}
- end
- types[name][#types[name]+1] = src
- end
- end
- end)
- return types
-end
-
-local function getDocTypes(name)
local results = {}
- if name == 'any'
- or name == 'nil' then
- return results
- end
- for uri in files.eachFile() do
- local cache = files.getCache(uri)
- cache.classes = cache.classes or getTypesOfFile(uri)
- if name == '*' then
- for _, sources in util.sortPairs(cache.classes) do
- for _, source in ipairs(sources) do
- results[#results+1] = source
- end
- end
- else
- if cache.classes[name] then
- for _, source in ipairs(cache.classes[name]) do
+ local id = 'def:dn:' .. (name or '')
+ for node in collector.each(id) do
+ if node.sources then
+ for _, source in ipairs(node.sources) do
+ if source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name' then
results[#results+1] = source
end
end
end
end
+ cache[name] = results
return results
end
-function vm.getDocEnums(doc, mark, results)
- if not doc then
- return nil
- end
- mark = mark or {}
- if mark[doc] then
- return nil
- end
- mark[doc] = true
- results = results or {}
- for _, enum in ipairs(doc.enums) do
- results[#results+1] = enum
- end
- for _, resume in ipairs(doc.resumes) do
- results[#results+1] = resume
+function vm.isDocDefined(name)
+ if define.BuiltinClass[name] then
+ return true
end
- for _, unit in ipairs(doc.types) do
- if unit.type == 'doc.type.name' then
- for _, other in ipairs(vm.getDocTypes(unit[1])) do
- if other.type == 'doc.alias.name' then
- vm.getDocEnums(other.parent.extends, mark, results)
- end
- end
- end
+ local id = 'def:dn:' .. name
+ if collector.has(id) then
+ return true
end
- return results
+ return false
end
-function vm.getDocTypeUnits(doc, mark, results)
+function vm.getDocEnums(doc)
if not doc then
return nil
end
- mark = mark or {}
- if mark[doc] then
- return nil
- end
- mark[doc] = true
- results = results or {}
- for _, enum in ipairs(doc.enums) do
- results[#results+1] = enum
- end
- for _, resume in ipairs(doc.resumes) do
- results[#results+1] = resume
- end
- for _, unit in ipairs(doc.types) do
- if unit.type == 'doc.type.name' then
- for _, other in ipairs(vm.getDocTypes(unit[1])) do
- if other.type == 'doc.alias.name' then
- vm.getDocTypeUnits(other.parent.extends, mark, results)
- elseif other.type == 'doc.class.name' then
- results[#results+1] = other
- end
- end
- else
- results[#results+1] = unit
- end
- end
- return results
-end
-
-function vm.getDocTypes(name)
- local cache = vm.getCache('getDocTypes')[name]
- if cache ~= nil then
- return cache
- end
- cache = getDocTypes(name)
- vm.getCache('getDocTypes')[name] = cache
- return cache
-end
+ local defs = vm.getDefs(doc)
+ local results = {}
-function vm.getDocClass(name)
- local cache = vm.getCache('getDocClass')[name]
- if cache ~= nil then
- return cache
- end
- cache = {}
- local results = getDocTypes(name)
- for _, doc in ipairs(results) do
- if doc.type == 'doc.class.name' then
- cache[#cache+1] = doc
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
+ results[#results+1] = def
end
end
- vm.getCache('getDocClass')[name] = cache
- return cache
+
+ return results
end
function vm.isMetaFile(uri)
- local status = files.getAst(uri)
+ local status = files.getState(uri)
if not status then
return false
end
@@ -224,7 +137,7 @@ end
function vm.isDeprecated(value, deep)
if deep then
- local defs = vm.getDefs(value, 0)
+ local defs = vm.getDefs(value)
if #defs == 0 then
return false
end
@@ -300,7 +213,7 @@ local function makeDiagRange(uri, doc, results)
end
function vm.isDiagDisabledAt(uri, offset, name)
- local status = files.getAst(uri)
+ local status = files.getState(uri)
if not status then
return false
end
diff --git a/script/vm/getGlobals.lua b/script/vm/getGlobals.lua
index 2752ce09..e5bcafc0 100644
--- a/script/vm/getGlobals.lua
+++ b/script/vm/getGlobals.lua
@@ -1,5 +1,6 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local await = require "await"
+local searcher = require "core.searcher"
---@type vm
local vm = require 'vm.vm'
local files = require 'files'
@@ -17,12 +18,8 @@ local function getGlobalsOfFile(uri)
end
local globals = {}
cache.globals = globals
- local ast = files.getAst(uri)
- if not ast then
- return globals
- end
tracy.ZoneBeginN 'getGlobalsOfFile'
- local results = guide.findGlobals(ast.ast)
+ local results = searcher.findGlobals(uri, 'ref')
local subscribe = ws.getCache 'globalSubscribe'
subscribe[uri] = {}
local mark = {}
@@ -34,7 +31,7 @@ local function getGlobalsOfFile(uri)
goto CONTINUE
end
mark[res] = true
- local name = guide.getSimpleName(res)
+ local name = guide.getKeyName(res)
if name then
if not globals[name] then
globals[name] = {}
@@ -59,12 +56,8 @@ local function getGlobalSetsOfFile(uri)
end
local globals = {}
cache.globalSets = globals
- local ast = files.getAst(uri)
- if not ast then
- return globals
- end
tracy.ZoneBeginN 'getGlobalSetsOfFile'
- local results = guide.findGlobals(ast.ast)
+ local results = searcher.findGlobals(uri, 'def')
local subscribe = ws.getCache 'globalSetsSubscribe'
subscribe[uri] = {}
local mark = {}
@@ -76,16 +69,14 @@ local function getGlobalSetsOfFile(uri)
goto CONTINUE
end
mark[res] = true
- if vm.isSet(res) then
- local name = guide.getSimpleName(res)
- if name then
- if not globals[name] then
- globals[name] = {}
- subscribe[uri][#subscribe[uri]+1] = name
- end
- globals[name][#globals[name]+1] = res
- globals['*'][#globals['*']+1] = res
+ local name = guide.getKeyName(res)
+ if name then
+ if not globals[name] then
+ globals[name] = {}
+ subscribe[uri][#subscribe[uri]+1] = name
end
+ globals[name][#globals[name]+1] = res
+ globals['*'][#globals['*']+1] = res
end
::CONTINUE::
end
@@ -265,7 +256,7 @@ files.watch(function (ev, uri)
end
needUpdateGlobals[uri] = true
elseif ev == 'create' then
- getGlobalsOfFile(uri)
- getGlobalSetsOfFile(uri)
+ --getGlobalsOfFile(uri)
+ --getGlobalSetsOfFile(uri)
end
end)
diff --git a/script/vm/getInfer.lua b/script/vm/getInfer.lua
deleted file mode 100644
index 5447ca23..00000000
--- a/script/vm/getInfer.lua
+++ /dev/null
@@ -1,104 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local util = require 'utility'
-local await = require 'await'
-local config = require 'config'
-
-NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
-
---- 是否包含某种类型
-function vm.hasType(source, type, deep)
- local defs = vm.getDefs(source, deep)
- for i = 1, #defs do
- local def = defs[i]
- local value = guide.getObjectValue(def) or def
- if value.type == type then
- return true
- end
- end
- return false
-end
-
---- 是否包含某种类型
-function vm.hasInferType(source, type, deep)
- local infers = vm.getInfers(source, deep)
- for i = 1, #infers do
- local infer = infers[i]
- if infer.type == type then
- return true
- end
- end
- return false
-end
-
-function vm.getInferType(source, deep)
- local infers = vm.getInfers(source, deep)
- return guide.viewInferType(infers)
-end
-
-function vm.getInferLiteral(source, deep)
- local infers = vm.getInfers(source, deep)
- local literals = {}
- local mark = {}
- for _, infer in ipairs(infers) do
- local value = infer.value
- if value and not mark[value] then
- mark[value] = true
- literals[#literals+1] = util.viewLiteral(value)
- end
- end
- if #literals == 0 then
- return nil
- end
- table.sort(literals)
- return table.concat(literals, '|')
-end
-
-local function getInfers(source, deep)
- local results = {}
- local lock = vm.lock('getInfers', source)
- if not lock then
- return results
- end
-
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- await.delay()
-
- local clock = os.clock()
- local myResults, count = guide.requestInfer(source, vm.interface, deep)
- if DEVELOP and os.clock() - clock > 0.1 then
- log.warn('requestInfer', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
- end
- vm.mergeResults(results, myResults)
-
- lock()
-
- return results
-end
-
-local function getInfersBySource(source, deep)
- deep = deep or -999
- local cache = vm.getCache('getInfers')[source]
- if not cache or cache.deep < deep then
- cache = getInfers(source, deep)
- cache.deep = deep
- vm.getCache('getInfers')[source] = cache
- end
- return cache
-end
-
---- 获取对象的值
---- 会尝试穿透函数调用
-function vm.getInfers(source, deep)
- if guide.isGlobal(source) then
- local name = guide.getKeyName(source)
- local cache = vm.getCache('getInfersOfGlobal')[name]
- or getInfersBySource(source, deep)
- vm.getCache('getInfersOfGlobal')[name] = cache
- return cache
- else
- return getInfersBySource(source, deep)
- end
-end
diff --git a/script/vm/getLibrary.lua b/script/vm/getLibrary.lua
index b52f7240..a3c8feb0 100644
--- a/script/vm/getLibrary.lua
+++ b/script/vm/getLibrary.lua
@@ -1,8 +1,11 @@
---@type vm
local vm = require 'vm.vm'
-function vm.getLibraryName(source, deep)
- local defs = vm.getDefs(source, deep)
+function vm.getLibraryName(source)
+ if source.special then
+ return source.special
+ end
+ local defs = vm.getDefs(source)
for _, def in ipairs(defs) do
if def.special then
return def.special
diff --git a/script/vm/getLinks.lua b/script/vm/getLinks.lua
index 91a5f1a0..14b34987 100644
--- a/script/vm/getLinks.lua
+++ b/script/vm/getLinks.lua
@@ -1,12 +1,11 @@
-local guide = require 'core.guide'
----@type vm
+local guide = require 'parser.guide'
local vm = require 'vm.vm'
local files = require 'files'
local function getFileLinks(uri)
local ws = require 'workspace'
local links = {}
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if not ast then
return links
end
@@ -33,11 +32,17 @@ local function getFileLinks(uri)
return links
end
+local function getFileLinksOrCache(uri)
+ local cache = files.getCache(uri)
+ cache.links = cache.links or getFileLinks(uri)
+ return cache.links
+end
+
local function getLinksTo(uri)
uri = files.asKey(uri)
local links = {}
for u in files.eachFile() do
- local ls = vm.getFileLinks(u)
+ local ls = getFileLinksOrCache(u)
if ls[uri] then
for _, l in ipairs(ls[uri]) do
links[#links+1] = l
@@ -47,6 +52,7 @@ local function getLinksTo(uri)
return links
end
+-- 获取所有 require(uri) 的文件
function vm.getLinksTo(uri)
local cache = vm.getCache('getLinksTo')[uri]
if cache ~= nil then
@@ -56,9 +62,3 @@ function vm.getLinksTo(uri)
vm.getCache('getLinksTo')[uri] = cache
return cache
end
-
-function vm.getFileLinks(uri)
- local cache = files.getCache(uri)
- cache.links = cache.links or getFileLinks(uri)
- return cache.links
-end
diff --git a/script/vm/getMeta.lua b/script/vm/getMeta.lua
deleted file mode 100644
index 44d1874a..00000000
--- a/script/vm/getMeta.lua
+++ /dev/null
@@ -1,53 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-
-local function eachMetaOfArg1(source, callback)
- local node, index = vm.getArgInfo(source)
- local special = vm.getSpecial(node)
- if special == 'setmetatable' and index == 1 then
- local mt = node.next.args[2]
- if mt then
- callback(mt)
- end
- end
-end
-
-local function eachMetaOfRecv(source, callback)
- if not source or source.type ~= 'select' then
- return
- end
- if source.index ~= 1 then
- return
- end
- local call = source.vararg
- if not call or call.type ~= 'call' then
- return
- end
- local special = vm.getSpecial(call.node)
- if special ~= 'setmetatable' then
- return
- end
- local mt = call.args[2]
- if mt then
- callback(mt)
- end
-end
-
-function vm.eachMetaValue(source, callback)
- vm.eachMeta(source, function (mt)
- for _, src in ipairs(vm.getDefFields(mt)) do
- if vm.getKeyName(src) == '__index' then
- if src.value then
- for _, valueSrc in ipairs(vm.getDefFields(src.value)) do
- callback(valueSrc)
- end
- end
- end
- end
- end)
-end
-
-function vm.eachMeta(source, callback)
- eachMetaOfArg1(source, callback)
- eachMetaOfRecv(source.value, callback)
-end
diff --git a/script/vm/guideInterface.lua b/script/vm/guideInterface.lua
index ae060481..a07b6644 100644
--- a/script/vm/guideInterface.lua
+++ b/script/vm/guideInterface.lua
@@ -2,7 +2,7 @@
local vm = require 'vm.vm'
local files = require 'files'
local ws = require 'workspace'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local await = require 'await'
local config = require 'config'
@@ -27,11 +27,11 @@ function m.require(args, index)
return nil
end
local results = {}
- local myUri = guide.getUri(args[1])
+ local myUri = searcher.getUri(args[1])
local uris = ws.findUrisByRequirePath(reqName)
for _, uri in ipairs(uris) do
if not files.eq(myUri, uri) then
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if ast then
m.searchFileReturn(results, ast.ast, index)
end
@@ -47,11 +47,11 @@ function m.dofile(args, index)
return
end
local results = {}
- local myUri = guide.getUri(args[1])
+ local myUri = searcher.getUri(args[1])
local uris = ws.findUrisByFilePath(reqName)
for _, uri in ipairs(uris) do
if not files.eq(myUri, uri) then
- local ast = files.getAst(uri)
+ local ast = files.getState(uri)
if ast then
m.searchFileReturn(results, ast.ast, index)
end
@@ -87,9 +87,9 @@ function vm.interface.global(name, onlyDef)
end
end
-function vm.interface.docType(name)
+function vm.interface.doc(name, type)
await.delay()
- return vm.getDocTypes(name)
+ return vm.getDocNames(name, type)
end
function vm.interface.link(uri)
diff --git a/script/vm/init.lua b/script/vm/init.lua
index b9e8e147..c38f01d5 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -2,10 +2,6 @@ local vm = require 'vm.vm'
require 'vm.getGlobals'
require 'vm.getDocs'
require 'vm.getLibrary'
-require 'vm.getInfer'
-require 'vm.getClass'
-require 'vm.getMeta'
-require 'vm.eachField'
require 'vm.eachDef'
require 'vm.eachRef'
require 'vm.getLinks'
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
index 0248ad8c..0e7f3176 100644
--- a/script/vm/vm.lua
+++ b/script/vm/vm.lua
@@ -1,18 +1,14 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local util = require 'utility'
local files = require 'files'
local timer = require 'timer'
local setmetatable = setmetatable
-local assert = assert
-local require = require
-local type = type
local running = coroutine.running
local ipairs = ipairs
local log = log
local xpcall = xpcall
local mathHuge = math.huge
-local collectgarbage = collectgarbage
_ENV = nil
@@ -64,7 +60,10 @@ function m.getArgInfo(source)
end
function m.getSpecial(source)
- return guide.getSpecial(source)
+ if not source then
+ return nil
+ end
+ return source.special
end
function m.getKeyName(source)