summaryrefslogtreecommitdiff
path: root/script/core
diff options
context:
space:
mode:
Diffstat (limited to 'script/core')
-rw-r--r--script/core/code-action.lua13
-rw-r--r--script/core/command/removeSpace.lua15
-rw-r--r--script/core/command/solve.lua9
-rw-r--r--script/core/completion.lua132
-rw-r--r--script/core/definition.lua18
-rw-r--r--script/core/diagnostics/ambiguity-1.lua2
-rw-r--r--script/core/diagnostics/circle-doc-class.lua12
-rw-r--r--script/core/diagnostics/close-non-object.lua7
-rw-r--r--script/core/diagnostics/code-after-break.lua8
-rw-r--r--script/core/diagnostics/count-down-loop.lua6
-rw-r--r--script/core/diagnostics/deprecated.lua20
-rw-r--r--script/core/diagnostics/duplicate-doc-class.lua14
-rw-r--r--script/core/diagnostics/duplicate-index.lua10
-rw-r--r--script/core/diagnostics/duplicate-set-field.lua12
-rw-r--r--script/core/diagnostics/empty-block.lua2
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua2
-rw-r--r--script/core/diagnostics/init.lua3
-rw-r--r--script/core/diagnostics/lowercase-global.lua2
-rw-r--r--script/core/diagnostics/newfield-call.lua2
-rw-r--r--script/core/diagnostics/newline-call.lua2
-rw-r--r--script/core/diagnostics/no-implicit-any.lua4
-rw-r--r--script/core/diagnostics/redefined-local.lua5
-rw-r--r--script/core/diagnostics/redundant-parameter.lua4
-rw-r--r--script/core/diagnostics/trailing-space.lua2
-rw-r--r--script/core/diagnostics/unbalanced-assignments.lua2
-rw-r--r--script/core/diagnostics/undefined-doc-class.lua4
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua4
-rw-r--r--script/core/diagnostics/undefined-doc-param.lua2
-rw-r--r--script/core/diagnostics/undefined-env-child.lua10
-rw-r--r--script/core/diagnostics/undefined-field.lua69
-rw-r--r--script/core/diagnostics/undefined-global.lua19
-rw-r--r--script/core/diagnostics/unused-function.lua2
-rw-r--r--script/core/diagnostics/unused-label.lua2
-rw-r--r--script/core/diagnostics/unused-local.lua5
-rw-r--r--script/core/diagnostics/unused-vararg.lua2
-rw-r--r--script/core/document-symbol.lua10
-rw-r--r--script/core/find-source.lua2
-rw-r--r--script/core/folding.lua4
-rw-r--r--script/core/generic.lua234
-rw-r--r--script/core/guide2.lua (renamed from script/core/guide.lua)5
-rw-r--r--script/core/highlight.lua59
-rw-r--r--script/core/hint.lua20
-rw-r--r--script/core/hover/arg.lua13
-rw-r--r--script/core/hover/description.lua17
-rw-r--r--script/core/hover/init.lua11
-rw-r--r--script/core/hover/label.lua59
-rw-r--r--script/core/hover/name.lua6
-rw-r--r--script/core/hover/return.lua28
-rw-r--r--script/core/hover/table.lua282
-rw-r--r--script/core/infer.lua634
-rw-r--r--script/core/keyword.lua2
-rw-r--r--script/core/noder.lua926
-rw-r--r--script/core/reference.lua23
-rw-r--r--script/core/rename.lua21
-rw-r--r--script/core/searcher.lua728
-rw-r--r--script/core/semantic-tokens.lua4
-rw-r--r--script/core/signature.lua7
-rw-r--r--script/core/type-formatting.lua2
-rw-r--r--script/core/workspace-symbol.lua4
59 files changed, 2963 insertions, 565 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
index bae3df81..9ed000e9 100644
--- a/script/core/code-action.lua
+++ b/script/core/code-action.lua
@@ -1,10 +1,9 @@
-local files = require 'files'
-local lang = require 'language'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local util = require 'utility'
-local sp = require 'bee.subprocess'
-local vm = require 'vm'
+local files = require 'files'
+local lang = require 'language'
+local util = require 'utility'
+local sp = require 'bee.subprocess'
+local vm = require 'vm'
+local guide = require "parser.guide"
local function checkDisableByLuaDocExits(uri, row, mode, code)
local lines = files.getLines(uri)
diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua
index 527af8d5..ba1ee8eb 100644
--- a/script/core/command/removeSpace.lua
+++ b/script/core/command/removeSpace.lua
@@ -1,11 +1,10 @@
-local files = require 'files'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local proto = require 'proto'
-local lang = require 'language'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local proto = require 'proto'
+local lang = require 'language'
local function isInString(ast, offset)
- return guide.eachSourceContain(ast.ast, offset, function (source)
+ return searcher.eachSourceContain(ast.ast, offset, function (source)
if source.type == 'string' then
return true
end
@@ -23,10 +22,10 @@ return function (data)
local textEdit = {}
for i = 1, #lines do
- local line = guide.lineContent(lines, text, i, true)
+ local line = searcher.lineContent(lines, text, i, true)
local pos = line:find '[ \t]+$'
if pos then
- local start, finish = guide.lineRange(lines, i, true)
+ local start, finish = searcher.lineRange(lines, i, true)
start = start + pos - 1
if isInString(ast, start) then
goto NEXT_LINE
diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua
index 995a2109..dc23e7af 100644
--- a/script/core/command/solve.lua
+++ b/script/core/command/solve.lua
@@ -1,8 +1,7 @@
-local files = require 'files'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local proto = require 'proto'
-local lang = require 'language'
+local files = require 'files'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
local opMap = {
['+'] = true,
diff --git a/script/core/completion.lua b/script/core/completion.lua
index e3980eca..acaaa276 100644
--- a/script/core/completion.lua
+++ b/script/core/completion.lua
@@ -1,19 +1,14 @@
local define = require 'proto.define'
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local matchKey = require 'core.matchkey'
local vm = require 'vm'
-local getLabel = require 'core.hover.label'
local getName = require 'core.hover.name'
local getArg = require 'core.hover.arg'
-local getReturn = require 'core.hover.return'
-local getDesc = require 'core.hover.description'
local getHover = require 'core.hover'
local config = require 'config'
local util = require 'utility'
local markdown = require 'provider.markdown'
-local findSource = require 'core.find-source'
-local await = require 'await'
local parser = require 'parser'
local keyWordMap = require 'core.keyword'
local workspace = require 'workspace'
@@ -21,6 +16,8 @@ local furi = require 'file-uri'
local rpath = require 'workspace.require-path'
local lang = require 'language'
local lookBackward = require 'core.look-backward'
+local guide = require 'parser.guide'
+local infer = require 'core.infer'
local DiagnosticModes = {
'disable-next-line',
@@ -135,8 +132,8 @@ local function buildFunctionSnip(source, value, oop)
end
local function buildDetail(source)
- local types = vm.getInferType(source, 0)
- local literals = vm.getInferLiteral(source, 0)
+ local types = infer.searchAndViewInfers(source)
+ local literals = infer.searchAndViewLiterals(source)
if literals then
return types .. ' = ' .. literals
else
@@ -149,9 +146,9 @@ local function getSnip(source)
if context <= 0 then
return nil
end
- local defs = vm.getRefs(source, 0)
+ local defs = vm.getRefs(source)
for _, def in ipairs(defs) do
- def = guide.getObjectValue(def) or def
+ def = searcher.getObjectValue(def) or def
if def ~= source and def.type == 'function' then
local uri = guide.getUri(def)
local text = files.getText(uri)
@@ -273,8 +270,8 @@ local function checkLocal(ast, word, offset, results)
if not matchKey(word, name) then
goto CONTINUE
end
- if vm.hasType(source, 'function') then
- for _, def in ipairs(vm.getDefs(source, 0)) do
+ if infer.hasType(source, 'function') then
+ for _, def in ipairs(vm.getDefs(source)) do
if def.type == 'function'
or def.type == 'doc.type.function' then
local funcLabel = name .. getParams(def, false)
@@ -417,7 +414,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset)
end
local function checkFieldThen(name, src, word, start, offset, parent, oop, results)
- local value = guide.getObjectValue(src) or src
+ local value = searcher.getObjectValue(src) or src
local kind = define.CompletionItemKind.Field
if value.type == 'function'
or value.type == 'doc.type.function' then
@@ -492,7 +489,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res
end
local funcLabel
if config.config.completion.showParams then
- local value = guide.getObjectValue(src) or src
+ local value = searcher.getObjectValue(src) or src
if value.type == 'function'
or value.type == 'doc.type.function' then
funcLabel = name .. getParams(value, oop)
@@ -539,16 +536,16 @@ end
local function checkGlobal(ast, word, start, offset, parent, oop, results)
local locals = guide.getVisibleLocals(ast.ast, offset)
- local refs = vm.getGlobalSets '*'
- checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global')
+ local globals = vm.getGlobalSets '*'
+ checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results, locals, 'global')
end
local function checkField(ast, word, start, offset, parent, oop, results)
if parent.tag == '_ENV' or parent.special == '_G' then
- local refs = vm.getGlobalSets '*'
- checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
+ local globals = vm.getGlobalSets '*'
+ checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results)
else
- local refs = vm.getFields(parent, 0)
+ local refs = vm.getRefs(parent, '*')
checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
end
end
@@ -1043,14 +1040,14 @@ local function mergeEnums(a, b, source)
end
end
-local function checkTypingEnum(ast, text, offset, infers, str, results)
+local function checkTypingEnum(ast, text, offset, defs, str, results)
local enums = {}
- for _, infer in ipairs(infers) do
- if infer.source.type == 'doc.type.enum'
- or infer.source.type == 'doc.resume' then
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
enums[#enums+1] = {
- label = infer.source[1],
- description = infer.source.comment and infer.source.comment.text,
+ label = def[1],
+ description = def.comment and def.comment.text,
kind = define.CompletionItemKind.EnumMember,
}
end
@@ -1074,8 +1071,8 @@ local function checkEqualEnumLeft(ast, text, offset, source, results)
return src
end
end)
- local infers = vm.getInfers(source, 0)
- checkTypingEnum(ast, text, offset, infers, str, results)
+ local defs = vm.getDefs(source)
+ checkTypingEnum(ast, text, offset, defs, str, results)
end
local function checkEqualEnum(ast, text, offset, results)
@@ -1247,7 +1244,30 @@ function (%s)\
end"):format(table.concat(args, ', '))
end
-local function getCallEnums(source, index)
+local function pushCallEnumsAndFuncs(defs)
+ local results = {}
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
+ results[#results+1] = {
+ label = def[1],
+ description = def.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ end
+ if def.type == 'doc.type.function' then
+ results[#results+1] = {
+ label = infer.viewDocFunction(def),
+ description = def.comment,
+ kind = define.CompletionItemKind.Function,
+ insertText = buildInsertDocFunction(def),
+ }
+ end
+ end
+ return results
+end
+
+local function getCallEnumsAndFuncs(source, index)
if source.type == 'function' and source.bindDocs then
if not source.args then
return
@@ -1266,37 +1286,10 @@ local function getCallEnums(source, index)
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.param'
and doc.param[1] == arg[1] then
- local enums = {}
- for _, enum in ipairs(vm.getDocEnums(doc.extends) or {}) do
- enums[#enums+1] = {
- label = enum[1],
- description = enum.comment,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
- for _, unit in ipairs(vm.getDocTypeUnits(doc.extends) or {}) do
- if unit.type == 'doc.type.function' then
- local text = files.getText(guide.getUri(unit))
- enums[#enums+1] = {
- label = text:sub(unit.start, unit.finish),
- description = doc.comment,
- kind = define.CompletionItemKind.Function,
- insertText = buildInsertDocFunction(unit),
- }
- end
- end
- return enums
+ return pushCallEnumsAndFuncs(vm.getDefs(doc.extends))
elseif doc.type == 'doc.vararg'
and arg.type == '...' then
- local enums = {}
- for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do
- enums[#enums+1] = {
- label = enum[1],
- description = enum.comment,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
- return enums
+ return pushCallEnumsAndFuncs(vm.getDefs(doc.vararg))
end
end
end
@@ -1403,12 +1396,12 @@ local function checkTableLiteralFieldByCall(ast, text, offset, call, defs, index
return
end
for _, def in ipairs(defs) do
- local func = guide.getObjectValue(def) or def
+ local func = searcher.getObjectValue(def) or def
local param = getFuncParamByCallIndex(func, index)
if not param then
goto CONTINUE
end
- local defs = vm.getDefFields(param, 0)
+ local defs = vm.getDefs(param, '*')
for _, field in ipairs(defs) do
local name = guide.getKeyName(field)
if name and not mark[name] then
@@ -1431,10 +1424,10 @@ local function tryCallArg(ast, text, offset, results)
if arg and arg.type == 'function' then
return
end
- local defs = vm.getDefs(call.node, 0)
+ local defs = vm.getDefs(call.node)
for _, def in ipairs(defs) do
- def = guide.getObjectValue(def) or def
- local enums = getCallEnums(def, argIndex)
+ def = searcher.getObjectValue(def) or def
+ local enums = getCallEnumsAndFuncs(def, argIndex)
if enums then
mergeEnums(myResults, enums, arg)
end
@@ -1461,7 +1454,8 @@ local function tryTable(ast, text, offset, results)
if source.type ~= 'table' then
tbl = source.parent
end
- local defs = vm.getDefFields(tbl, 0)
+ local parent = tbl.parent
+ local defs = vm.getDefs(parent, '*')
for _, field in ipairs(defs) do
local name = guide.getKeyName(field)
if name and not mark[name] then
@@ -1560,7 +1554,7 @@ end
local function tryLuaDocBySource(ast, offset, source, results)
if source.type == 'doc.extends.name' then
if source.parent.type == 'doc.class' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if doc.type == 'doc.class.name'
and doc.parent ~= source.parent
and matchKey(source[1], doc[1]) then
@@ -1578,7 +1572,7 @@ local function tryLuaDocBySource(ast, offset, source, results)
end
return true
elseif source.type == 'doc.type.name' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name')
and doc.parent ~= source.parent
and matchKey(source[1], doc[1]) then
@@ -1652,7 +1646,7 @@ end
local function tryLuaDocByErr(ast, offset, err, docState, results)
if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if doc.type == 'doc.class.name'
and doc.parent ~= docState then
results[#results+1] = {
@@ -1662,7 +1656,7 @@ local function tryLuaDocByErr(ast, offset, err, docState, results)
end
end
elseif err.type == 'LUADOC_MISS_TYPE_NAME' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then
results[#results+1] = {
label = doc[1],
@@ -1735,14 +1729,14 @@ local function buildLuaDocOfFunction(func)
local returns = {}
if func.args then
for _, arg in ipairs(func.args) do
- args[#args+1] = vm.getInferType(arg)
+ args[#args+1] = infer.searchAndViewInfers(arg)
end
end
if func.returns then
for _, rtns in ipairs(func.returns) do
for n = 1, #rtns do
if not returns[n] then
- returns[n] = vm.getInferType(rtns[n])
+ returns[n] = infer.searchAndViewInfers(rtns[n])
end
end
end
diff --git a/script/core/definition.lua b/script/core/definition.lua
index b26bb922..3ced05a2 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -1,14 +1,15 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local workspace = require 'workspace'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
+local guide = require 'parser.guide'
local function sortResults(results)
-- 先按照顺序排序
table.sort(results, function (a, b)
- local u1 = guide.getUri(a.target)
- local u2 = guide.getUri(b.target)
+ local u1 = searcher.getUri(a.target)
+ local u2 = searcher.getUri(b.target)
if u1 == u2 then
return a.target.start < b.target.start
else
@@ -20,7 +21,7 @@ local function sortResults(results)
for i = #results, 1, -1 do
local res = results[i].target
local f = res.finish
- local uri = guide.getUri(res)
+ local uri = searcher.getUri(res)
if lf and f > lf and uri == lu then
table.remove(results, i)
else
@@ -127,11 +128,11 @@ return function (uri, offset)
end
end
- local defs = vm.getDefs(source, 0)
+ local defs = vm.getDefs(source)
local values = {}
for _, src in ipairs(defs) do
- local value = guide.getObjectValue(src)
- if value and value ~= src then
+ local value = searcher.getObjectValue(src)
+ if value and value ~= src and guide.isLiteral(value) then
values[value] = true
end
end
@@ -148,9 +149,6 @@ return function (uri, offset)
goto CONTINUE
end
src = src.field or src.method or src.index or src
- if src.type == 'table' and src.parent.type ~= 'return' then
- goto CONTINUE
- end
if src.type == 'doc.class.name'
and source.type ~= 'doc.type.name'
and source.type ~= 'doc.extends.name'
diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua
index 19bb4f97..37815fb5 100644
--- a/script/core/diagnostics/ambiguity-1.lua
+++ b/script/core/diagnostics/ambiguity-1.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local opMap = {
diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua
index 702cd904..d2e26378 100644
--- a/script/core/diagnostics/circle-doc-class.lua
+++ b/script/core/diagnostics/circle-doc-class.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local vm = require 'vm'
+local guide = require 'parser.guide'
return function (uri, callback)
local state = files.getAst(uri)
@@ -40,7 +40,7 @@ return function (uri, callback)
end
if not mark[newName] then
mark[newName] = true
- local docs = vm.getDocTypes(newName)
+ local docs = vm.getDocDefines(newName)
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name' then
list[#list+1] = otherDoc.parent
diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua
index d1983c42..7828efe9 100644
--- a/script/core/diagnostics/close-non-object.lua
+++ b/script/core/diagnostics/close-non-object.lua
@@ -1,7 +1,6 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
return function (uri, callback)
local state = files.getAst(uri)
diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua
index f23755ea..f300a61a 100644
--- a/script/core/diagnostics/code-after-break.lua
+++ b/script/core/diagnostics/code-after-break.lua
@@ -1,7 +1,7 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
return function (uri, callback)
local state = files.getAst(uri)
diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua
index 65099af8..ee245781 100644
--- a/script/core/diagnostics/count-down-loop.lua
+++ b/script/core/diagnostics/count-down-loop.lua
@@ -1,6 +1,6 @@
-local files = require "files"
-local guide = require "core.guide"
-local lang = require 'language'
+local files = require "files"
+local guide = require "parser.guide"
+local lang = require 'language'
return function (uri, callback)
local state = files.getAst(uri)
diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua
index 60d60946..a6f8a47e 100644
--- a/script/core/diagnostics/deprecated.lua
+++ b/script/core/diagnostics/deprecated.lua
@@ -1,10 +1,10 @@
-local files = require 'files'
-local vm = require 'vm'
-local lang = require 'language'
-local guide = require 'core.guide'
-local config = require 'config'
-local define = require 'proto.define'
-local await = require 'await'
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local config = require 'config'
+local define = require 'proto.define'
+local await = require 'await'
return function (uri, callback)
local ast = files.getAst(uri)
@@ -20,7 +20,7 @@ return function (uri, callback)
return
end
if src.type == 'getglobal' then
- local key = guide.getKeyName(src)
+ local key = src[1]
if not key then
return
end
@@ -34,11 +34,11 @@ return function (uri, callback)
await.delay()
- if not vm.isDeprecated(src, 0) then
+ if not vm.isDeprecated(src, true) then
return
end
- local defs = vm.getDefs(src, 0)
+ local defs = vm.getDefs(src)
local validVersions
for _, def in ipairs(defs) do
if def.bindDocs then
diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua
index 8c6696a9..daecb836 100644
--- a/script/core/diagnostics/duplicate-doc-class.lua
+++ b/script/core/diagnostics/duplicate-doc-class.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local vm = require 'vm'
+local guide = require 'parser.guide'
return function (uri, callback)
local state = files.getAst(uri)
@@ -20,7 +20,7 @@ return function (uri, callback)
or doc.type == 'doc.alias' then
local name = guide.getKeyName(doc)
if not cache[name] then
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
cache[name] = {}
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name'
@@ -28,7 +28,7 @@ return function (uri, callback)
cache[name][#cache[name]+1] = {
start = otherDoc.start,
finish = otherDoc.finish,
- uri = guide.getUri(otherDoc),
+ uri = searcher.getUri(otherDoc),
}
end
end
diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua
index 5e63d39e..d1ba9261 100644
--- a/script/core/diagnostics/duplicate-index.lua
+++ b/script/core/diagnostics/duplicate-index.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
return function (uri, callback)
local ast = files.getAst(uri)
diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua
index c1e2285a..e1883fe5 100644
--- a/script/core/diagnostics/duplicate-set-field.lua
+++ b/script/core/diagnostics/duplicate-set-field.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local define = require 'proto.define'
+local guide = require "parser.guide"
return function (uri, callback)
local ast = files.getAst(uri)
@@ -30,7 +30,7 @@ return function (uri, callback)
if not name then
goto CONTINUE
end
- local value = guide.getObjectValue(nxt)
+ local value = searcher.getObjectValue(nxt)
if not value or value.type ~= 'function' then
goto CONTINUE
end
diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua
index 690a4ca2..2024f4e3 100644
--- a/script/core/diagnostics/empty-block.lua
+++ b/script/core/diagnostics/empty-block.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
index de23bc76..9a0d4f35 100644
--- a/script/core/diagnostics/global-in-nil-env.lua
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
-- TODO: 检查路径是否可达
diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua
index a2b831f7..1d1ab9af 100644
--- a/script/core/diagnostics/init.lua
+++ b/script/core/diagnostics/init.lua
@@ -62,14 +62,11 @@ local function check(uri, name, results)
end
return function (uri, response)
- local vm = require 'vm'
local ast = files.getAst(uri)
if not ast then
return nil
end
- local isOpen = files.isOpen(uri)
-
for _, name in ipairs(diagList) do
await.delay()
local results = {}
diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua
index 9c094701..8c7ae793 100644
--- a/script/core/diagnostics/lowercase-global.lua
+++ b/script/core/diagnostics/lowercase-global.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local config = require 'config'
local vm = require 'vm'
diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua
index 0727c2fd..75681cbc 100644
--- a/script/core/diagnostics/newfield-call.lua
+++ b/script/core/diagnostics/newfield-call.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua
index 807f76a2..159a60c9 100644
--- a/script/core/diagnostics/newline-call.lua
+++ b/script/core/diagnostics/newline-call.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua
index ffaab821..23af570a 100644
--- a/script/core/diagnostics/no-implicit-any.lua
+++ b/script/core/diagnostics/no-implicit-any.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
@@ -10,7 +10,7 @@ return function (uri, callback)
return
end
- guide.eachSource(ast.ast, function (source)
+ searcher.eachSource(ast.ast, function (source)
if source.type ~= 'local'
and source.type ~= 'setlocal'
and source.type ~= 'setglobal'
diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua
index 857d80d2..48093417 100644
--- a/script/core/diagnostics/redefined-local.lua
+++ b/script/core/diagnostics/redefined-local.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
@@ -13,6 +13,9 @@ return function (uri, callback)
or name == ast.ENVMode then
return
end
+ if source.tag == 'self' then
+ return
+ end
local exist = guide.getLocal(source, name, source.start-1)
if exist then
callback {
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
index c5bcd5a5..eca7fc91 100644
--- a/script/core/diagnostics/redundant-parameter.lua
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local define = require 'proto.define'
@@ -84,7 +84,7 @@ return function (uri, callback)
local funcArgs = cache[func]
if funcArgs == nil then
funcArgs = getFuncArgs(func) or false
- local refs = vm.getRefs(func, 0)
+ local refs = vm.getRefs(func)
for _, ref in ipairs(refs) do
cache[ref] = funcArgs
end
diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua
index 0a4b1d57..e54a6e60 100644
--- a/script/core/diagnostics/trailing-space.lua
+++ b/script/core/diagnostics/trailing-space.lua
@@ -1,6 +1,6 @@
local files = require 'files'
local lang = require 'language'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local function isInString(ast, offset)
local result = false
diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua
index b2b2800c..35aebb45 100644
--- a/script/core/diagnostics/unbalanced-assignments.lua
+++ b/script/core/diagnostics/unbalanced-assignments.lua
@@ -1,7 +1,7 @@
local files = require 'files'
local define = require 'proto.define'
local lang = require 'language'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
return function (uri, callback, code)
local ast = files.getAst(uri)
diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua
index a91cfa7f..d79f7ea4 100644
--- a/script/core/diagnostics/undefined-doc-class.lua
+++ b/script/core/diagnostics/undefined-doc-class.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
@@ -25,7 +25,7 @@ return function (uri, callback)
end
for _, ext in ipairs(doc.extends) do
local name = ext[1]
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
if cache[name] == nil then
cache[name] = false
for _, otherDoc in ipairs(docs) do
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
index d8a4363b..871f16e1 100644
--- a/script/core/diagnostics/undefined-doc-name.lua
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
@@ -22,7 +22,7 @@ return function (uri, callback)
if classCache[name] ~= nil then
return classCache[name]
end
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name'
or otherDoc.type == 'doc.alias.name' then
diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua
index 0bf371e5..4a97947d 100644
--- a/script/core/diagnostics/undefined-doc-param.lua
+++ b/script/core/diagnostics/undefined-doc-param.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua
index 89efb8c7..c97c3fe8 100644
--- a/script/core/diagnostics/undefined-env-child.lua
+++ b/script/core/diagnostics/undefined-env-child.lua
@@ -1,7 +1,7 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local vm = require 'vm'
-local lang = require 'language'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local guide = require 'parser.guide'
+local lang = require 'language'
return function (uri, callback)
local ast = files.getAst(uri)
@@ -13,7 +13,7 @@ return function (uri, callback)
if source.node.tag == '_ENV' then
return
end
- local defs = guide.requestDefinition(source)
+ local defs = searcher.requestDefinition(source)
if #defs > 0 then
return
end
diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua
index b10c9ab0..2d357d5b 100644
--- a/script/core/diagnostics/undefined-field.lua
+++ b/script/core/diagnostics/undefined-field.lua
@@ -2,9 +2,16 @@ local files = require 'files'
local vm = require 'vm'
local lang = require 'language'
local config = require 'config'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
+local SkipCheckClass = {
+ ['unknown'] = true,
+ ['any'] = true,
+ ['table'] = true,
+ ['nil'] = true,
+}
+
return function (uri, callback)
local ast = files.getAst(uri)
if not ast then
@@ -18,7 +25,7 @@ return function (uri, callback)
if cache[src] == nil then
tracy.ZoneBeginN('undefined-field getInfers')
infers = vm.getInfers(src, 0) or false
- local refs = vm.getRefs(src, 0)
+ local refs = vm.getRefs(src)
for _, ref in ipairs(refs) do
cache[ref] = infers
end
@@ -47,7 +54,7 @@ return function (uri, callback)
elseif inferSource.type == 'doc.class.name' then
addTo(allDocClass, inferSource.parent)
elseif inferSource.type == 'doc.type.name' then
- local docTypes = vm.getDocTypes(inferSource[1])
+ local docTypes = vm.getDocDefines(inferSource[1])
for _, docType in ipairs(docTypes) do
if docType.type == 'doc.class.name' then
addTo(allDocClass, docType.parent)
@@ -65,7 +72,7 @@ return function (uri, callback)
local empty = true
for _, docClass in ipairs(allDocClass) do
tracy.ZoneBeginN('undefined-field getDefFields')
- local refs = vm.getDefFields(docClass)
+ local refs = vm.getDefs(docClass, '*')
tracy.ZoneEnd()
for _, ref in ipairs(refs) do
@@ -87,35 +94,37 @@ return function (uri, callback)
end
local function checkUndefinedField(src)
- local fieldName = guide.getKeyName(src)
-
- local allDocClass = getAllDocClassFromInfer(src.node)
- if (not allDocClass) or (#allDocClass == 0) then
- return
- end
-
- local fields = getAllFieldsFromAllDocClass(allDocClass)
-
- -- 没找到任何 field,跳过检查
- if not fields then
+ if #vm.getDefs(src) > 0 then
return
end
-
- if not fields[fieldName] then
- local message = lang.script('DIAG_UNDEF_FIELD', fieldName)
- if src.type == 'getfield' and src.field then
- callback {
- start = src.field.start,
- finish = src.field.finish,
- message = message,
- }
- elseif src.type == 'getmethod' and src.method then
- callback {
- start = src.method.start,
- finish = src.method.finish,
- message = message,
- }
+ local node = src.node
+ if node then
+ local defs = vm.getDefs(node)
+ local ok
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class.name'
+ and not SkipCheckClass[def[1]] then
+ ok = true
+ break
+ end
end
+ if not ok then
+ return
+ end
+ end
+ local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src))
+ if src.type == 'getfield' and src.field then
+ callback {
+ start = src.field.start,
+ finish = src.field.finish,
+ message = message,
+ }
+ elseif src.type == 'getmethod' and src.method then
+ callback {
+ start = src.method.start,
+ finish = src.method.finish,
+ message = message,
+ }
end
end
guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField);
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
index 161d8856..825b14f1 100644
--- a/script/core/diagnostics/undefined-global.lua
+++ b/script/core/diagnostics/undefined-global.lua
@@ -1,9 +1,8 @@
-local files = require 'files'
-local vm = require 'vm'
-local lang = require 'language'
-local config = require 'config'
-local guide = require 'core.guide'
-local define = require 'proto.define'
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local config = require 'config'
+local guide = require 'parser.guide'
local requireLike = {
['include'] = true,
@@ -20,7 +19,7 @@ return function (uri, callback)
-- 遍历全局变量,检查所有没有 set 模式的全局变量
guide.eachSourceType(ast.ast, 'getglobal', function (src)
- local key = guide.getKeyName(src)
+ local key = src[1]
if not key then
return
end
@@ -30,7 +29,11 @@ return function (uri, callback)
if config.config.runtime.special[key] then
return
end
- if #vm.getGlobalSets(key) == 0 then
+ local node = src.node
+ if node.tag ~= '_ENV' then
+ return
+ end
+ if #vm.getDefs(src) == 0 then
local message = lang.script('DIAG_UNDEF_GLOBAL', key)
if requireLike[key:lower()] then
message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua
index b6f92e60..41c239f9 100644
--- a/script/core/diagnostics/unused-function.lua
+++ b/script/core/diagnostics/unused-function.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local vm = require 'vm'
local define = require 'proto.define'
local lang = require 'language'
diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua
index e2d5e49a..e6d998ba 100644
--- a/script/core/diagnostics/unused-label.lua
+++ b/script/core/diagnostics/unused-label.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua
index fde90cb8..1a77a45f 100644
--- a/script/core/diagnostics/unused-local.lua
+++ b/script/core/diagnostics/unused-local.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
@@ -87,6 +87,9 @@ return function (uri, callback)
or name == ast.ENVMode then
return
end
+ if source.tag == 'self' then
+ return
+ end
if isToBeClosed(source) then
return
end
diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua
index ec0a05fb..74cc08e7 100644
--- a/script/core/diagnostics/unused-vararg.lua
+++ b/script/core/diagnostics/unused-vararg.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua
index cc87e3ca..e36ba29b 100644
--- a/script/core/document-symbol.lua
+++ b/script/core/document-symbol.lua
@@ -1,8 +1,8 @@
-local await = require 'await'
-local files = require 'files'
-local guide = require 'core.guide'
-local define = require 'proto.define'
-local util = require 'utility'
+local await = require 'await'
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local util = require 'utility'
local function buildName(source, text)
if source.type == 'setmethod'
diff --git a/script/core/find-source.lua b/script/core/find-source.lua
index b36306b6..edbb1e2c 100644
--- a/script/core/find-source.lua
+++ b/script/core/find-source.lua
@@ -1,4 +1,4 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local function isValidFunctionPos(source, offset)
for i = 1, #source.keyword // 2 do
diff --git a/script/core/folding.lua b/script/core/folding.lua
index 15678995..1bbae944 100644
--- a/script/core/folding.lua
+++ b/script/core/folding.lua
@@ -1,5 +1,5 @@
local files = require "files"
-local guide = require "core.guide"
+local searcher = require "core.searcher"
local util = require 'utility'
local Care = {
@@ -153,7 +153,7 @@ return function (uri)
local regions = {}
local status = {}
- guide.eachSource(ast.ast, function (source)
+ searcher.eachSource(ast.ast, function (source)
local tp = source.type
if Care[tp] then
Care[tp](source, text, regions)
diff --git a/script/core/generic.lua b/script/core/generic.lua
new file mode 100644
index 00000000..15950974
--- /dev/null
+++ b/script/core/generic.lua
@@ -0,0 +1,234 @@
+local guide = require 'parser.guide'
+local noder = require "core.noder"
+
+---@class generic.value
+---@field type string
+---@field closure generic.closure
+---@field proto parser.guide.object
+---@field parent parser.guide.object
+
+---@class generic.closure
+---@field type string
+---@field proto parser.guide.object
+---@field upvalues table<string, generic.value[]>
+---@field params generic.value[]
+---@field returns generic.value[]
+
+local m = {}
+
+---@param closure generic.closure
+---@param proto parser.guide.object
+local function instantValue(closure, proto)
+ ---@type generic.value
+ local value = {
+ type = 'generic.value',
+ closure = closure,
+ proto = proto,
+ parent = proto.parent,
+ }
+ closure.values[#closure.values+1] = value
+ return value
+end
+
+---递归实例化对象
+---@param proto parser.guide.object
+---@return generic.value
+local function createValue(closure, proto, callback, road)
+ if callback then
+ road = road or {}
+ end
+ if proto.type == 'doc.type' then
+ local types = {}
+ local hasGeneric
+ for i, tp in ipairs(proto.types) do
+ local genericValue = createValue(closure, tp, callback, road)
+ if genericValue then
+ hasGeneric = true
+ types[i] = genericValue
+ else
+ types[i] = tp
+ end
+ end
+ if not hasGeneric then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.types = types
+ noder.compileNode(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.name' then
+ if not proto.typeGeneric then
+ return nil
+ end
+ local key = proto[1]
+ local value = instantValue(closure, proto)
+ if callback then
+ callback(road, key, proto)
+ end
+ noder.compileNode(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.function' then
+ local hasGeneric
+ local args = {}
+ local returns = {}
+ for i, arg in ipairs(proto.args) do
+ local value = createValue(closure, arg, callback, road)
+ if value then
+ hasGeneric = true
+ end
+ args[i] = value or arg
+ end
+ for i, rtn in ipairs(proto.returns) do
+ local value = createValue(closure, rtn, callback, road)
+ if value then
+ hasGeneric = true
+ end
+ returns[i] = value or rtn
+ end
+ if not hasGeneric then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.args = args
+ value.returns = returns
+ value.isGeneric = true
+ noder.pushSource(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.array' then
+ if road then
+ road[#road+1] = noder.ANY_FIELD
+ end
+ local node = createValue(closure, proto.node, callback, road)
+ if road then
+ road[#road] = nil
+ end
+ if not node then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.node = node
+ return value
+ end
+ if proto.type == 'doc.type.table' then
+ road[#road+1] = noder.TABLE_KEY
+ local tkey = createValue(closure, proto.tkey, callback, road)
+ road[#road] = nil
+
+ road[#road+1] = noder.ANY_FIELD
+ local tvalue = createValue(closure, proto.tvalue, callback, road)
+ road[#road] = nil
+
+ if not tkey and not tvalue then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.tkey = tkey or proto.tkey
+ value.tvalue = tvalue or proto.tvalue
+ return value
+ end
+end
+
+local function buildValue(road, key, proto, param, upvalues)
+ local paramID
+ if proto.literal then
+ local str = param.type == 'string' and param[1]
+ if not str then
+ return
+ end
+ paramID = 'dn:' .. str
+ else
+ paramID = noder.getID(param)
+ end
+ if not paramID then
+ return
+ end
+ local myUri = guide.getUri(param)
+ local myHead = noder.URI_CHAR .. myUri .. noder.URI_CHAR
+ paramID = myHead .. paramID
+ if not upvalues[key] then
+ upvalues[key] = {}
+ end
+ upvalues[key][#upvalues[key]+1] = paramID .. table.concat(road)
+end
+
+-- 为所有的 param 与 return 创建副本
+---@param closure generic.closure
+local function buildValues(closure)
+ local protoFunction = closure.proto
+ local upvalues = closure.upvalues
+ local params = closure.call.args
+
+ if protoFunction.type == 'function' then
+ for _, doc in ipairs(protoFunction.bindDocs) do
+ if doc.type == 'doc.param' then
+ local extends = doc.extends
+ local index = extends.paramIndex
+ if index then
+ local param = params and params[index]
+ closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ buildValue(road, key, proto, param, upvalues)
+ end) or extends
+ end
+ end
+ end
+ for _, doc in ipairs(protoFunction.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ closure.returns[rtn.returnIndex] = createValue(closure, rtn) or rtn
+ end
+ end
+ end
+ end
+ if protoFunction.type == 'doc.type.function' then
+ for index, arg in ipairs(protoFunction.args) do
+ local extends = arg.extends
+ local param = params and params[index]
+ closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ buildValue(road, key, proto, param, upvalues)
+ end) or extends
+ end
+ for index, rtn in ipairs(protoFunction.returns) do
+ closure.returns[index] = createValue(closure, rtn) or rtn
+ end
+ end
+end
+
+---创建一个闭包
+---@param proto parser.guide.object|generic.value # 原型函数|泛型值
+---@return generic.closure
+function m.createClosure(proto, call)
+ local protoFunction, parentClosure
+ if proto.type == 'function' then
+ protoFunction = proto
+ elseif proto.type == 'doc.type.function' then
+ protoFunction = proto
+ elseif proto.type == 'generic.value' then
+ protoFunction = proto.proto
+ parentClosure = proto.closure
+ end
+ ---@type generic.closure
+ local closure = {
+ type = 'generic.closure',
+ parent = protoFunction.parent,
+ proto = protoFunction,
+ call = call,
+ upvalues = parentClosure and parentClosure.upvalues or {},
+ params = {},
+ returns = {},
+ values = {},
+ }
+ buildValues(closure)
+
+ if #closure.returns == 0 then
+ return nil
+ end
+
+ noder.compileNode(noder.getNoders(proto), closure)
+
+ return closure
+end
+
+return m
diff --git a/script/core/guide.lua b/script/core/guide2.lua
index e4871060..576c0c20 100644
--- a/script/core/guide.lua
+++ b/script/core/guide2.lua
@@ -292,8 +292,13 @@ end
---@param obj parser.guide.object
---@return parser.guide.object
function m.getRoot(obj)
+ local source = obj
+ if source._root then
+ return source._root
+ end
for _ = 1, 1000 do
if obj.type == 'main' then
+ source._root = obj
return obj
end
local parent = obj.parent
diff --git a/script/core/highlight.lua b/script/core/highlight.lua
index 12ec114f..45001134 100644
--- a/script/core/highlight.lua
+++ b/script/core/highlight.lua
@@ -1,31 +1,18 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local files = require 'files'
local vm = require 'vm'
local define = require 'proto.define'
local findSource = require 'core.find-source'
local util = require 'utility'
+local guide = require 'parser.guide'
local function eachRef(source, callback)
- local results = guide.requestReference(source)
+ local results = searcher.requestReference(source)
for i = 1, #results do
callback(results[i])
end
end
-local function eachField(source, callback)
- if not source then
- return
- end
- local isGlobal = guide.isGlobal(source)
- local results = guide.requestReference(source)
- for i = 1, #results do
- local res = results[i]
- if isGlobal == guide.isGlobal(res) then
- callback(res)
- end
- end
-end
-
local function eachLocal(source, callback)
callback(source)
if source.ref then
@@ -43,21 +30,21 @@ local function find(source, uri, callback)
eachLocal(source.node, callback)
elseif source.type == 'field'
or source.type == 'method' then
- eachField(source.parent, callback)
+ eachRef(source.parent, callback)
elseif source.type == 'getindex'
or source.type == 'setindex'
or source.type == 'tableindex' then
- eachField(source, callback)
+ eachRef(source, callback)
elseif source.type == 'setglobal'
or source.type == 'getglobal' then
- eachField(source, callback)
+ eachRef(source, callback)
elseif source.type == 'goto'
or source.type == 'label' then
eachRef(source, callback)
elseif source.type == 'string'
and source.parent
and source.parent.index == source then
- eachField(source.parent, callback)
+ eachRef(source.parent, callback)
elseif source.type == 'string'
or source.type == 'boolean'
or source.type == 'number'
@@ -238,6 +225,16 @@ local accept = {
['nil'] = true,
}
+local function isLiteralValue(source)
+ if not guide.isLiteral(source) then
+ return false
+ end
+ if source.parent.index == source then
+ return false
+ end
+ return true
+end
+
return function (uri, offset)
local ast = files.getAst(uri)
if not ast then
@@ -249,10 +246,25 @@ return function (uri, offset)
local source = findSource(ast, offset, accept)
if source then
+ local isGlobal = guide.isGlobal(source)
+ local isLiteral = isLiteralValue(source)
find(source, uri, function (target)
+ if not target then
+ return
+ end
if target.dummy then
return
end
+ if mark[target] then
+ return
+ end
+ mark[target] = true
+ if isGlobal ~= guide.isGlobal(target) then
+ return
+ end
+ if isLiteral ~= isLiteralValue(target) then
+ return
+ end
local kind
if target.type == 'getfield' then
target = target.field
@@ -315,13 +327,6 @@ return function (uri, offset)
else
return
end
- if not target then
- return
- end
- if mark[target] then
- return
- end
- mark[target] = true
results[#results+1] = {
start = target.start,
finish = target.finish,
diff --git a/script/core/hint.lua b/script/core/hint.lua
index 13d01dc7..43b8726e 100644
--- a/script/core/hint.lua
+++ b/script/core/hint.lua
@@ -1,7 +1,7 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local vm = require 'vm'
-local config = require 'config'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local vm = require 'vm'
+local config = require 'config'
local function typeHint(uri, edits, start, finish)
local ast = files.getAst(uri)
@@ -9,7 +9,7 @@ local function typeHint(uri, edits, start, finish)
return
end
local mark = {}
- guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ searcher.eachSourceBetween(ast.ast, start, finish, function (source)
if source.type ~= 'local'
and source.type ~= 'setglobal'
and source.type ~= 'tablefield'
@@ -21,7 +21,7 @@ local function typeHint(uri, edits, start, finish)
if source[1] == '_' then
return
end
- if source.value and guide.isLiteral(source.value) then
+ if source.value and searcher.isLiteral(source.value) then
return
end
if source.parent.type == 'funcargs' then
@@ -84,7 +84,7 @@ local function hasLiteralArgInCall(call)
return false
end
for _, arg in ipairs(call.args) do
- if guide.isLiteral(arg) then
+ if searcher.isLiteral(arg) then
return true
end
end
@@ -100,14 +100,14 @@ local function paramName(uri, edits, start, finish)
return
end
local mark = {}
- guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ searcher.eachSourceBetween(ast.ast, start, finish, function (source)
if source.type ~= 'call' then
return
end
if not hasLiteralArgInCall(source) then
return
end
- local defs = vm.getDefs(source.node, 0)
+ local defs = vm.getDefs(source.node)
if not defs then
return
end
@@ -130,7 +130,7 @@ local function paramName(uri, edits, start, finish)
table.remove(args, 1)
end
for i, arg in ipairs(source.args) do
- if not mark[arg] and guide.isLiteral(arg) then
+ if not mark[arg] and searcher.isLiteral(arg) then
mark[arg] = true
if args[i] and args[i] ~= '' then
edits[#edits+1] = {
diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua
index 324d28af..822be2b6 100644
--- a/script/core/hover/arg.lua
+++ b/script/core/hover/arg.lua
@@ -1,4 +1,5 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
+local infer = require 'core.infer'
local vm = require 'vm'
local function optionalArg(arg)
@@ -21,7 +22,7 @@ local function asFunction(source, oop)
methodDef = true
end
if methodDef then
- args[#args+1] = ('self: %s'):format(vm.getInferType(parent.node))
+ args[#args+1] = ('self: %s'):format(infer.searchAndViewInfers(parent.node))
end
if source.args then
for i = 1, #source.args do
@@ -34,10 +35,12 @@ local function asFunction(source, oop)
args[#args+1] = ('%s%s: %s'):format(
name,
optionalArg(arg) and '?' or '',
- vm.getInferType(arg)
+ infer.searchAndViewInfers(arg)
)
+ elseif arg.type == '...' then
+ args[#args+1] = '...'
else
- args[#args+1] = ('%s'):format(vm.getInferType(arg))
+ args[#args+1] = ('%s'):format(infer.searchAndViewInfers(arg))
end
::CONTINUE::
end
@@ -61,7 +64,7 @@ local function asDocFunction(source)
args[i] = ('%s%s: %s'):format(
name,
arg.optional and '?' or '',
- vm.getInferType(arg.extends)
+ infer.searchAndViewInfers(arg.extends)
)
else
args[i] = ('%s%s'):format(
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index 401ca5a7..bcc3065a 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -2,11 +2,13 @@ local vm = require 'vm'
local ws = require 'workspace'
local furi = require 'file-uri'
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local markdown = require 'provider.markdown'
local config = require 'config'
local lang = require 'language'
local util = require 'utility'
+local guide = require 'parser.guide'
+local noder = require 'core.noder'
local function asStringInRequire(source, literal)
local rootPath = ws.path or ''
@@ -124,10 +126,10 @@ local function getBindComment(source, docGroup, base)
end
local function tryDocClassComment(source)
- for _, def in ipairs(vm.getDefs(source, 0)) do
+ for _, def in ipairs(vm.getDefs(source)) do
if def.type == 'doc.class.name'
or def.type == 'doc.alias.name' then
- local class = guide.getDocState(def)
+ local class = noder.getDocState(def)
local comment = getBindComment(class, class.bindGroup, class)
if comment then
return comment
@@ -180,7 +182,7 @@ local function isFunction(source)
if source.type == 'function' then
return true
end
- local value = guide.getObjectValue(source)
+ local value = searcher.getObjectValue(source)
if not value then
return false
end
@@ -223,13 +225,14 @@ local function getBindEnums(source, docGroup)
end
local function tryDocFieldUpComment(source)
- if source.type ~= 'doc.field' then
+ if source.type ~= 'doc.field.name' then
return
end
- if not source.bindGroup then
+ local docField = source.parent
+ if not docField.bindGroup then
return
end
- local comment = getBindComment(source, source.bindGroup, source)
+ local comment = getBindComment(docField, docField.bindGroup, docField)
return comment
end
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
index 0c8644ed..5dd00c43 100644
--- a/script/core/hover/init.lua
+++ b/script/core/hover/init.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local vm = require 'vm'
local getLabel = require 'core.hover.label'
local getDesc = require 'core.hover.description'
@@ -7,6 +7,7 @@ local util = require 'utility'
local findSource = require 'core.find-source'
local lang = require 'language'
local markdown = require 'provider.markdown'
+local infer = require 'core.infer'
local function eachFunctionAndOverload(value, callback)
callback(value)
@@ -24,7 +25,7 @@ local function getHoverAsValue(source)
local label = getLabel(source)
local desc = getDesc(source)
if not desc then
- local values = vm.getDefs(source, 0)
+ local values = vm.getDefs(source)
for _, def in ipairs(values) do
desc = getDesc(def)
if desc then
@@ -40,7 +41,7 @@ local function getHoverAsValue(source)
end
local function getHoverAsFunction(source)
- local values = vm.getDefs(source, 0)
+ local values = vm.getDefs(source)
local desc = getDesc(source)
local labels = {}
local defs = 0
@@ -48,7 +49,7 @@ local function getHoverAsFunction(source)
local other = 0
local mark = {}
for _, def in ipairs(values) do
- def = guide.getObjectValue(def) or def
+ def = searcher.getObjectValue(def) or def
if def.type == 'function'
or def.type == 'doc.type.function' then
eachFunctionAndOverload(def, function (value)
@@ -123,7 +124,7 @@ local function getHover(source)
if source.type == 'doc.type.name' then
return getHoverAsDocName(source)
end
- local isFunction = vm.hasInferType(source, 'function', 0)
+ local isFunction = infer.hasType(source, 'function')
if isFunction then
return getHoverAsFunction(source)
else
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
index d93b14e3..032f19c0 100644
--- a/script/core/hover/label.lua
+++ b/script/core/hover/label.lua
@@ -2,9 +2,10 @@ local buildName = require 'core.hover.name'
local buildArg = require 'core.hover.arg'
local buildReturn = require 'core.hover.return'
local buildTable = require 'core.hover.table'
+local infer = require 'core.infer'
local vm = require 'vm'
local util = require 'utility'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local config = require 'config'
local files = require 'files'
@@ -31,29 +32,28 @@ local function asDocFunction(source)
end
local function asDocTypeName(source)
- for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ local defs = searcher.requestDefinition(source)
+ for _, doc in ipairs(defs) do
if doc.type == 'doc.class.name' then
- return 'class ' .. source[1]
+ return 'class ' .. doc[1]
end
if doc.type == 'doc.alias.name' then
local extends = doc.parent.extends
- return lang.script('HOVER_EXTENDS', vm.getInferType(extends))
+ return lang.script('HOVER_EXTENDS', infer.searchAndViewInfers(extends))
end
end
end
local function asValue(source, title)
local name = buildName(source)
- local infers = vm.getInfers(source, 0)
- local type = vm.getInferType(source, 0)
- local class = vm.getClass(source, 0)
- local literal = vm.getInferLiteral(source, 0)
+ local type = infer.searchAndViewInfers(source)
+ local literal = infer.searchAndViewLiterals(source)
local cont
- if not vm.hasInferType(source, 'string', 0)
+ if not infer.hasType(source, 'string')
and not type:find('%[%]$')
and not type:find('%w%<') then
- if #vm.getFields(source, 0) > 0
- or vm.hasInferType(source, 'table', 0) then
+ if #vm.getRefs(source, '*') > 0
+ or infer.hasType(source, 'table') then
cont = buildTable(source)
end
end
@@ -66,11 +66,7 @@ local function asValue(source, title)
or type == 'nil') then
type = nil
end
- if class then
- pack[#pack+1] = class
- else
- pack[#pack+1] = type
- end
+ pack[#pack+1] = type
if literal then
pack[#pack+1] = '='
pack[#pack+1] = literal
@@ -123,30 +119,21 @@ local function asField(source)
return asValue(source, 'field')
end
-local function asDocField(source)
- local name = source.field[1]
+local function asDocFieldName(source)
+ local name = source[1]
+ local docField = source.parent
local class
- for _, doc in ipairs(source.bindGroup) do
+ for _, doc in ipairs(docField.bindGroup) do
if doc.type == 'doc.class' then
class = doc
break
end
end
- local infers = {}
- for _, infer in ipairs(vm.getInfers(source.extends) or {}) do
- infers[#infers+1] = infer
- end
+ local view = infer.searchAndViewInfers(docField.extends)
if not class then
- return ('field ?.%s: %s'):format(
- name,
- guide.viewInferType(infers)
- )
- end
- return ('field %s.%s: %s'):format(
- class.class[1],
- name,
- guide.viewInferType(infers)
- )
+ return ('field ?.%s: %s'):format(name, view)
+ end
+ return ('field %s.%s: %s'):format(class.class[1], name, view)
end
local function asString(source)
@@ -177,7 +164,7 @@ local function asNumber(source)
if type(num) ~= 'number' then
return nil
end
- local uri = guide.getUri(source)
+ local uri = searcher.getUri(source)
local text = files.getText(uri)
if not text then
return nil
@@ -215,7 +202,7 @@ return function (source, oop)
return asDocFunction(source)
elseif source.type == 'doc.type.name' then
return asDocTypeName(source)
- elseif source.type == 'doc.field' then
- return asDocField(source)
+ elseif source.type == 'doc.field.name' then
+ return asDocFieldName(source)
end
end
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
index d583f1e1..d2b9d30b 100644
--- a/script/core/hover/name.lua
+++ b/script/core/hover/name.lua
@@ -1,4 +1,6 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
+local infer = require 'core.infer'
+local guide = require 'parser.guide'
local vm = require 'vm'
local buildName
@@ -19,7 +21,7 @@ end
local function asField(source, oop)
local class
if source.node.type ~= 'getglobal' then
- class = vm.getClass(source.node, 0)
+ class = infer.getClass(source.node)
end
local node = class or guide.getKeyName(source.node) or '?'
local method = guide.getKeyName(source)
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
index c3e9656d..0f0d85e0 100644
--- a/script/core/hover/return.lua
+++ b/script/core/hover/return.lua
@@ -1,12 +1,4 @@
-local guide = require 'core.guide'
-local vm = require 'vm'
-
-local function mergeTypes(returns)
- if type(returns) == 'string' then
- return returns
- end
- return guide.mergeTypes(returns)
-end
+local infer = require 'core.infer'
local function getReturnDualByDoc(source)
local docs = source.bindDocs
@@ -55,24 +47,20 @@ local function asFunction(source)
local returns = {}
for i, rtn in ipairs(dual) do
local line = {}
- local types = {}
+ local infers = {}
if i == 1 then
line[#line+1] = ' -> '
else
line[#line+1] = ('% 3d. '):format(i)
end
for n = 1, #rtn do
- local values = vm.getInfers(rtn[n])
- for _, value in ipairs(values) do
- if value.type then
- for tp in value.type:gmatch '[^|]+' do
- types[tp] = true
- end
- end
+ local values = infer.searchInfers(rtn[n])
+ for tp in pairs(values) do
+ infers[tp] = true
end
end
- if next(types) or rtn[1] then
- local tp = mergeTypes(types) or 'any'
+ if next(infers) or rtn[1] then
+ local tp = infer.viewInfers(infers)
if rtn[1].name then
line[#line+1] = ('%s%s: %s'):format(
rtn[1].name[1],
@@ -103,7 +91,7 @@ local function asDocFunction(source)
local returns = {}
for i, rtn in ipairs(source.returns) do
local rtnText = ('%s%s'):format(
- vm.getInferType(rtn),
+ infer.searchAndViewInfers(rtn),
rtn.optional and '?' or ''
)
if i == 1 then
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
index edb7751b..159453e6 100644
--- a/script/core/hover/table.lua
+++ b/script/core/hover/table.lua
@@ -1,26 +1,12 @@
local vm = require 'vm'
local util = require 'utility'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local config = require 'config'
local lang = require 'language'
+local infer = require 'core.infer'
-local function getKey(src)
- local key = vm.getKeyName(src)
- if not key or #key <= 0 then
- if not src.index then
- return '[any]'
- end
- local class = vm.getClass(src.index)
- if class then
- return ('[%s]'):format(class)
- end
- local tp = vm.getInferType(src.index)
- if tp then
- return ('[%s]'):format(tp)
- end
- return '[any]'
- end
- if guide.getKeyType(src) == 'string' then
+local function formatKey(key)
+ if type(key) == 'string' then
if key:match '^[%a_][%w_]*$' then
return key
else
@@ -30,104 +16,16 @@ local function getKey(src)
return ('[%s]'):format(key)
end
-local function getFieldFull(src)
- local value = guide.getObjectValue(src) or src
- local tp = vm.getInferType(value, 0)
- --local class = vm.getClass(src)
- local literal = vm.getInferLiteral(value)
- if type(literal) == 'string' and #literal >= 50 then
- literal = literal:sub(1, 47) .. '...'
- end
- return tp, literal
-end
-
-local function getFieldFast(src)
- if src.bindDocs then
- return getFieldFull(src)
- end
- local value = guide.getObjectValue(src) or src
- if not value then
- return 'any'
- end
- if value.type == 'boolean' then
- return value.type, util.viewLiteral(value[1])
- end
- if value.type == 'number'
- or value.type == 'integer' then
- if math.tointeger(value[1]) then
- if config.config.runtime.version == 'Lua 5.3'
- or config.config.runtime.version == 'Lua 5.4' then
- return 'integer', util.viewLiteral(value[1])
- end
- end
- return value.type, util.viewLiteral(value[1])
- end
- if value.type == 'table'
- or value.type == 'function' then
- return value.type
- end
- if value.type == 'string' then
- local literal = value[1]
- if type(literal) == 'string' and #literal >= 50 then
- literal = literal:sub(1, 47) .. '...'
- end
- return value.type, util.viewLiteral(literal)
- end
- if value.type == 'doc.field' then
- return vm.getInferType(value)
- end
-end
-
-local function getField(src, timeUp, mark, key)
- if src.type == 'table'
- or src.type == 'function' then
- return nil
- end
- if src.parent then
- if src.type == 'string'
- or src.type == 'boolean'
- or src.type == 'number'
- or src.type == 'integer' then
- if src.parent.type == 'tableindex'
- or src.parent.type == 'setindex'
- or src.parent.type == 'getindex' then
- if src.parent.index == src then
- src = src.parent
- end
- end
- end
- end
- local tp, literal
- tp, literal = getFieldFast(src)
- if tp then
- return tp, literal
- end
- if timeUp or mark[key] then
- return nil
- end
- mark[key] = true
- tp, literal = getFieldFull(src)
- if tp then
- return tp, literal
- end
- return nil
-end
-
-local function buildAsHash(classes, literals, reachMax)
- local keys = {}
- for k in pairs(classes) do
- keys[#keys+1] = k
- end
- table.sort(keys)
+local function buildAsHash(keys, inferMap, literalMap, reachMax)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local class = classes[key]
- local literal = literals[key]
- if literal then
- lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ local inferView = inferMap[key]
+ local literalView = literalMap[key]
+ if literalView then
+ lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView)
else
- lines[#lines+1] = (' %s: %s,'):format(key, class)
+ lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView)
end
end
if reachMax then
@@ -137,23 +35,19 @@ local function buildAsHash(classes, literals, reachMax)
return table.concat(lines, '\n')
end
-local function buildAsConst(classes, literals, reachMax)
- local keys = {}
- for k in pairs(classes) do
- keys[#keys+1] = k
- end
+local function buildAsConst(keys, inferMap, literalMap, reachMax)
table.sort(keys, function (a, b)
- return tonumber(literals[a]) < tonumber(literals[b])
+ return tonumber(literalMap[a]) < tonumber(literalMap[b])
end)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local class = classes[key]
- local literal = literals[key]
- if literal then
- lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ local inferView = inferMap[key]
+ local literalView = literalMap[key]
+ if literalView then
+ lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView)
else
- lines[#lines+1] = (' %s: %s,'):format(key, class)
+ lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView)
end
end
if reachMax then
@@ -163,111 +57,79 @@ local function buildAsConst(classes, literals, reachMax)
return table.concat(lines, '\n')
end
-local function mergeLiteral(literals)
- local results = {}
+local typeSorter = {
+ ['string'] = 1,
+ ['number'] = 2,
+ ['boolean'] = 3,
+}
+
+local function getKeyMap(fields)
+ local keys = {}
local mark = {}
- for _, value in ipairs(literals) do
- if not mark[value] then
- mark[value] = true
- results[#results+1] = value
+ for _, field in ipairs(fields) do
+ local key = vm.getKeyName(field)
+ local tp = vm.getKeyType(field)
+ if tp == 'number' then
+ key = tonumber(key)
+ elseif tp == 'boolean' then
+ key = key == 'true'
end
- end
- if #results == 0 then
- return nil
- end
- table.sort(results)
- return table.concat(results, '|')
-end
-
-local function mergeTypes(types)
- local results = {}
- local mark = {
- -- 讲道理table的keyvalue不会是nil
- ['nil'] = true,
- }
- for _, tv in ipairs(types) do
- for tp in tv:gmatch '[^|]+' do
- if not mark[tp] then
- mark[tp] = true
- results[tp] = true
- end
+ if key and not mark[key] then
+ mark[key] = true
+ keys[#keys+1] = key
end
end
- return guide.mergeTypes(results)
-end
-
-local function clearClasses(classes)
- classes['[nil]'] = nil
- classes['[any]'] = nil
- classes['[string]'] = nil
+ table.sort(keys, function (a, b)
+ local ta = typeSorter[type(a)]
+ local tb = typeSorter[type(b)]
+ if ta == tb then
+ return tostring(a) < tostring(b)
+ else
+ return ta < tb
+ end
+ end)
+ return keys
end
return function (source)
- if config.config.hover.previewFields <= 0 then
+ local maxFields = config.config.hover.previewFields
+ if maxFields <= 0 then
return 'table'
end
- local literals = {}
- local classes = {}
- local clock = os.clock()
- local timeUp
- local mark = {}
- local fields = vm.getFields(source, 0)
- local keyCount = 0
- local reachMax
- for _, src in ipairs(fields) do
- local key = getKey(src)
- if not key then
- goto CONTINUE
- end
- if not classes[key] then
- classes[key] = {}
- keyCount = keyCount + 1
- end
- if not literals[key] then
- literals[key] = {}
- end
- if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then
- timeUp = true
- end
- local class, literal = getField(src, timeUp, mark, key)
- if literal == 'nil' then
- literal = nil
- end
- classes[key][#classes[key]+1] = class
- literals[key][#literals[key]+1] = literal
- if keyCount >= config.config.hover.previewFields then
- reachMax = true
- break
- end
- ::CONTINUE::
- end
-
- clearClasses(classes)
- for key, class in pairs(classes) do
- literals[key] = mergeLiteral(literals[key])
- classes[key] = mergeTypes(class)
- end
+ local fields = vm.getRefs(source, '*')
+ local keys = getKeyMap(fields)
- if not next(classes) then
+ if #keys == 0 then
return '{}'
end
- local intValue = true
- for key, class in pairs(classes) do
- if class ~= 'integer' or not tonumber(literals[key]) then
- intValue = false
- break
+ local inferMap = {}
+ local literalMap = {}
+
+ local reachMax = maxFields < #keys
+
+ local isConsts = true
+ for i = 1, math.min(maxFields, #keys) do
+ local key = keys[i]
+ inferMap[key] = infer.searchAndViewInfers(source, key)
+ literalMap[key] = infer.searchAndViewLiterals(source, key)
+ if not tonumber(literalMap[key]) then
+ isConsts = false
end
end
+
local result
- if intValue then
- result = buildAsConst(classes, literals, reachMax)
+
+ if isConsts then
+ result = buildAsConst(keys, inferMap, literalMap, reachMax)
else
- result = buildAsHash(classes, literals, reachMax)
- end
- if timeUp then
- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ result = buildAsHash(keys, inferMap, literalMap, reachMax)
end
+
+ --if timeUp then
+ -- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ --end
+
return result
end
diff --git a/script/core/infer.lua b/script/core/infer.lua
new file mode 100644
index 00000000..77236811
--- /dev/null
+++ b/script/core/infer.lua
@@ -0,0 +1,634 @@
+local searcher = require 'core.searcher'
+local config = require 'config'
+local noder = require 'core.noder'
+local util = require 'utility'
+
+local STRING_OR_TABLE = {'STRING_OR_TABLE'}
+local BE_RETURN = {'BE_RETURN'}
+local BE_CONNACT = {'BE_CONNACT'}
+local CLASS = {'CLASS'}
+local TABLE = {'TABLE'}
+
+local TypeSort = {
+ ['boolean'] = 1,
+ ['string'] = 2,
+ ['integer'] = 3,
+ ['number'] = 4,
+ ['table'] = 5,
+ ['function'] = 6,
+ ['true'] = 101,
+ ['false'] = 102,
+ ['nil'] = 999,
+}
+
+local m = {}
+
+local function mergeTable(a, b)
+ if not b then
+ return
+ end
+ for v in pairs(b) do
+ a[v] = true
+ end
+end
+
+local function searchInferOfUnary(value, infers)
+ local op = value.op.type
+ if op == 'not' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '#' then
+ infers['integer'] = true
+ return
+ end
+ if op == '-' then
+ if m.hasType(value[1], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+end
+
+local function searchInferOfBinary(value, infers)
+ local op = value.op.type
+ if op == 'and' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[2]))
+ else
+ mergeTable(infers, m.searchInfers(value[1]))
+ end
+ return
+ end
+ if op == 'or' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[1]))
+ else
+ mergeTable(infers, m.searchInfers(value[2]))
+ end
+ return
+ end
+ if op == '=='
+ or op == '~='
+ or op == '<'
+ or op == '>'
+ or op == '<='
+ or op == '>=' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '&'
+ or op == '|' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers['string'] = true
+ return
+ end
+ if op == '^'
+ or op == '/' then
+ infers['number'] = true
+ return
+ end
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '%'
+ or op == '//' then
+ if m.hasType(value[1], 'integer')
+ and m.hasType(value[2], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+end
+
+local function searchInferOfValue(value, infers)
+ if value.type == 'string' then
+ infers['string'] = true
+ return true
+ end
+ if value.type == 'boolean' then
+ infers['boolean'] = true
+ return true
+ end
+ if value.type == 'table' then
+ if value.array then
+ local node = m.searchAndViewInfers(value.array)
+ local infer = node .. '[]'
+ infers[infer] = true
+ else
+ infers['table'] = true
+ end
+ return true
+ end
+ if value.type == 'number' then
+ if math.type(value[1]) == 'integer' then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return true
+ end
+ if value.type == 'nil' then
+ infers['nil'] = true
+ return true
+ end
+ if value.type == 'function' then
+ infers['function'] = true
+ return true
+ end
+ if value.type == 'unary' then
+ searchInferOfUnary(value, infers)
+ return true
+ end
+ if value.type == 'binary' then
+ searchInferOfBinary(value, infers)
+ return true
+ end
+ return false
+end
+
+local function searchLiteralOfValue(value, literals)
+ if value.type == 'string'
+ or value.type == 'boolean'
+ or value.type == 'number'
+ or value.type == 'integer' then
+ local v = value[1]
+ if v ~= nil then
+ literals[v] = true
+ end
+ return
+ end
+ if value.type == 'unary' then
+ local op = value.op.type
+ if op == '-' then
+ local subLiterals = m.searchLiterals(value[1])
+ if subLiterals then
+ for subLiteral in pairs(subLiterals) do
+ local num = tonumber(subLiteral)
+ if num then
+ literals[-num] = true
+ end
+ end
+ end
+ end
+ if op == '~' then
+ local subLiterals = m.searchLiterals(value[1])
+ if subLiterals then
+ for subLiteral in pairs(subLiterals) do
+ local num = math.tointeger(subLiteral)
+ if num then
+ literals[~num] = true
+ end
+ end
+ end
+ end
+ end
+ return
+end
+
+local function bindClassOrType(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.type' then
+ return true
+ end
+ end
+ return false
+end
+
+local function cleanInfers(infers)
+ local version = config.config.runtime.version
+ local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4'
+ infers['unknown'] = nil
+ if infers['any'] and infers['nil'] then
+ infers['nil'] = nil
+ end
+ if infers['number'] then
+ enableInteger = false
+ end
+ if not enableInteger and infers['integer'] then
+ infers['integer'] = nil
+ infers['number'] = true
+ end
+ -- stringlib 就是 string
+ if infers['stringlib'] and infers['string'] then
+ infers['stringlib'] = nil
+ end
+ -- 如果是通过 .. 来推测的,且结果里没有 number 与 integer,则推测为string
+ if infers[BE_CONNACT] then
+ infers[BE_CONNACT] = nil
+ if not infers['number'] and not infers['integer'] then
+ infers['string'] = true
+ end
+ end
+ -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型
+ if infers[STRING_OR_TABLE] then
+ infers[STRING_OR_TABLE] = nil
+ if not infers['table'] and not infers['string'] then
+ infers['table'] = true
+ infers['string'] = true
+ end
+ end
+ -- 如果有doc标记,则先移除table类型
+ if infers[CLASS] then
+ infers[CLASS] = nil
+ infers['table'] = nil
+ end
+ -- 用doc标记的table,加入table类型
+ if infers[TABLE] then
+ infers[TABLE] = nil
+ infers['table'] = true
+ end
+ if infers[BE_RETURN] then
+ infers[BE_RETURN] = nil
+ infers['nil'] = nil
+ end
+ infers['any'] = nil
+end
+
+---合并对象的推断类型
+---@param infers string[]
+---@return string
+function m.viewInfers(infers)
+ if infers[0] then
+ return infers[0]
+ end
+ -- 如果有显性的 any ,则直接显示为 any
+ if infers['any'] then
+ infers[0] = 'any'
+ return 'any'
+ end
+ local result = {}
+ local count = 0
+ for infer in pairs(infers) do
+ count = count + 1
+ result[count] = infer
+ end
+ -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any
+ if count == 0 then
+ infers[0] = 'any'
+ return 'any'
+ end
+ table.sort(result, function (a, b)
+ local sa = TypeSort[a] or 100
+ local sb = TypeSort[b] or 100
+ if sa == sb then
+ return a < b
+ else
+ return sa < sb
+ end
+ end)
+ infers[0] = table.concat(result, '|')
+ return infers[0]
+end
+
+---合并对象的值
+---@param literals string[]
+---@return string
+function m.viewLiterals(literals)
+ local result = {}
+ local count = 0
+ for infer in pairs(literals) do
+ count = count + 1
+ result[count] = util.viewLiteral(infer)
+ end
+ if count == 0 then
+ return nil
+ end
+ table.sort(result)
+ local view = table.concat(result, '|')
+ return view
+end
+
+function m.viewDocName(doc)
+ if not doc then
+ return nil
+ end
+ if doc.type == 'doc.type' then
+ local list = {}
+ for _, tp in ipairs(doc.types) do
+ list[#list+1] = m.getDocName(tp)
+ end
+ for _, enum in ipairs(doc.enums) do
+ list[#list+1] = m.getDocName(enum)
+ end
+ return table.concat(list, '|')
+ end
+ return m.getDocName(doc)
+end
+
+function m.getDocName(doc)
+ if not doc then
+ return nil
+ end
+ if doc.type == 'doc.class.name'
+ or doc.type == 'doc.type.name' then
+ local name = doc[1] or '?'
+ if doc.typeGeneric then
+ return '<' .. name .. '>'
+ else
+ return name
+ end
+ end
+ if doc.type == 'doc.type.array' then
+ local nodeName = m.viewDocName(doc.node) or '?'
+ return nodeName .. '[]'
+ end
+ if doc.type == 'doc.type.table' then
+ local key = m.viewDocName(doc.tkey) or '?'
+ local value = m.viewDocName(doc.tvalue) or '?'
+ return ('table<%s, %s>'):format(key, value)
+ end
+ if doc.type == 'doc.type.function' then
+ return 'function'
+ end
+ if doc.type == 'doc.type.enum'
+ or doc.type == 'doc.resume' then
+ local value = doc[1] or '?'
+ return value
+ end
+end
+
+function m.viewDocFunction(doc)
+ if doc.type ~= 'doc.type.function' then
+ return ''
+ end
+ local args = {}
+ for i, arg in ipairs(doc.args) do
+ args[i] = ('%s: %s'):format(arg.name[1], m.viewDocName(arg.extends))
+ end
+ local label = ('fun(%s)'):format(table.concat(args, ', '))
+ if #doc.returns > 0 then
+ local returns = {}
+ for i, rtn in ipairs(doc.returns) do
+ returns[i] = m.viewDocName(rtn)
+ end
+ label = ('%s:%s'):format(label, table.concat(returns))
+ end
+ return label
+end
+
+---显示对象的推断类型
+---@param source parser.guide.object
+---@return string
+local function searchInfer(source, infers)
+ if bindClassOrType(source) then
+ return
+ end
+ if searchInferOfValue(source, infers) then
+ return
+ end
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchInferOfValue(value, infers)
+ return
+ end
+ -- check LuaDoc
+ local docName = m.getDocName(source)
+ if docName then
+ infers[docName] = true
+ if docName ~= 'unknown' then
+ infers[CLASS] = true
+ end
+ if docName == 'table' then
+ infers[TABLE] = true
+ end
+ end
+ if source.parent.type == 'unary' then
+ local op = source.parent.op.type
+ -- # XX -> string | table
+ if op == '#' then
+ infers[STRING_OR_TABLE] = true
+ return
+ end
+ if op == '-' then
+ infers['number'] = true
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+ return
+ end
+ if source.parent.type == 'binary' then
+ local op = source.parent.op.type
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '/'
+ or op == '//'
+ or op == '^'
+ or op == '%' then
+ infers['number'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '|'
+ or op == '&' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers[BE_CONNACT] = true
+ return
+ end
+ end
+ -- X.a -> table
+ if source.next and source.next.node == source then
+ if source.next.type == 'setfield'
+ or source.next.type == 'setindex'
+ or source.next.type == 'setmethod'
+ or source.next.type == 'getfield'
+ or source.next.type == 'getindex' then
+ infers['table'] = true
+ end
+ if source.next.type == 'getmethod' then
+ infers[STRING_OR_TABLE] = true
+ end
+ end
+ -- return XX
+ if source.parent.type == 'return' then
+ infers[BE_RETURN] = true
+ end
+end
+
+local function searchLiteral(source, literals)
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchLiteralOfValue(value, literals)
+ return
+ end
+end
+
+---搜索对象的推断类型
+---@param source parser.guide.object
+---@param field? string
+---@return string[]
+function m.searchInfers(source, field)
+ if not source then
+ return nil
+ end
+ local defs = searcher.requestDefinition(source, field)
+ local infers = {}
+ local mark = {}
+ if not field then
+ mark[source] = true
+ searchInfer(source, infers)
+ local id = noder.getID(source)
+ if id then
+ local node = noder.getNodeByID(source, id)
+ if node and node.sources then
+ for _, src in ipairs(node.sources) do
+ if not mark[src] then
+ mark[src] = true
+ searchInfer(src, infers)
+ end
+ end
+ end
+ end
+ end
+ if source.type == 'field' or source.type == 'method' then
+ mark[source.parent] = true
+ searchInfer(source.parent, infers)
+ end
+ for _, def in ipairs(defs) do
+ if not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers)
+ end
+ end
+ if source.docParam then
+ local docType = source.docParam.extends
+ if docType.type == 'doc.type' then
+ for _, def in ipairs(docType.types) do
+ if def.typeGeneric and not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers)
+ end
+ end
+ end
+ end
+ if source.type == 'doc.type' then
+ if source.type == 'doc.type' then
+ for _, def in ipairs(source.types) do
+ if def.typeGeneric and not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers)
+ end
+ end
+ end
+ end
+ cleanInfers(infers)
+ return infers
+end
+
+---搜索对象的字面量值
+---@param source parser.guide.object
+---@param field? string
+---@return table
+function m.searchLiterals(source, field)
+ local defs = searcher.requestDefinition(source, field)
+ local literals = {}
+ local mark = {}
+ if not field then
+ mark[source] = true
+ searchLiteral(source, literals)
+ end
+ for _, def in ipairs(defs) do
+ if not mark[def] then
+ mark[def] = true
+ searchLiteral(def, literals)
+ end
+ end
+ return literals
+end
+
+---搜索并显示推断值
+---@param source parser.guide.object
+---@param field? string
+---@return string
+function m.searchAndViewLiterals(source, field)
+ if not source then
+ return nil
+ end
+ local literals = m.searchLiterals(source, field)
+ local view = m.viewLiterals(literals)
+ return view
+end
+
+---判断对象的推断值是否是 true
+---@param source parser.guide.object
+function m.isTrue(source)
+ if not source then
+ return false
+ end
+ local literals = m.searchLiterals(source)
+ for literal in pairs(literals) do
+ if literal ~= false then
+ return true
+ end
+ end
+ return false
+end
+
+---判断对象的推断类型是否包含某个类型
+function m.hasType(source, tp)
+ local infers = m.searchInfers(source)
+ return infers[tp] or false
+end
+
+---搜索并显示推断类型
+---@param source parser.guide.object
+---@param field? string
+---@return string
+function m.searchAndViewInfers(source, field)
+ if not source then
+ return 'any'
+ end
+ local infers = m.searchInfers(source, field)
+ local view = m.viewInfers(infers)
+ return view
+end
+
+---搜索并显示推断的class
+---@param source parser.guide.object
+---@return string?
+function m.getClass(source)
+ if not source then
+ return nil
+ end
+ local infers = {}
+ local defs = searcher.requestDefinition(source)
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class.name' then
+ infers[def[1]] = true
+ end
+ end
+ local view = m.viewInfers(infers)
+ if view == 'any' then
+ return nil
+ end
+ return view
+end
+
+return m
diff --git a/script/core/keyword.lua b/script/core/keyword.lua
index 71ea4969..73892f18 100644
--- a/script/core/keyword.lua
+++ b/script/core/keyword.lua
@@ -1,6 +1,6 @@
local define = require 'proto.define'
-local guide = require 'core.guide'
local files = require 'files'
+local guide = require 'parser.guide'
local keyWordMap = {
{'do', function (info, results)
diff --git a/script/core/noder.lua b/script/core/noder.lua
new file mode 100644
index 00000000..c3679612
--- /dev/null
+++ b/script/core/noder.lua
@@ -0,0 +1,926 @@
+local util = require 'utility'
+local guide = require 'parser.guide'
+
+local LastIDCache = {}
+local FirstIDCache = {}
+local SPLIT_CHAR = '\x1F'
+local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$'
+local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*'
+local ANY_FIELD_CHAR = '*'
+local RETURN_INDEX = SPLIT_CHAR .. '#'
+local PARAM_INDEX = SPLIT_CHAR .. '&'
+local TABLE_KEY = SPLIT_CHAR .. '<'
+local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR
+local URI_CHAR = '@'
+local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(.*)'
+
+---@class node
+-- 当前节点的id
+---@field id string
+-- 使用该ID的单元
+---@field sources parser.guide.object[]
+-- 前进的关联ID
+---@field forward string[]
+-- 后退的关联ID
+---@field backward string[]
+-- 函数调用参数信息(用于泛型)
+---@field call parser.guide.object
+
+---@alias noders table<string, node[]>
+
+---创建source的链接信息
+---@param noders noders
+---@param id string
+---@return node
+local function getNode(noders, id)
+ if not noders[id] then
+ noders[id] = {
+ id = id,
+ }
+ end
+ return noders[id]
+end
+
+---获取语法树单元的key
+---@param source parser.guide.object
+---@return string? key
+---@return parser.guide.object? node
+local function getKey(source)
+ if source.type == 'local' then
+ return tostring(source.start), nil
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ return tostring(source.node.start), nil
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ local node = source.node
+ if node.tag == '_ENV' then
+ return ('%q'):format(source[1] or ''), nil
+ else
+ return ('%q'):format(source[1] or ''), node
+ end
+ elseif source.type == 'getfield'
+ or source.type == 'setfield' then
+ return ('%q'):format(source.field and source.field[1] or ''), source.node
+ elseif source.type == 'tablefield' then
+ return ('%q'):format(source.field and source.field[1] or ''), source.parent
+ elseif source.type == 'getmethod'
+ or source.type == 'setmethod' then
+ return ('%q'):format(source.method and source.method[1] or ''), source.node
+ elseif source.type == 'setindex'
+ or source.type == 'getindex' then
+ local index = source.index
+ if not index then
+ return ANY_FIELD_CHAR, source.node
+ end
+ if index.type == 'string'
+ or index.type == 'boolean'
+ or index.type == 'number' then
+ return ('%q'):format(index[1] or ''), source.node
+ elseif index.type ~= 'function'
+ and index.type ~= 'table' then
+ return ANY_FIELD_CHAR, source.node
+ end
+ elseif source.type == 'tableindex' then
+ local index = source.index
+ if not index then
+ return ANY_FIELD_CHAR, source.parent
+ end
+ if index.type == 'string'
+ or index.type == 'boolean'
+ or index.type == 'number' then
+ return ('%q'):format(index[1] or ''), source.parent
+ elseif index.type ~= 'function'
+ and index.type ~= 'table' then
+ return ANY_FIELD_CHAR, source.parent
+ end
+ elseif source.type == 'table' then
+ return source.start, nil
+ elseif source.type == 'label' then
+ return source.start, nil
+ elseif source.type == 'goto' then
+ if source.node then
+ return source.node.start, nil
+ end
+ return nil, nil
+ elseif source.type == 'function' then
+ return source.start, nil
+ elseif source.type == 'string' then
+ return '', nil
+ elseif source.type == 'integer'
+ or source.type == 'number'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return source.start, nil
+ elseif source.type == '...' then
+ return source.start, nil
+ elseif source.type == 'varargs' then
+ if source.node then
+ return source.node.start, nil
+ end
+ elseif source.type == 'select' then
+ return ('%d%s%d'):format(source.start, RETURN_INDEX, source.sindex)
+ elseif source.type == 'call' then
+ local node = source.node
+ if node.special == 'rawget'
+ or node.special == 'rawset' then
+ if not source.args then
+ return nil, nil
+ end
+ local tbl, key = source.args[1], source.args[2]
+ if not tbl or not key then
+ return nil, nil
+ end
+ if key.type == 'string' then
+ return ('%q'):format(key[1] or ''), tbl
+ else
+ return '', tbl
+ end
+ end
+ return source.finish, nil
+ elseif source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name'
+ or source.type == 'doc.see.name' then
+ local name = source[1]
+ return name, nil
+ elseif source.type == 'doc.type.name' then
+ local name = source[1]
+ if source.typeGeneric then
+ return source.typeGeneric[name][1].start, nil
+ else
+ return name, nil
+ end
+ elseif source.type == 'doc.class'
+ or source.type == 'doc.type'
+ or source.type == 'doc.param'
+ or source.type == 'doc.vararg'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.function' then
+ return source.start, nil
+ elseif source.type == 'doc.see.field' then
+ return ('%q'):format(source[1]), source.parent.name
+ elseif source.type == 'generic.closure' then
+ return source.call.start, nil
+ elseif source.type == 'generic.value' then
+ return ('%s|%s'):format(
+ source.closure.call.start,
+ getKey(source.proto)
+ )
+ end
+ return nil, nil
+end
+
+local function checkMode(source)
+ if source.type == 'table' then
+ return 't:'
+ end
+ if source.type == 'select' then
+ return 's:'
+ end
+ if source.type == 'function' then
+ return 'f:'
+ end
+ if source.type == 'string' then
+ return 'str:'
+ end
+ if source.type == 'number'
+ or source.type == 'integer'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return 'i:'
+ end
+ if source.type == 'call' then
+ return 'c:'
+ end
+ if source.type == '...'
+ or source.type == 'varargs' then
+ return 'va:'
+ end
+ if source.type == 'doc.class.name'
+ or source.type == 'doc.type.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name' then
+ if source.typeGeneric then
+ return 'dg:'
+ end
+ return 'dn:'
+ end
+ if source.type == 'doc.field.name' then
+ return 'dfn:'
+ end
+ if source.type == 'doc.see.name' then
+ return 'dsn:'
+ end
+ if source.type == 'doc.class' then
+ return 'dc:'
+ end
+ if source.type == 'doc.type' then
+ return 'dt:'
+ end
+ if source.type == 'doc.param' then
+ return 'dp:'
+ end
+ if source.type == 'doc.type.function' then
+ return 'dfun:'
+ end
+ if source.type == 'doc.type.table' then
+ return 'dtable:'
+ end
+ if source.type == 'doc.type.array' then
+ return 'darray:'
+ end
+ if source.type == 'doc.vararg' then
+ return 'dv:'
+ end
+ if source.type == 'doc.type.enum'
+ or source.type == 'doc.resume' then
+ return 'de:'
+ end
+ if source.type == 'generic.closure' then
+ return 'gc:'
+ end
+ if source.type == 'generic.value' then
+ local id = 'gv:'
+ if guide.getUri(source.closure.call) ~= guide.getUri(source.proto) then
+ id = id .. URI_CHAR .. guide.getUri(source.closure.call)
+ end
+ return id
+ end
+ if guide.isGlobal(source) then
+ return 'g:'
+ end
+ if source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ source = source.node
+ end
+ if source.parent.type == 'funcargs' then
+ return 'p:'
+ end
+ return 'l:'
+end
+
+local IDList = {}
+---获取语法树单元的字符串ID
+---@param source parser.guide.object
+---@return string? id
+local function getID(source)
+ if not source then
+ return nil
+ end
+ if source._id ~= nil then
+ return source._id or nil
+ end
+ if source.type == 'field'
+ or source.type == 'method' then
+ source._id = false
+ return nil
+ end
+ local current = source
+ local index = 0
+ while true do
+ if current.type == 'paren' then
+ current = current.exp
+ goto CONTINUE
+ end
+ local id, node = getKey(current)
+ if not id then
+ break
+ end
+ index = index + 1
+ IDList[index] = id
+ if not node then
+ break
+ end
+ current = node
+ if current.special == '_G' then
+ for i = index, 2, -1 do
+ if IDList[i] == '"_G"' then
+ IDList[i] = nil
+ end
+ end
+ break
+ end
+ ::CONTINUE::
+ end
+ if index == 0 then
+ source._id = false
+ return nil
+ end
+ for i = index + 1, #IDList do
+ IDList[i] = nil
+ end
+ local mode = checkMode(current)
+ if not mode then
+ source._id = false
+ return nil
+ end
+ util.revertTable(IDList)
+ local id = mode .. table.concat(IDList, SPLIT_CHAR)
+ source._id = id
+ return id
+end
+
+---添加关联的前进ID
+---@param noders noders
+---@param id string
+---@param forwardID string
+local function pushForward(noders, id, forwardID, tag)
+ if not id
+ or not forwardID
+ or forwardID == ''
+ or id == forwardID then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.forward then
+ node.forward = {}
+ end
+ if node.forward[forwardID] ~= nil then
+ return
+ end
+ node.forward[forwardID] = tag or false
+ node.forward[#node.forward+1] = forwardID
+end
+
+---添加关联的后退ID
+---@param noders noders
+---@param id string
+---@param backwardID string
+local function pushBackward(noders, id, backwardID, tag)
+ if not id
+ or not backwardID
+ or backwardID == ''
+ or id == backwardID then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.backward then
+ node.backward = {}
+ end
+ if node.backward[backwardID] ~= nil then
+ return
+ end
+ node.backward[backwardID] = tag or false
+ node.backward[#node.backward+1] = backwardID
+end
+
+local m = {}
+
+m.SPLIT_CHAR = SPLIT_CHAR
+m.RETURN_INDEX = RETURN_INDEX
+m.PARAM_INDEX = PARAM_INDEX
+m.TABLE_KEY = TABLE_KEY
+m.ANY_FIELD = ANY_FIELD
+m.URI_CHAR = URI_CHAR
+
+--- 寻找doc的主体
+---@param obj parser.guide.object
+---@return parser.guide.object
+local function getDocStateWithoutCrossFunction(obj)
+ for _ = 1, 1000 do
+ local parent = obj.parent
+ if not parent then
+ return obj
+ end
+ if parent.type == 'doc' then
+ return obj
+ end
+ if parent.type == 'doc.type.function' then
+ return nil
+ end
+ obj = parent
+ end
+ error('guide.getDocState overstack')
+end
+
+---添加关联单元
+---@param noders noders
+---@param source parser.guide.object
+function m.pushSource(noders, source)
+ local id = m.getID(source)
+ if not id then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.sources then
+ node.sources = {}
+ end
+ node.sources[#node.sources+1] = source
+end
+
+---@param noders noders
+---@param source parser.guide.object
+---@return parser.guide.object[]
+function m.compileNode(noders, source)
+ local id = getID(source)
+ local value = source.value
+ if value then
+ local valueID = getID(value)
+ if valueID then
+ -- x = y : x -> y
+ pushForward(noders, id, valueID, 'set')
+ -- 参数禁止反向查找赋值
+ if valueID:sub(1, 2) ~= 'p:' then
+ pushBackward(noders, valueID, id, 'set')
+ end
+ end
+ end
+ -- self -> mt:xx
+ if source.type == 'local' and source[1] == 'self' then
+ local func = guide.getParentFunction(source)
+ if func.isGeneric then
+ return
+ end
+ local setmethod = func.parent
+ -- guess `self`
+ if setmethod and ( setmethod.type == 'setmethod'
+ or setmethod.type == 'setfield'
+ or setmethod.type == 'setindex') then
+ pushForward(noders, id, getID(setmethod.node), 'method')
+ pushBackward(noders, getID(setmethod.node), id, 'method')
+ end
+ end
+ -- 分解 @type
+ if source.type == 'doc.type' then
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
+ end
+ end
+ for _, enumUnit in ipairs(source.enums) do
+ pushForward(noders, id, getID(enumUnit))
+ end
+ for _, resumeUnit in ipairs(source.resumes) do
+ pushForward(noders, id, getID(resumeUnit))
+ end
+ for _, typeUnit in ipairs(source.types) do
+ local unitID = getID(typeUnit)
+ pushForward(noders, id, unitID)
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushBackward(noders, unitID, getID(src))
+ end
+ end
+ end
+ end
+ -- 分解 @alias
+ if source.type == 'doc.alias' then
+ pushForward(noders, getID(source.alias), getID(source.extends))
+ end
+ -- 分解 @class
+ if source.type == 'doc.class' then
+ pushForward(noders, id, getID(source.class))
+ pushForward(noders, getID(source.class), id)
+ if source.extends then
+ for _, ext in ipairs(source.extends) do
+ pushBackward(noders, id, getID(ext))
+ end
+ end
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
+ end
+ end
+ for _, field in ipairs(source.fields) do
+ local key = field.field[1]
+ if key then
+ local keyID = ('%s%s%q'):format(
+ id,
+ SPLIT_CHAR,
+ key
+ )
+ pushForward(noders, keyID, getID(field.field))
+ pushForward(noders, getID(field.field), keyID)
+ pushForward(noders, keyID, getID(field.extends))
+ pushBackward(noders, getID(field.extends), keyID)
+ end
+ end
+ end
+ if source.type == 'doc.param' then
+ pushForward(noders, id, getID(source.extends))
+ for _, src in ipairs(source.bindSources) do
+ if src.type == 'local' and src.parent.type == 'in' then
+ pushForward(noders, getID(src), id)
+ end
+ end
+ end
+ if source.type == 'doc.vararg' then
+ pushForward(noders, getID(source), getID(source.vararg))
+ end
+ if source.type == 'doc.see' then
+ local nameID = getID(source.name)
+ local classID = nameID:gsub('^dsn:', 'dn:')
+ pushForward(noders, nameID, classID)
+ if source.field then
+ local fieldID = getID(source.field)
+ local fieldClassID = fieldID:gsub('^dsn:', 'dn:')
+ pushForward(noders, fieldID, fieldClassID)
+ end
+ end
+ if source.type == 'call' then
+ local node = source.node
+ local nodeID = getID(node)
+ if not nodeID then
+ return
+ end
+ getNode(noders, id).call = source
+ -- 将 call 映射到 node#1 上
+ local callID = ('%s%s%s'):format(
+ nodeID,
+ RETURN_INDEX,
+ 1
+ )
+ pushForward(noders, id, callID)
+ -- 将setmetatable映射到 param1 以及 param2.__index 上
+ if node.special == 'setmetatable' then
+ local tblID = getID(source.args and source.args[1])
+ local metaID = getID(source.args and source.args[2])
+ local indexID
+ if metaID then
+ indexID = ('%s%s%q'):format(
+ metaID,
+ SPLIT_CHAR,
+ '__index'
+ )
+ end
+ pushForward(noders, id, callID)
+ pushBackward(noders, callID, id)
+ pushForward(noders, callID, tblID)
+ pushForward(noders, callID, indexID)
+ pushBackward(noders, tblID, callID)
+ --pushBackward(noders, indexID, callID)
+ end
+ if node.special == 'require' then
+ local arg1 = source.args and source.args[1]
+ if arg1 and arg1.type == 'string' then
+ getNode(noders, callID).require = arg1[1]
+ end
+ end
+ end
+ if source.type == 'select' then
+ if source.vararg.type == 'call' then
+ local call = source.vararg
+ local node = call.node
+ local nodeID = getID(node)
+ if not nodeID then
+ return
+ end
+ -- 将call的返回值接收映射到函数返回值上
+ local callXID = ('%s%s%s'):format(
+ nodeID,
+ RETURN_INDEX,
+ source.sindex
+ )
+ pushForward(noders, id, callXID)
+ pushBackward(noders, callXID, id)
+ getNode(noders, id).call = call
+ if node.special == 'pcall'
+ or node.special == 'xpcall' then
+ local index = source.sindex - 1
+ if index <= 0 then
+ return
+ end
+ local funcID = call.args and getID(call.args[1])
+ if not funcID then
+ return
+ end
+ local funcXID = ('%s%s%s'):format(
+ funcID,
+ RETURN_INDEX,
+ index
+ )
+ pushForward(noders, id, funcXID)
+ pushBackward(noders, funcXID, id)
+ end
+ end
+ if source.vararg.type == 'varargs' then
+ pushForward(noders, id, getID(source.vararg))
+ end
+ end
+ if source.type == 'doc.type.function' then
+ if source.returns then
+ for index, rtn in ipairs(source.returns) do
+ local returnID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ index
+ )
+ pushForward(noders, returnID, getID(rtn))
+ end
+ end
+ -- @type fun(x: T):T 的情况
+ local docType = getDocStateWithoutCrossFunction(source)
+ if docType and docType.type == 'doc.type' then
+ guide.eachSourceType(source, 'doc.type.name', function (typeName)
+ if typeName.typeGeneric then
+ source.isGeneric = true
+ return false
+ end
+ end)
+ end
+ end
+ if source.type == 'doc.type.table' then
+ if source.tkey then
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, getID(source.tkey))
+ end
+ if source.tvalue then
+ local valueID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, valueID, getID(source.tvalue))
+ end
+ end
+ if source.type == 'doc.type.array' then
+ if source.node then
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source.node))
+ end
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, 'dn:integer')
+ end
+ -- 将函数的返回值映射到具体的返回值上
+ if source.type == 'function' then
+ local hasDocReturn = {}
+ -- 检查 luadoc
+ if source.bindDocs then
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ local fullID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ rtn.returnIndex
+ )
+ pushForward(noders, fullID, getID(rtn))
+ hasDocReturn[rtn.returnIndex] = true
+ end
+ end
+ if doc.type == 'doc.param' then
+ local paramName = doc.param[1]
+ if source.docParamMap then
+ local paramIndex = source.docParamMap[paramName]
+ local param = source.args[paramIndex]
+ if param then
+ pushForward(noders, getID(param), getID(doc))
+ param.docParam = doc
+ end
+ end
+ end
+ if doc.type == 'doc.vararg' then
+ for _, param in ipairs(source.args) do
+ if param.type == '...' then
+ pushForward(noders, getID(param), getID(doc))
+ end
+ end
+ end
+ if doc.type == 'doc.generic' then
+ source.isGeneric = true
+ end
+ if doc.type == 'doc.overload' then
+ pushForward(noders, id, getID(doc.overload))
+ end
+ end
+ end
+ -- 检查实体返回值
+ if source.returns then
+ local returns = {}
+ for _, rtn in ipairs(source.returns) do
+ for index, rtnObj in ipairs(rtn) do
+ if not hasDocReturn[index] then
+ if not returns[index] then
+ returns[index] = {}
+ end
+ returns[index][#returns[index]+1] = rtnObj
+ end
+ end
+ end
+ for index, rtnObjs in ipairs(returns) do
+ local returnID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ index
+ )
+ for _, rtnObj in ipairs(rtnObjs) do
+ pushForward(noders, returnID, getID(rtnObj))
+ if rtnObj.type == 'function'
+ or rtnObj.type == 'call' then
+ pushBackward(noders, getID(rtnObj), returnID)
+ end
+ end
+ end
+ end
+ end
+ if source.type == 'table' then
+ if #source == 1 and source[1].type == 'varargs' then
+ source.array = source[1]
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source[1]))
+ end
+ end
+ if source.type == 'main' then
+ if source.returns then
+ for _, rtn in ipairs(source.returns) do
+ local rtnObj = rtn[1]
+ if rtnObj then
+ pushForward(noders, 'mainreturn', getID(rtnObj))
+ pushBackward(noders, getID(rtnObj), 'mainreturn')
+ end
+ end
+ end
+ end
+ if source.type == 'generic.closure' then
+ for i, rtn in ipairs(source.returns) do
+ local closureID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ i
+ )
+ local returnID = getID(rtn)
+ pushForward(noders, closureID, returnID)
+ end
+ end
+ if source.type == 'generic.value' then
+ local proto = source.proto
+ local closure = source.closure
+ local upvalues = closure.upvalues
+ if proto.type == 'doc.type.name' then
+ local key = proto[1]
+ if upvalues[key] then
+ for _, paramID in ipairs(upvalues[key]) do
+ pushForward(noders, id, paramID)
+ pushBackward(noders, paramID, id)
+ end
+ end
+ end
+ if proto.type == 'doc.type' then
+ for _, tp in ipairs(source.types) do
+ pushForward(noders, id, getID(tp))
+ pushBackward(noders, getID(tp), id)
+ end
+ end
+ if proto.type == 'doc.type.array' then
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source.node))
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, 'dn:integer')
+ end
+ if proto.type == 'doc.type.table' then
+ if source.tkey then
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, getID(source.tkey))
+ end
+ if source.tvalue then
+ local valueID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, valueID, getID(source.tvalue))
+ end
+ end
+ end
+end
+
+---根据ID来获取所有的node
+---@param root parser.guide.object
+---@param id string
+---@return node?
+function m.getNodeByID(root, id)
+ root = guide.getRoot(root)
+ local noders = root._noders
+ if not noders then
+ return nil
+ end
+ return noders[id]
+end
+
+---根据ID来获取第一个节点的ID
+---@param id string
+---@return string
+function m.getFirstID(id)
+ if FirstIDCache[id] then
+ return FirstIDCache[id] or nil
+ end
+ local firstID, count = id:match(FIRST_REGEX)
+ if count == 0 then
+ FirstIDCache[id] = false
+ return nil
+ end
+ FirstIDCache[id] = firstID
+ return firstID
+end
+
+---根据ID来获取上个节点的ID
+---@param id string
+---@return string
+function m.getLastID(id)
+ if LastIDCache[id] then
+ return LastIDCache[id] or nil
+ end
+ local lastID, count = id:gsub(LAST_REGEX, '')
+ if count == 0 then
+ LastIDCache[id] = false
+ return nil
+ end
+ LastIDCache[id] = lastID
+ return lastID
+end
+
+---把形如 `@file:\\\XXXXX@gv:1|1`拆分成uri与id
+---@param id string
+---@return uri? string
+---@return string id
+function m.getUriAndID(id)
+ local uri, newID = id:match(URI_REGEX)
+ return uri, newID
+end
+
+---获取source的ID
+---@param source parser.guide.object
+---@return string
+function m.getID(source)
+ return getID(source)
+end
+
+---获取source的key
+---@param source parser.guide.object
+---@return string
+function m.getKey(source)
+ return getKey(source)
+end
+
+---清除临时id(用于泛型的临时对象)
+---@param root parser.guide.object
+---@param id string
+function m.removeID(root, id)
+ root = guide.getRoot(root)
+ local noders = root._noders
+ noders[id] = nil
+end
+
+---寻找doc的主体
+---@param doc parser.guide.object
+function m.getDocState(doc)
+ return getDocStateWithoutCrossFunction(doc)
+end
+
+---获取对象的noders
+---@param source parser.guide.object
+---@return noders
+function m.getNoders(source)
+ local root = guide.getRoot(source)
+ if not root._noders then
+ root._noders = {}
+ end
+ return root._noders
+end
+
+---编译整个文件的node
+---@param source parser.guide.object
+---@return table
+function m.compileNodes(source)
+ local root = guide.getRoot(source)
+ local noders = m.getNoders(source)
+ if next(noders) then
+ return noders
+ end
+ guide.eachSource(root, function (src)
+ m.pushSource(noders, src)
+ m.compileNode(noders, src)
+ end)
+ -- Special rule: ('').XX -> stringlib.XX
+ pushBackward(noders, 'str:', 'dn:stringlib')
+ pushBackward(noders, 'dn:string', 'dn:stringlib')
+ return noders
+end
+
+return m
diff --git a/script/core/reference.lua b/script/core/reference.lua
index 7620b09e..c3f3b349 100644
--- a/script/core/reference.lua
+++ b/script/core/reference.lua
@@ -1,4 +1,5 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
+local guide = require 'parser.guide'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
@@ -6,8 +7,8 @@ local findSource = require 'core.find-source'
local function sortResults(results)
-- 先按照顺序排序
table.sort(results, function (a, b)
- local u1 = guide.getUri(a.target)
- local u2 = guide.getUri(b.target)
+ local u1 = searcher.getUri(a.target)
+ local u2 = searcher.getUri(b.target)
if u1 == u2 then
return a.target.start < b.target.start
else
@@ -19,7 +20,7 @@ local function sortResults(results)
for i = #results, 1, -1 do
local res = results[i].target
local f = res.finish
- local uri = guide.getUri(res)
+ local uri = searcher.getUri(res)
if lf and f > lf and uri == lu then
table.remove(results, i)
else
@@ -64,11 +65,23 @@ return function (uri, offset)
local metaSource = vm.isMetaFile(uri)
+ local refs = vm.getRefs(source)
+ local values = {}
+ for _, src in ipairs(refs) do
+ local value = searcher.getObjectValue(src)
+ if value and value ~= src and guide.isLiteral(value) then
+ values[value] = true
+ end
+ end
+
local results = {}
- for _, src in ipairs(vm.getRefs(source, 5)) do
+ for _, src in ipairs(refs) do
if src.dummy then
goto CONTINUE
end
+ if values[src] then
+ goto CONTINUE
+ end
local root = guide.getRoot(src)
if not root then
goto CONTINUE
diff --git a/script/core/rename.lua b/script/core/rename.lua
index da82b0a6..6b67d4be 100644
--- a/script/core/rename.lua
+++ b/script/core/rename.lua
@@ -1,11 +1,11 @@
local files = require 'files'
local vm = require 'vm'
-local guide = require 'core.guide'
local proto = require 'proto'
local define = require 'proto.define'
local util = require 'utility'
local findSource = require 'core.find-source'
-local ws = require 'workspace'
+local guide = require 'parser.guide'
+local noder = require 'core.noder'
local Forcing
@@ -185,7 +185,7 @@ local function renameField(source, newname, callback)
end
callback(source, source.start, source.finish, newname)
elseif parent.type == 'setmethod' then
- local uri = guide.getUri(source)
+ local uri = guide.getUri(source)
local text = files.getText(uri)
local func = parent.value
-- function mt:name () end --> mt['newname'] = function (self) end
@@ -292,14 +292,14 @@ local function ofField(source, newname, callback)
else
node = source.node
end
- for _, src in ipairs(vm.getFields(node, 5)) do
+ for _, src in ipairs(vm.getRefs(node, '*')) do
ofFieldThen(key, src, newname, callback)
end
end
local function ofGlobal(source, newname, callback)
local key = guide.getKeyName(source)
- for _, src in ipairs(vm.getRefs(source, 0)) do
+ for _, src in ipairs(vm.getRefs(source)) do
ofFieldThen(key, src, newname, callback)
end
end
@@ -308,24 +308,27 @@ local function ofLabel(source, newname, callback)
if not isValidName(newname) and not askForcing(newname)then
return false
end
- for _, src in ipairs(vm.getRefs(source, 0)) do
+ for _, src in ipairs(vm.getRefs(source)) do
callback(src, src.start, src.finish, newname)
end
end
local function ofDocTypeName(source, newname, callback)
- for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ local oldname = source[1]
+ for _, doc in ipairs(vm.getRefs(source)) do
if doc.type == 'doc.class.name'
or doc.type == 'doc.type.name'
or doc.type == 'doc.alias.name' then
- callback(doc, doc.start, doc.finish, newname)
+ if oldname == doc[1] then
+ callback(doc, doc.start, doc.finish, newname)
+ end
end
end
end
local function ofDocParamName(source, newname, callback)
callback(source, source.start, source.finish, newname)
- local doc = guide.getDocState(source)
+ local doc = noder.getDocState(source)
if doc.bindSources then
for _, src in ipairs(doc.bindSources) do
if src.type == 'local'
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
new file mode 100644
index 00000000..11e00378
--- /dev/null
+++ b/script/core/searcher.lua
@@ -0,0 +1,728 @@
+local noder = require 'core.noder'
+local guide = require 'parser.guide'
+local files = require 'files'
+local generic = require 'core.generic'
+local ws = require 'workspace'
+local vm = require 'vm.vm'
+
+local NONE = {'NONE'}
+local LAST = {'LAST'}
+
+local ignoredIDs = {
+ ['dn:unknown'] = true,
+ ['dn:nil'] = true,
+ ['dn:any'] = true,
+ ['dn:boolean'] = true,
+ ['dn:string'] = true,
+ ['dn:table'] = true,
+ ['dn:number'] = true,
+ ['dn:integer'] = true,
+ ['dn:userdata'] = true,
+ ['dn:lightuserdata'] = true,
+ ['dn:function'] = true,
+ ['dn:thread'] = true,
+}
+
+local m = {}
+
+---@alias guide.searchmode '"ref"'|'"def"'
+
+---添加结果
+---@param status guide.status
+---@param mode guide.searchmode
+---@param source parser.guide.object
+---@param force boolean
+function m.pushResult(status, mode, source, force)
+ if not source then
+ return
+ end
+ local results = status.results
+ if results[source] then
+ return
+ end
+ results[source] = true
+ if force then
+ results[#results+1] = source
+ return
+ end
+ local parent = source.parent
+ if mode == 'def' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal'
+ or source.type == 'label'
+ or source.type == 'setfield'
+ or source.type == 'setmethod'
+ or source.type == 'setindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset' then
+ results[#results+1] = source
+ end
+ end
+ if parent.type == 'return' then
+ if noder.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
+ elseif mode == 'ref' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal'
+ or source.type == 'label'
+ or source.type == 'goto'
+ or source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod'
+ or source.type == 'setindex'
+ or source.type == 'getindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number'
+ or source.type == 'nil'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.type.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset'
+ or source.node.special == 'rawget' then
+ results[#results+1] = source
+ end
+ end
+ if parent.type == 'return' then
+ if noder.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
+ end
+end
+
+---获取uri
+---@param obj parser.guide.object
+---@return uri
+function m.getUri(obj)
+ if obj.uri then
+ return obj.uri
+ end
+ local root = guide.getRoot(obj)
+ if root then
+ return root.uri
+ end
+ return ''
+end
+
+---@param obj parser.guide.object
+---@return parser.guide.object?
+function m.getObjectValue(obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return nil
+ end
+ end
+ if obj.type == 'boolean'
+ or obj.type == 'number'
+ or obj.type == 'integer'
+ or obj.type == 'string' then
+ return obj
+ end
+ if obj.value then
+ return obj.value
+ end
+ if obj.type == 'field'
+ or obj.type == 'method' then
+ return obj.parent and obj.parent.value
+ end
+ if obj.type == 'call' then
+ if obj.node.special == 'rawset' then
+ return obj.args and obj.args[3]
+ else
+ return obj
+ end
+ end
+ if obj.type == 'select' then
+ return obj
+ end
+ return nil
+end
+
+local function crossSearch(status, uri, expect, mode)
+ m.searchRefsByID(status, uri, expect, mode)
+end
+
+local function getLock(status, uri, expect, mode)
+ local slock = status.lock
+ local ulock = slock[uri]
+ if not ulock then
+ ulock = {}
+ slock[uri] = ulock
+ end
+ local mlock = ulock[mode]
+ if not mlock then
+ mlock = {}
+ ulock[mode] = mlock
+ end
+ if mlock[expect] then
+ return false
+ end
+ mlock[expect] = true
+ return true
+end
+
+function m.searchRefsByID(status, uri, expect, mode)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ if not getLock(status, uri, expect, mode) then
+ return
+ end
+ local root = ast.ast
+ local searchStep
+ noder.compileNodes(root)
+
+ status.id = expect
+
+ local callStack = status.callStack
+
+ local mark = {}
+
+ local function search(id, field)
+ local firstID = noder.getFirstID(id)
+ if ignoredIDs[firstID] and (field or firstID ~= id) then
+ return
+ end
+ local cmark = mark[id]
+ if not cmark then
+ cmark = {}
+ mark[id] = cmark
+ end
+ log.debug('search:', id, field)
+ if field then
+ if cmark[field] then
+ return
+ end
+ cmark[field] = true
+ searchStep(id, field)
+ cmark[field] = nil
+ else
+ if cmark[NONE] then
+ return
+ end
+ cmark[NONE] = true
+ searchStep(id, nil)
+ cmark[NONE] = nil
+ end
+ log.debug('pop:', id, field)
+ end
+
+ local function checkLastID(id, field)
+ local cmark = mark[id]
+ if not cmark then
+ cmark = {}
+ mark[id] = cmark
+ end
+ if cmark[LAST] then
+ return
+ end
+ local lastID = noder.getLastID(id)
+ if not lastID then
+ return
+ end
+ local newField = id:sub(#lastID + 1)
+ if field then
+ newField = newField .. field
+ end
+ cmark[LAST] = true
+ search(lastID, newField)
+ cmark[LAST] = nil
+ return lastID
+ end
+
+ local function searchID(id, field)
+ if not id then
+ return
+ end
+ if field then
+ id = id .. field
+ end
+ search(id, nil)
+ end
+
+ local function isCallID(field)
+ if not field then
+ return false
+ end
+ if field:sub(1, 2) == noder.RETURN_INDEX then
+ return true
+ end
+ return false
+ end
+
+ local function findLastCall()
+ for i = #callStack, 1, -1 do
+ local call = callStack[i]
+ if call then
+ -- 标记此处的call失效,等待在堆栈平衡时弹出
+ callStack[i] = false
+ return call
+ end
+ end
+ return nil
+ end
+
+ local genericCallArgs = {}
+ local closureCache = {}
+ local function checkGeneric(source, field)
+ if not source.isGeneric then
+ return
+ end
+ if not isCallID(field) then
+ return
+ end
+ local call = findLastCall()
+ if not call then
+ return
+ end
+
+ if call.args then
+ for _, arg in ipairs(call.args) do
+ genericCallArgs[arg] = true
+ end
+ end
+
+ local cacheID = noder.getID(source) .. noder.getID(call)
+ local closure = closureCache[cacheID]
+ if closure == false then
+ return
+ end
+ if not closure then
+ closure = generic.createClosure(source, call)
+ closureCache[cacheID] = closure or false
+ if not closure then
+ return
+ end
+ end
+ local id = noder.getID(closure)
+ searchID(id, field)
+ end
+
+ local function checkENV(source, field)
+ if not field then
+ return
+ end
+ if source.special ~= '_G' then
+ return
+ end
+ local newID = 'g:' .. field:sub(2)
+ searchID(newID)
+ end
+
+ local forwardTag = {}
+ local backwardTag = {}
+ local function checkForward(id, node, field)
+ for _, forwardID in ipairs(node.forward) do
+ local tag = node.forward[forwardID]
+ if tag then
+ if backwardTag[tag] and backwardTag[tag] > 0 then
+ goto CONTINUE
+ end
+ forwardTag[tag] = (forwardTag[tag] or 0) + 1
+ end
+ local targetUri, targetID = noder.getUriAndID(forwardID)
+ if targetUri and not files.eq(targetUri, uri) then
+ crossSearch(status, targetUri, targetID .. (field or ''), mode)
+ else
+ searchID(targetID or forwardID, field)
+ end
+ if tag then
+ forwardTag[tag] = forwardTag[tag] - 1
+ end
+ ::CONTINUE::
+ end
+ end
+
+ local function checkBackward(id, node, field)
+ if mode ~= 'ref' and not field then
+ return
+ end
+ for _, backwardID in ipairs(node.backward) do
+ local tag = node.backward[backwardID]
+ if tag then
+ if forwardTag[tag] and forwardTag[tag] > 0 then
+ goto CONTINUE
+ end
+ backwardTag[tag] = (backwardTag[tag] or 0) + 1
+ end
+ local targetUri, targetID = noder.getUriAndID(backwardID)
+ if targetUri and not files.eq(targetUri, uri) then
+ crossSearch(status, targetUri, targetID .. (field or ''), mode)
+ else
+ searchID(targetID or backwardID, field)
+ end
+ if tag then
+ backwardTag[tag] = backwardTag[tag] - 1
+ end
+ ::CONTINUE::
+ end
+ end
+
+ local function checkRequire(requireName, field)
+ local tid = 'mainreturn' .. (field or '')
+ local uris = ws.findUrisByRequirePath(requireName)
+ for _, ruri in ipairs(uris) do
+ if not files.eq(uri, ruri) then
+ crossSearch(status, ruri, tid, mode)
+ end
+ end
+ end
+
+ local function checkGlobal(id, node, field)
+ if id:sub(1, 2) ~= 'g:' then
+ return
+ end
+ local firstID = noder.getFirstID(id)
+ if status.crossed[firstID] then
+ return
+ end
+ status.crossed[firstID] = true
+ local tid = id .. (field or '')
+ for guri in files.eachFile() do
+ if not files.eq(uri, guri) then
+ crossSearch(status, guri, tid, mode)
+ end
+ end
+ end
+
+ local function checkClass(id, node, field)
+ if id:sub(1, 3) ~= 'dn:' then
+ return
+ end
+ local firstID = noder.getFirstID(id)
+ if status.crossed[firstID] then
+ return
+ end
+ status.crossed[firstID] = true
+ local tid = id .. (field or '')
+ for guri in files.eachFile() do
+ if not files.eq(uri, guri) then
+ crossSearch(status, guri, tid, mode)
+ end
+ end
+ end
+
+ local function checkMainReturn(id, node, field)
+ if id ~= 'mainreturn' then
+ return
+ end
+ if mode ~= 'ref' and not field then
+ return
+ end
+ local calls = vm.getLinksTo(uri)
+ for _, call in ipairs(calls) do
+ local turi = guide.getUri(call)
+ if not files.eq(turi, uri) then
+ local tid = noder.getID(call) .. (field or '')
+ crossSearch(status, turi, tid, mode)
+ end
+ end
+ end
+
+ local function searchNode(id, node, field)
+ if node.call then
+ callStack[#callStack+1] = node.call
+ end
+ if field == nil and node.sources then
+ for _, source in ipairs(node.sources) do
+ local force = genericCallArgs[source]
+ m.pushResult(status, mode, source, force)
+ end
+ end
+ if node.forward then
+ checkForward(id, node, field)
+ end
+ if node.backward then
+ checkBackward(id, node, field)
+ end
+
+ if node.sources then
+ checkGeneric(node.sources[1], field)
+ checkENV(node.sources[1], field)
+ end
+
+ if node.require then
+ checkRequire(node.require, field)
+ end
+
+ checkMainReturn(id, node, field)
+
+ if node.call then
+ callStack[#callStack] = nil
+ end
+ end
+
+ local function checkCrossUri(id, field)
+ local targetUri, newID = noder.getUriAndID(id)
+ if not targetUri then
+ return false
+ end
+ crossSearch(status, targetUri, newID .. (field or ''), mode)
+ return true
+ end
+
+ local stepCount = 0
+ function searchStep(id, field)
+ stepCount = stepCount + 1
+ if stepCount > 1000 then
+ error('too large')
+ end
+ local node = noder.getNodeByID(root, id)
+ if node then
+ searchNode(id, node, field)
+ end
+ checkGlobal(id, node, field)
+ checkClass(id, node, field)
+ local lastID = checkLastID(id, field)
+ if not lastID then
+ return
+ end
+ local originField = id:sub(#lastID + 1)
+ if originField == noder.TABLE_KEY then
+ return
+ end
+ local anyFieldID = lastID .. noder.ANY_FIELD
+ local anyFieldNode = noder.getNodeByID(root, anyFieldID)
+ if anyFieldNode then
+ searchNode(anyFieldID, anyFieldNode, field)
+ end
+ end
+
+ search(expect)
+
+ --清除来自泛型的临时对象
+ for _, closure in pairs(closureCache) do
+ noder.removeID(root, noder.getID(closure))
+ if closure then
+ for _, value in ipairs(closure.values) do
+ noder.removeID(root, noder.getID(value))
+ end
+ end
+ end
+end
+
+local function prepareSearch(source)
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ local root = guide.getRoot(source)
+ noder.compileNodes(root)
+ local uri = guide.getUri(source)
+ local id = noder.getID(source)
+ return uri, id
+end
+
+local function getField(status, source, mode)
+ if source.type == 'table' then
+ for _, field in ipairs(source) do
+ m.pushResult(status, mode, field)
+ end
+ return
+ end
+ if source.type == 'doc.class.name' then
+ local class = source.parent
+ for _, field in ipairs(class.fields) do
+ m.pushResult(status, mode, field.field)
+ end
+ return
+ end
+ local field = source.next
+ if field then
+ if field.type == 'getmethod'
+ or field.type == 'setmethod'
+ or field.type == 'getfield'
+ or field.type == 'setfield'
+ or field.type == 'getindex'
+ or field.type == 'setindex' then
+ m.pushResult(status, mode, field)
+ end
+ return
+ end
+end
+
+local function searchAllGlobalByUri(status, mode, uri, fullID)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local root = ast.ast
+ noder.compileNodes(root)
+ local noders = noder.getNoders(root)
+ if fullID then
+ for id, node in pairs(noders) do
+ if node.sources
+ and id == fullID then
+ for _, source in ipairs(node.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ end
+ else
+ for id, node in pairs(noders) do
+ if node.sources
+ and id:sub(1, 2) == 'g:'
+ and not id:find(noder.SPLIT_CHAR) then
+ for _, source in ipairs(node.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ end
+ end
+end
+
+local function searchAllGlobals(status, mode, fullID)
+ for uri in files.eachFile() do
+ searchAllGlobalByUri(status, mode, uri, fullID)
+ end
+end
+
+---搜索对象的引用
+---@param status guide.status
+---@param source parser.guide.object
+---@param mode guide.searchmode
+function m.searchRefs(status, source, mode)
+ local uri, id = prepareSearch(source)
+ if not id then
+ return
+ end
+ log.debug('searchRefs:', id)
+ m.searchRefsByID(status, uri, id, mode)
+end
+
+function m.findGlobals(uri, mode, name)
+ local status = m.status()
+
+ if name then
+ local fullID = ('g:%q'):format(name)
+ searchAllGlobalByUri(status, mode, uri, fullID)
+ else
+ searchAllGlobalByUri(status, mode, uri)
+ end
+
+ return status.results
+end
+
+---搜索对象的field
+---@param status guide.status
+---@param source parser.guide.object
+---@param mode guide.searchmode
+---@param field string
+function m.searchFields(status, source, mode, field)
+ local uri, id = prepareSearch(source)
+ if not id then
+ return
+ end
+ log.debug('searchFields:', id, field)
+ if field == '*' then
+ if source.special == '_G' then
+ searchAllGlobals(status, mode)
+ else
+ local newStatus = m.status(status)
+ m.searchRefsByID(newStatus, uri, id, mode)
+ for _, def in ipairs(newStatus.results) do
+ getField(status, def, mode)
+ end
+ end
+ else
+ if source.special == '_G' then
+ local fullID = ('g:%q'):format(field)
+ m.searchRefsByID(status, uri, fullID, mode)
+ else
+ local fullID = ('%s%s%q'):format(id, noder.SPLIT_CHAR, field)
+ m.searchRefsByID(status, uri, fullID, mode)
+ end
+ end
+end
+
+---@class guide.status
+---搜索结果
+---@field results parser.guide.object[]
+
+---创建搜索状态
+---@param parentStatus guide.status
+---@return guide.status
+function m.status(parentStatus)
+ local status = {
+ --mark = parentStatus and parentStatus.mark or {},
+ callStack = {},
+ crossed = {},
+ lock = {},
+ results = {},
+ }
+ return status
+end
+
+--- 请求对象的引用
+---@param obj parser.guide.object
+---@param field? string
+---@return parser.guide.object[]
+function m.requestReference(obj, field)
+ local status = m.status()
+
+ if field then
+ m.searchFields(status, obj, 'ref', field)
+ else
+ m.searchRefs(status, obj, 'ref')
+ end
+
+ return status.results
+end
+
+--- 请求对象的定义
+---@param obj parser.guide.object
+---@param field? string
+---@return parser.guide.object[]
+function m.requestDefinition(obj, field)
+ local status = m.status()
+
+ if field then
+ m.searchFields(status, obj, 'def', field)
+ else
+ m.searchRefs(status, obj, 'def')
+ end
+
+ return status.results
+end
+
+return m
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index f8feaa09..5e9ee9b1 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local await = require 'await'
local define = require 'proto.define'
local vm = require 'vm'
@@ -221,7 +221,7 @@ return function (uri, start, finish)
local results = {}
local count = 0
- guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ searcher.eachSourceBetween(ast.ast, start, finish, function (source)
local method = Care[source.type]
if not method then
return
diff --git a/script/core/signature.lua b/script/core/signature.lua
index eb740784..915310c0 100644
--- a/script/core/signature.lua
+++ b/script/core/signature.lua
@@ -1,8 +1,9 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local vm = require 'vm'
local hoverLabel = require 'core.hover.label'
local hoverDesc = require 'core.hover.description'
+local guide = require 'parser.guide'
local function findNearCall(uri, ast, pos)
local text = files.getText(uri)
@@ -96,10 +97,10 @@ local function makeSignatures(call, pos)
index = 1
end
local signs = {}
- local defs = vm.getDefs(node, 0)
+ local defs = vm.getDefs(node)
local mark = {}
for _, src in ipairs(defs) do
- src = guide.getObjectValue(src) or src
+ src = searcher.getObjectValue(src) or src
if src.type == 'function'
or src.type == 'doc.type.function' then
if not mark[src] then
diff --git a/script/core/type-formatting.lua b/script/core/type-formatting.lua
index c2290ef3..49a721e5 100644
--- a/script/core/type-formatting.lua
+++ b/script/core/type-formatting.lua
@@ -1,6 +1,6 @@
local files = require 'files'
local lookBackward = require 'core.look-backward'
-local guide = require 'core.guide'
+local guide = require "parser.guide"
local function insertIndentation(uri, offset, edits)
local lines = files.getLines(uri)
diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua
index ae420d32..2df23a4d 100644
--- a/script/core/workspace-symbol.lua
+++ b/script/core/workspace-symbol.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local matchKey = require 'core.matchkey'
local define = require 'proto.define'
local await = require 'await'
@@ -52,7 +52,7 @@ local function searchFile(uri, key, results)
return
end
- guide.eachSource(ast.ast, function (source)
+ searcher.eachSource(ast.ast, function (source)
buildSource(uri, source, key, results)
end)
end