summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--meta/template/basic.lua3
-rw-r--r--meta/template/builtin.lua3
-rw-r--r--meta/template/package.lua6
-rw-r--r--script/core/code-action.lua13
-rw-r--r--script/core/command/removeSpace.lua15
-rw-r--r--script/core/command/solve.lua9
-rw-r--r--script/core/completion.lua132
-rw-r--r--script/core/definition.lua18
-rw-r--r--script/core/diagnostics/ambiguity-1.lua2
-rw-r--r--script/core/diagnostics/circle-doc-class.lua12
-rw-r--r--script/core/diagnostics/close-non-object.lua7
-rw-r--r--script/core/diagnostics/code-after-break.lua8
-rw-r--r--script/core/diagnostics/count-down-loop.lua6
-rw-r--r--script/core/diagnostics/deprecated.lua20
-rw-r--r--script/core/diagnostics/duplicate-doc-class.lua14
-rw-r--r--script/core/diagnostics/duplicate-index.lua10
-rw-r--r--script/core/diagnostics/duplicate-set-field.lua12
-rw-r--r--script/core/diagnostics/empty-block.lua2
-rw-r--r--script/core/diagnostics/global-in-nil-env.lua2
-rw-r--r--script/core/diagnostics/init.lua3
-rw-r--r--script/core/diagnostics/lowercase-global.lua2
-rw-r--r--script/core/diagnostics/newfield-call.lua2
-rw-r--r--script/core/diagnostics/newline-call.lua2
-rw-r--r--script/core/diagnostics/no-implicit-any.lua4
-rw-r--r--script/core/diagnostics/redefined-local.lua5
-rw-r--r--script/core/diagnostics/redundant-parameter.lua4
-rw-r--r--script/core/diagnostics/trailing-space.lua2
-rw-r--r--script/core/diagnostics/unbalanced-assignments.lua2
-rw-r--r--script/core/diagnostics/undefined-doc-class.lua4
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua4
-rw-r--r--script/core/diagnostics/undefined-doc-param.lua2
-rw-r--r--script/core/diagnostics/undefined-env-child.lua10
-rw-r--r--script/core/diagnostics/undefined-field.lua69
-rw-r--r--script/core/diagnostics/undefined-global.lua19
-rw-r--r--script/core/diagnostics/unused-function.lua2
-rw-r--r--script/core/diagnostics/unused-label.lua2
-rw-r--r--script/core/diagnostics/unused-local.lua5
-rw-r--r--script/core/diagnostics/unused-vararg.lua2
-rw-r--r--script/core/document-symbol.lua10
-rw-r--r--script/core/find-source.lua2
-rw-r--r--script/core/folding.lua4
-rw-r--r--script/core/generic.lua234
-rw-r--r--script/core/guide2.lua (renamed from script/core/guide.lua)5
-rw-r--r--script/core/highlight.lua59
-rw-r--r--script/core/hint.lua20
-rw-r--r--script/core/hover/arg.lua13
-rw-r--r--script/core/hover/description.lua17
-rw-r--r--script/core/hover/init.lua11
-rw-r--r--script/core/hover/label.lua59
-rw-r--r--script/core/hover/name.lua6
-rw-r--r--script/core/hover/return.lua28
-rw-r--r--script/core/hover/table.lua282
-rw-r--r--script/core/infer.lua634
-rw-r--r--script/core/keyword.lua2
-rw-r--r--script/core/noder.lua926
-rw-r--r--script/core/reference.lua23
-rw-r--r--script/core/rename.lua21
-rw-r--r--script/core/searcher.lua728
-rw-r--r--script/core/semantic-tokens.lua4
-rw-r--r--script/core/signature.lua7
-rw-r--r--script/core/type-formatting.lua2
-rw-r--r--script/core/workspace-symbol.lua4
-rw-r--r--script/files.lua3
-rw-r--r--script/parser/ast.lua10
-rw-r--r--script/parser/compile.lua5
-rw-r--r--script/parser/guide.lua395
-rw-r--r--script/parser/luadoc.lua81
-rw-r--r--script/vm/eachDef.lua50
-rw-r--r--script/vm/eachField.lua109
-rw-r--r--script/vm/eachRef.lua49
-rw-r--r--script/vm/getClass.lua64
-rw-r--r--script/vm/getDocs.lua155
-rw-r--r--script/vm/getGlobals.lua37
-rw-r--r--script/vm/getInfer.lua104
-rw-r--r--script/vm/getLibrary.lua7
-rw-r--r--script/vm/getLinks.lua17
-rw-r--r--script/vm/getMeta.lua53
-rw-r--r--script/vm/guideInterface.lua10
-rw-r--r--script/vm/init.lua4
-rw-r--r--script/vm/vm.lua10
-rw-r--r--test.lua8
-rw-r--r--test/basic/init.lua221
-rw-r--r--test/basic/linker.txt141
-rw-r--r--test/basic/noder.lua146
-rw-r--r--test/basic/textmerger.lua219
-rw-r--r--test/crossfile/hover.lua14
-rw-r--r--test/definition/init.lua9
-rw-r--r--test/definition/luadoc.lua294
-rw-r--r--test/diagnostics/init.lua5
-rw-r--r--test/full/example.lua13
-rw-r--r--test/hover/init.lua239
-rw-r--r--test/references/all.lua213
-rw-r--r--test/references/init.lua168
-rw-r--r--test/rename/init.lua4
-rw-r--r--test/type_inference/init.lua117
95 files changed, 4719 insertions, 1795 deletions
diff --git a/meta/template/basic.lua b/meta/template/basic.lua
index 785819a4..7a42ab74 100644
--- a/meta/template/basic.lua
+++ b/meta/template/basic.lua
@@ -135,9 +135,8 @@ function next(table, index) end
---#DES 'paris'
---@generic T: table, K, V
---@param t T
----@return fun(table: table<K, V>, index: K):K, V
+---@return fun(table: table<K, V>, index?: K):K, V
---@return T
----@return nil
function pairs(t) end
---#DES 'pcall'
diff --git a/meta/template/builtin.lua b/meta/template/builtin.lua
index 2b547d1d..45ac24af 100644
--- a/meta/template/builtin.lua
+++ b/meta/template/builtin.lua
@@ -1,6 +1,7 @@
---@meta
----@class any
+---@class unknown
+---@class any: unknown
---@class nil: any
---@class boolean: any
---@class number: any
diff --git a/meta/template/package.lua b/meta/template/package.lua
index ae7def31..8c18e10b 100644
--- a/meta/template/package.lua
+++ b/meta/template/package.lua
@@ -3,13 +3,13 @@
---#if VERSION >=5.4 then
---#DES 'require>5.4'
---@param modname string
----@return any
----@return any loaderdata
+---@return unknown
+---@return unknown loaderdata
function require(modname) end
---#else
---#DES 'require<5.3'
---@param modname string
----@return any
+---@return unknown
function require(modname) end
---#end
diff --git a/script/core/code-action.lua b/script/core/code-action.lua
index bae3df81..9ed000e9 100644
--- a/script/core/code-action.lua
+++ b/script/core/code-action.lua
@@ -1,10 +1,9 @@
-local files = require 'files'
-local lang = require 'language'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local util = require 'utility'
-local sp = require 'bee.subprocess'
-local vm = require 'vm'
+local files = require 'files'
+local lang = require 'language'
+local util = require 'utility'
+local sp = require 'bee.subprocess'
+local vm = require 'vm'
+local guide = require "parser.guide"
local function checkDisableByLuaDocExits(uri, row, mode, code)
local lines = files.getLines(uri)
diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua
index 527af8d5..ba1ee8eb 100644
--- a/script/core/command/removeSpace.lua
+++ b/script/core/command/removeSpace.lua
@@ -1,11 +1,10 @@
-local files = require 'files'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local proto = require 'proto'
-local lang = require 'language'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local proto = require 'proto'
+local lang = require 'language'
local function isInString(ast, offset)
- return guide.eachSourceContain(ast.ast, offset, function (source)
+ return searcher.eachSourceContain(ast.ast, offset, function (source)
if source.type == 'string' then
return true
end
@@ -23,10 +22,10 @@ return function (data)
local textEdit = {}
for i = 1, #lines do
- local line = guide.lineContent(lines, text, i, true)
+ local line = searcher.lineContent(lines, text, i, true)
local pos = line:find '[ \t]+$'
if pos then
- local start, finish = guide.lineRange(lines, i, true)
+ local start, finish = searcher.lineRange(lines, i, true)
start = start + pos - 1
if isInString(ast, start) then
goto NEXT_LINE
diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua
index 995a2109..dc23e7af 100644
--- a/script/core/command/solve.lua
+++ b/script/core/command/solve.lua
@@ -1,8 +1,7 @@
-local files = require 'files'
-local define = require 'proto.define'
-local guide = require 'core.guide'
-local proto = require 'proto'
-local lang = require 'language'
+local files = require 'files'
+local guide = require 'parser.guide'
+local proto = require 'proto'
+local lang = require 'language'
local opMap = {
['+'] = true,
diff --git a/script/core/completion.lua b/script/core/completion.lua
index e3980eca..acaaa276 100644
--- a/script/core/completion.lua
+++ b/script/core/completion.lua
@@ -1,19 +1,14 @@
local define = require 'proto.define'
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local matchKey = require 'core.matchkey'
local vm = require 'vm'
-local getLabel = require 'core.hover.label'
local getName = require 'core.hover.name'
local getArg = require 'core.hover.arg'
-local getReturn = require 'core.hover.return'
-local getDesc = require 'core.hover.description'
local getHover = require 'core.hover'
local config = require 'config'
local util = require 'utility'
local markdown = require 'provider.markdown'
-local findSource = require 'core.find-source'
-local await = require 'await'
local parser = require 'parser'
local keyWordMap = require 'core.keyword'
local workspace = require 'workspace'
@@ -21,6 +16,8 @@ local furi = require 'file-uri'
local rpath = require 'workspace.require-path'
local lang = require 'language'
local lookBackward = require 'core.look-backward'
+local guide = require 'parser.guide'
+local infer = require 'core.infer'
local DiagnosticModes = {
'disable-next-line',
@@ -135,8 +132,8 @@ local function buildFunctionSnip(source, value, oop)
end
local function buildDetail(source)
- local types = vm.getInferType(source, 0)
- local literals = vm.getInferLiteral(source, 0)
+ local types = infer.searchAndViewInfers(source)
+ local literals = infer.searchAndViewLiterals(source)
if literals then
return types .. ' = ' .. literals
else
@@ -149,9 +146,9 @@ local function getSnip(source)
if context <= 0 then
return nil
end
- local defs = vm.getRefs(source, 0)
+ local defs = vm.getRefs(source)
for _, def in ipairs(defs) do
- def = guide.getObjectValue(def) or def
+ def = searcher.getObjectValue(def) or def
if def ~= source and def.type == 'function' then
local uri = guide.getUri(def)
local text = files.getText(uri)
@@ -273,8 +270,8 @@ local function checkLocal(ast, word, offset, results)
if not matchKey(word, name) then
goto CONTINUE
end
- if vm.hasType(source, 'function') then
- for _, def in ipairs(vm.getDefs(source, 0)) do
+ if infer.hasType(source, 'function') then
+ for _, def in ipairs(vm.getDefs(source)) do
if def.type == 'function'
or def.type == 'doc.type.function' then
local funcLabel = name .. getParams(def, false)
@@ -417,7 +414,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset)
end
local function checkFieldThen(name, src, word, start, offset, parent, oop, results)
- local value = guide.getObjectValue(src) or src
+ local value = searcher.getObjectValue(src) or src
local kind = define.CompletionItemKind.Field
if value.type == 'function'
or value.type == 'doc.type.function' then
@@ -492,7 +489,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res
end
local funcLabel
if config.config.completion.showParams then
- local value = guide.getObjectValue(src) or src
+ local value = searcher.getObjectValue(src) or src
if value.type == 'function'
or value.type == 'doc.type.function' then
funcLabel = name .. getParams(value, oop)
@@ -539,16 +536,16 @@ end
local function checkGlobal(ast, word, start, offset, parent, oop, results)
local locals = guide.getVisibleLocals(ast.ast, offset)
- local refs = vm.getGlobalSets '*'
- checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global')
+ local globals = vm.getGlobalSets '*'
+ checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results, locals, 'global')
end
local function checkField(ast, word, start, offset, parent, oop, results)
if parent.tag == '_ENV' or parent.special == '_G' then
- local refs = vm.getGlobalSets '*'
- checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
+ local globals = vm.getGlobalSets '*'
+ checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results)
else
- local refs = vm.getFields(parent, 0)
+ local refs = vm.getRefs(parent, '*')
checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results)
end
end
@@ -1043,14 +1040,14 @@ local function mergeEnums(a, b, source)
end
end
-local function checkTypingEnum(ast, text, offset, infers, str, results)
+local function checkTypingEnum(ast, text, offset, defs, str, results)
local enums = {}
- for _, infer in ipairs(infers) do
- if infer.source.type == 'doc.type.enum'
- or infer.source.type == 'doc.resume' then
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
enums[#enums+1] = {
- label = infer.source[1],
- description = infer.source.comment and infer.source.comment.text,
+ label = def[1],
+ description = def.comment and def.comment.text,
kind = define.CompletionItemKind.EnumMember,
}
end
@@ -1074,8 +1071,8 @@ local function checkEqualEnumLeft(ast, text, offset, source, results)
return src
end
end)
- local infers = vm.getInfers(source, 0)
- checkTypingEnum(ast, text, offset, infers, str, results)
+ local defs = vm.getDefs(source)
+ checkTypingEnum(ast, text, offset, defs, str, results)
end
local function checkEqualEnum(ast, text, offset, results)
@@ -1247,7 +1244,30 @@ function (%s)\
end"):format(table.concat(args, ', '))
end
-local function getCallEnums(source, index)
+local function pushCallEnumsAndFuncs(defs)
+ local results = {}
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
+ results[#results+1] = {
+ label = def[1],
+ description = def.comment,
+ kind = define.CompletionItemKind.EnumMember,
+ }
+ end
+ if def.type == 'doc.type.function' then
+ results[#results+1] = {
+ label = infer.viewDocFunction(def),
+ description = def.comment,
+ kind = define.CompletionItemKind.Function,
+ insertText = buildInsertDocFunction(def),
+ }
+ end
+ end
+ return results
+end
+
+local function getCallEnumsAndFuncs(source, index)
if source.type == 'function' and source.bindDocs then
if not source.args then
return
@@ -1266,37 +1286,10 @@ local function getCallEnums(source, index)
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.param'
and doc.param[1] == arg[1] then
- local enums = {}
- for _, enum in ipairs(vm.getDocEnums(doc.extends) or {}) do
- enums[#enums+1] = {
- label = enum[1],
- description = enum.comment,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
- for _, unit in ipairs(vm.getDocTypeUnits(doc.extends) or {}) do
- if unit.type == 'doc.type.function' then
- local text = files.getText(guide.getUri(unit))
- enums[#enums+1] = {
- label = text:sub(unit.start, unit.finish),
- description = doc.comment,
- kind = define.CompletionItemKind.Function,
- insertText = buildInsertDocFunction(unit),
- }
- end
- end
- return enums
+ return pushCallEnumsAndFuncs(vm.getDefs(doc.extends))
elseif doc.type == 'doc.vararg'
and arg.type == '...' then
- local enums = {}
- for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do
- enums[#enums+1] = {
- label = enum[1],
- description = enum.comment,
- kind = define.CompletionItemKind.EnumMember,
- }
- end
- return enums
+ return pushCallEnumsAndFuncs(vm.getDefs(doc.vararg))
end
end
end
@@ -1403,12 +1396,12 @@ local function checkTableLiteralFieldByCall(ast, text, offset, call, defs, index
return
end
for _, def in ipairs(defs) do
- local func = guide.getObjectValue(def) or def
+ local func = searcher.getObjectValue(def) or def
local param = getFuncParamByCallIndex(func, index)
if not param then
goto CONTINUE
end
- local defs = vm.getDefFields(param, 0)
+ local defs = vm.getDefs(param, '*')
for _, field in ipairs(defs) do
local name = guide.getKeyName(field)
if name and not mark[name] then
@@ -1431,10 +1424,10 @@ local function tryCallArg(ast, text, offset, results)
if arg and arg.type == 'function' then
return
end
- local defs = vm.getDefs(call.node, 0)
+ local defs = vm.getDefs(call.node)
for _, def in ipairs(defs) do
- def = guide.getObjectValue(def) or def
- local enums = getCallEnums(def, argIndex)
+ def = searcher.getObjectValue(def) or def
+ local enums = getCallEnumsAndFuncs(def, argIndex)
if enums then
mergeEnums(myResults, enums, arg)
end
@@ -1461,7 +1454,8 @@ local function tryTable(ast, text, offset, results)
if source.type ~= 'table' then
tbl = source.parent
end
- local defs = vm.getDefFields(tbl, 0)
+ local parent = tbl.parent
+ local defs = vm.getDefs(parent, '*')
for _, field in ipairs(defs) do
local name = guide.getKeyName(field)
if name and not mark[name] then
@@ -1560,7 +1554,7 @@ end
local function tryLuaDocBySource(ast, offset, source, results)
if source.type == 'doc.extends.name' then
if source.parent.type == 'doc.class' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if doc.type == 'doc.class.name'
and doc.parent ~= source.parent
and matchKey(source[1], doc[1]) then
@@ -1578,7 +1572,7 @@ local function tryLuaDocBySource(ast, offset, source, results)
end
return true
elseif source.type == 'doc.type.name' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name')
and doc.parent ~= source.parent
and matchKey(source[1], doc[1]) then
@@ -1652,7 +1646,7 @@ end
local function tryLuaDocByErr(ast, offset, err, docState, results)
if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if doc.type == 'doc.class.name'
and doc.parent ~= docState then
results[#results+1] = {
@@ -1662,7 +1656,7 @@ local function tryLuaDocByErr(ast, offset, err, docState, results)
end
end
elseif err.type == 'LUADOC_MISS_TYPE_NAME' then
- for _, doc in ipairs(vm.getDocTypes '*') do
+ for _, doc in ipairs(vm.getDocDefines()) do
if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then
results[#results+1] = {
label = doc[1],
@@ -1735,14 +1729,14 @@ local function buildLuaDocOfFunction(func)
local returns = {}
if func.args then
for _, arg in ipairs(func.args) do
- args[#args+1] = vm.getInferType(arg)
+ args[#args+1] = infer.searchAndViewInfers(arg)
end
end
if func.returns then
for _, rtns in ipairs(func.returns) do
for n = 1, #rtns do
if not returns[n] then
- returns[n] = vm.getInferType(rtns[n])
+ returns[n] = infer.searchAndViewInfers(rtns[n])
end
end
end
diff --git a/script/core/definition.lua b/script/core/definition.lua
index b26bb922..3ced05a2 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -1,14 +1,15 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local workspace = require 'workspace'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
+local guide = require 'parser.guide'
local function sortResults(results)
-- 先按照顺序排序
table.sort(results, function (a, b)
- local u1 = guide.getUri(a.target)
- local u2 = guide.getUri(b.target)
+ local u1 = searcher.getUri(a.target)
+ local u2 = searcher.getUri(b.target)
if u1 == u2 then
return a.target.start < b.target.start
else
@@ -20,7 +21,7 @@ local function sortResults(results)
for i = #results, 1, -1 do
local res = results[i].target
local f = res.finish
- local uri = guide.getUri(res)
+ local uri = searcher.getUri(res)
if lf and f > lf and uri == lu then
table.remove(results, i)
else
@@ -127,11 +128,11 @@ return function (uri, offset)
end
end
- local defs = vm.getDefs(source, 0)
+ local defs = vm.getDefs(source)
local values = {}
for _, src in ipairs(defs) do
- local value = guide.getObjectValue(src)
- if value and value ~= src then
+ local value = searcher.getObjectValue(src)
+ if value and value ~= src and guide.isLiteral(value) then
values[value] = true
end
end
@@ -148,9 +149,6 @@ return function (uri, offset)
goto CONTINUE
end
src = src.field or src.method or src.index or src
- if src.type == 'table' and src.parent.type ~= 'return' then
- goto CONTINUE
- end
if src.type == 'doc.class.name'
and source.type ~= 'doc.type.name'
and source.type ~= 'doc.extends.name'
diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua
index 19bb4f97..37815fb5 100644
--- a/script/core/diagnostics/ambiguity-1.lua
+++ b/script/core/diagnostics/ambiguity-1.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local opMap = {
diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua
index 702cd904..d2e26378 100644
--- a/script/core/diagnostics/circle-doc-class.lua
+++ b/script/core/diagnostics/circle-doc-class.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local vm = require 'vm'
+local guide = require 'parser.guide'
return function (uri, callback)
local state = files.getAst(uri)
@@ -40,7 +40,7 @@ return function (uri, callback)
end
if not mark[newName] then
mark[newName] = true
- local docs = vm.getDocTypes(newName)
+ local docs = vm.getDocDefines(newName)
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name' then
list[#list+1] = otherDoc.parent
diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua
index d1983c42..7828efe9 100644
--- a/script/core/diagnostics/close-non-object.lua
+++ b/script/core/diagnostics/close-non-object.lua
@@ -1,7 +1,6 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
return function (uri, callback)
local state = files.getAst(uri)
diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua
index f23755ea..f300a61a 100644
--- a/script/core/diagnostics/code-after-break.lua
+++ b/script/core/diagnostics/code-after-break.lua
@@ -1,7 +1,7 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
return function (uri, callback)
local state = files.getAst(uri)
diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua
index 65099af8..ee245781 100644
--- a/script/core/diagnostics/count-down-loop.lua
+++ b/script/core/diagnostics/count-down-loop.lua
@@ -1,6 +1,6 @@
-local files = require "files"
-local guide = require "core.guide"
-local lang = require 'language'
+local files = require "files"
+local guide = require "parser.guide"
+local lang = require 'language'
return function (uri, callback)
local state = files.getAst(uri)
diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua
index 60d60946..a6f8a47e 100644
--- a/script/core/diagnostics/deprecated.lua
+++ b/script/core/diagnostics/deprecated.lua
@@ -1,10 +1,10 @@
-local files = require 'files'
-local vm = require 'vm'
-local lang = require 'language'
-local guide = require 'core.guide'
-local config = require 'config'
-local define = require 'proto.define'
-local await = require 'await'
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local config = require 'config'
+local define = require 'proto.define'
+local await = require 'await'
return function (uri, callback)
local ast = files.getAst(uri)
@@ -20,7 +20,7 @@ return function (uri, callback)
return
end
if src.type == 'getglobal' then
- local key = guide.getKeyName(src)
+ local key = src[1]
if not key then
return
end
@@ -34,11 +34,11 @@ return function (uri, callback)
await.delay()
- if not vm.isDeprecated(src, 0) then
+ if not vm.isDeprecated(src, true) then
return
end
- local defs = vm.getDefs(src, 0)
+ local defs = vm.getDefs(src)
local validVersions
for _, def in ipairs(defs) do
if def.bindDocs then
diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua
index 8c6696a9..daecb836 100644
--- a/script/core/diagnostics/duplicate-doc-class.lua
+++ b/script/core/diagnostics/duplicate-doc-class.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local vm = require 'vm'
+local guide = require 'parser.guide'
return function (uri, callback)
local state = files.getAst(uri)
@@ -20,7 +20,7 @@ return function (uri, callback)
or doc.type == 'doc.alias' then
local name = guide.getKeyName(doc)
if not cache[name] then
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
cache[name] = {}
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name'
@@ -28,7 +28,7 @@ return function (uri, callback)
cache[name][#cache[name]+1] = {
start = otherDoc.start,
finish = otherDoc.finish,
- uri = guide.getUri(otherDoc),
+ uri = searcher.getUri(otherDoc),
}
end
end
diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua
index 5e63d39e..d1ba9261 100644
--- a/script/core/diagnostics/duplicate-index.lua
+++ b/script/core/diagnostics/duplicate-index.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local guide = require 'parser.guide'
+local lang = require 'language'
+local define = require 'proto.define'
+local vm = require 'vm'
return function (uri, callback)
local ast = files.getAst(uri)
diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua
index c1e2285a..e1883fe5 100644
--- a/script/core/diagnostics/duplicate-set-field.lua
+++ b/script/core/diagnostics/duplicate-set-field.lua
@@ -1,8 +1,8 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local lang = require 'language'
-local define = require 'proto.define'
-local vm = require 'vm'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local lang = require 'language'
+local define = require 'proto.define'
+local guide = require "parser.guide"
return function (uri, callback)
local ast = files.getAst(uri)
@@ -30,7 +30,7 @@ return function (uri, callback)
if not name then
goto CONTINUE
end
- local value = guide.getObjectValue(nxt)
+ local value = searcher.getObjectValue(nxt)
if not value or value.type ~= 'function' then
goto CONTINUE
end
diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua
index 690a4ca2..2024f4e3 100644
--- a/script/core/diagnostics/empty-block.lua
+++ b/script/core/diagnostics/empty-block.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua
index de23bc76..9a0d4f35 100644
--- a/script/core/diagnostics/global-in-nil-env.lua
+++ b/script/core/diagnostics/global-in-nil-env.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
-- TODO: 检查路径是否可达
diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua
index a2b831f7..1d1ab9af 100644
--- a/script/core/diagnostics/init.lua
+++ b/script/core/diagnostics/init.lua
@@ -62,14 +62,11 @@ local function check(uri, name, results)
end
return function (uri, response)
- local vm = require 'vm'
local ast = files.getAst(uri)
if not ast then
return nil
end
- local isOpen = files.isOpen(uri)
-
for _, name in ipairs(diagList) do
await.delay()
local results = {}
diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua
index 9c094701..8c7ae793 100644
--- a/script/core/diagnostics/lowercase-global.lua
+++ b/script/core/diagnostics/lowercase-global.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local config = require 'config'
local vm = require 'vm'
diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua
index 0727c2fd..75681cbc 100644
--- a/script/core/diagnostics/newfield-call.lua
+++ b/script/core/diagnostics/newfield-call.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua
index 807f76a2..159a60c9 100644
--- a/script/core/diagnostics/newline-call.lua
+++ b/script/core/diagnostics/newline-call.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua
index ffaab821..23af570a 100644
--- a/script/core/diagnostics/no-implicit-any.lua
+++ b/script/core/diagnostics/no-implicit-any.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
@@ -10,7 +10,7 @@ return function (uri, callback)
return
end
- guide.eachSource(ast.ast, function (source)
+ searcher.eachSource(ast.ast, function (source)
if source.type ~= 'local'
and source.type ~= 'setlocal'
and source.type ~= 'setglobal'
diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua
index 857d80d2..48093417 100644
--- a/script/core/diagnostics/redefined-local.lua
+++ b/script/core/diagnostics/redefined-local.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
@@ -13,6 +13,9 @@ return function (uri, callback)
or name == ast.ENVMode then
return
end
+ if source.tag == 'self' then
+ return
+ end
local exist = guide.getLocal(source, name, source.start-1)
if exist then
callback {
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
index c5bcd5a5..eca7fc91 100644
--- a/script/core/diagnostics/redundant-parameter.lua
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local define = require 'proto.define'
@@ -84,7 +84,7 @@ return function (uri, callback)
local funcArgs = cache[func]
if funcArgs == nil then
funcArgs = getFuncArgs(func) or false
- local refs = vm.getRefs(func, 0)
+ local refs = vm.getRefs(func)
for _, ref in ipairs(refs) do
cache[ref] = funcArgs
end
diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua
index 0a4b1d57..e54a6e60 100644
--- a/script/core/diagnostics/trailing-space.lua
+++ b/script/core/diagnostics/trailing-space.lua
@@ -1,6 +1,6 @@
local files = require 'files'
local lang = require 'language'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local function isInString(ast, offset)
local result = false
diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua
index b2b2800c..35aebb45 100644
--- a/script/core/diagnostics/unbalanced-assignments.lua
+++ b/script/core/diagnostics/unbalanced-assignments.lua
@@ -1,7 +1,7 @@
local files = require 'files'
local define = require 'proto.define'
local lang = require 'language'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
return function (uri, callback, code)
local ast = files.getAst(uri)
diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua
index a91cfa7f..d79f7ea4 100644
--- a/script/core/diagnostics/undefined-doc-class.lua
+++ b/script/core/diagnostics/undefined-doc-class.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
@@ -25,7 +25,7 @@ return function (uri, callback)
end
for _, ext in ipairs(doc.extends) do
local name = ext[1]
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
if cache[name] == nil then
cache[name] = false
for _, otherDoc in ipairs(docs) do
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
index d8a4363b..871f16e1 100644
--- a/script/core/diagnostics/undefined-doc-name.lua
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
@@ -22,7 +22,7 @@ return function (uri, callback)
if classCache[name] ~= nil then
return classCache[name]
end
- local docs = vm.getDocTypes(name)
+ local docs = vm.getDocDefines(name)
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class.name'
or otherDoc.type == 'doc.alias.name' then
diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua
index 0bf371e5..4a97947d 100644
--- a/script/core/diagnostics/undefined-doc-param.lua
+++ b/script/core/diagnostics/undefined-doc-param.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local define = require 'proto.define'
local vm = require 'vm'
diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua
index 89efb8c7..c97c3fe8 100644
--- a/script/core/diagnostics/undefined-env-child.lua
+++ b/script/core/diagnostics/undefined-env-child.lua
@@ -1,7 +1,7 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local vm = require 'vm'
-local lang = require 'language'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local guide = require 'parser.guide'
+local lang = require 'language'
return function (uri, callback)
local ast = files.getAst(uri)
@@ -13,7 +13,7 @@ return function (uri, callback)
if source.node.tag == '_ENV' then
return
end
- local defs = guide.requestDefinition(source)
+ local defs = searcher.requestDefinition(source)
if #defs > 0 then
return
end
diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua
index b10c9ab0..2d357d5b 100644
--- a/script/core/diagnostics/undefined-field.lua
+++ b/script/core/diagnostics/undefined-field.lua
@@ -2,9 +2,16 @@ local files = require 'files'
local vm = require 'vm'
local lang = require 'language'
local config = require 'config'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
+local SkipCheckClass = {
+ ['unknown'] = true,
+ ['any'] = true,
+ ['table'] = true,
+ ['nil'] = true,
+}
+
return function (uri, callback)
local ast = files.getAst(uri)
if not ast then
@@ -18,7 +25,7 @@ return function (uri, callback)
if cache[src] == nil then
tracy.ZoneBeginN('undefined-field getInfers')
infers = vm.getInfers(src, 0) or false
- local refs = vm.getRefs(src, 0)
+ local refs = vm.getRefs(src)
for _, ref in ipairs(refs) do
cache[ref] = infers
end
@@ -47,7 +54,7 @@ return function (uri, callback)
elseif inferSource.type == 'doc.class.name' then
addTo(allDocClass, inferSource.parent)
elseif inferSource.type == 'doc.type.name' then
- local docTypes = vm.getDocTypes(inferSource[1])
+ local docTypes = vm.getDocDefines(inferSource[1])
for _, docType in ipairs(docTypes) do
if docType.type == 'doc.class.name' then
addTo(allDocClass, docType.parent)
@@ -65,7 +72,7 @@ return function (uri, callback)
local empty = true
for _, docClass in ipairs(allDocClass) do
tracy.ZoneBeginN('undefined-field getDefFields')
- local refs = vm.getDefFields(docClass)
+ local refs = vm.getDefs(docClass, '*')
tracy.ZoneEnd()
for _, ref in ipairs(refs) do
@@ -87,35 +94,37 @@ return function (uri, callback)
end
local function checkUndefinedField(src)
- local fieldName = guide.getKeyName(src)
-
- local allDocClass = getAllDocClassFromInfer(src.node)
- if (not allDocClass) or (#allDocClass == 0) then
- return
- end
-
- local fields = getAllFieldsFromAllDocClass(allDocClass)
-
- -- 没找到任何 field,跳过检查
- if not fields then
+ if #vm.getDefs(src) > 0 then
return
end
-
- if not fields[fieldName] then
- local message = lang.script('DIAG_UNDEF_FIELD', fieldName)
- if src.type == 'getfield' and src.field then
- callback {
- start = src.field.start,
- finish = src.field.finish,
- message = message,
- }
- elseif src.type == 'getmethod' and src.method then
- callback {
- start = src.method.start,
- finish = src.method.finish,
- message = message,
- }
+ local node = src.node
+ if node then
+ local defs = vm.getDefs(node)
+ local ok
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class.name'
+ and not SkipCheckClass[def[1]] then
+ ok = true
+ break
+ end
end
+ if not ok then
+ return
+ end
+ end
+ local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src))
+ if src.type == 'getfield' and src.field then
+ callback {
+ start = src.field.start,
+ finish = src.field.finish,
+ message = message,
+ }
+ elseif src.type == 'getmethod' and src.method then
+ callback {
+ start = src.method.start,
+ finish = src.method.finish,
+ message = message,
+ }
end
end
guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField);
diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua
index 161d8856..825b14f1 100644
--- a/script/core/diagnostics/undefined-global.lua
+++ b/script/core/diagnostics/undefined-global.lua
@@ -1,9 +1,8 @@
-local files = require 'files'
-local vm = require 'vm'
-local lang = require 'language'
-local config = require 'config'
-local guide = require 'core.guide'
-local define = require 'proto.define'
+local files = require 'files'
+local vm = require 'vm'
+local lang = require 'language'
+local config = require 'config'
+local guide = require 'parser.guide'
local requireLike = {
['include'] = true,
@@ -20,7 +19,7 @@ return function (uri, callback)
-- 遍历全局变量,检查所有没有 set 模式的全局变量
guide.eachSourceType(ast.ast, 'getglobal', function (src)
- local key = guide.getKeyName(src)
+ local key = src[1]
if not key then
return
end
@@ -30,7 +29,11 @@ return function (uri, callback)
if config.config.runtime.special[key] then
return
end
- if #vm.getGlobalSets(key) == 0 then
+ local node = src.node
+ if node.tag ~= '_ENV' then
+ return
+ end
+ if #vm.getDefs(src) == 0 then
local message = lang.script('DIAG_UNDEF_GLOBAL', key)
if requireLike[key:lower()] then
message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua
index b6f92e60..41c239f9 100644
--- a/script/core/diagnostics/unused-function.lua
+++ b/script/core/diagnostics/unused-function.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local vm = require 'vm'
local define = require 'proto.define'
local lang = require 'language'
diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua
index e2d5e49a..e6d998ba 100644
--- a/script/core/diagnostics/unused-label.lua
+++ b/script/core/diagnostics/unused-label.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua
index fde90cb8..1a77a45f 100644
--- a/script/core/diagnostics/unused-local.lua
+++ b/script/core/diagnostics/unused-local.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
@@ -87,6 +87,9 @@ return function (uri, callback)
or name == ast.ENVMode then
return
end
+ if source.tag == 'self' then
+ return
+ end
if isToBeClosed(source) then
return
end
diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua
index ec0a05fb..74cc08e7 100644
--- a/script/core/diagnostics/unused-vararg.lua
+++ b/script/core/diagnostics/unused-vararg.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua
index cc87e3ca..e36ba29b 100644
--- a/script/core/document-symbol.lua
+++ b/script/core/document-symbol.lua
@@ -1,8 +1,8 @@
-local await = require 'await'
-local files = require 'files'
-local guide = require 'core.guide'
-local define = require 'proto.define'
-local util = require 'utility'
+local await = require 'await'
+local files = require 'files'
+local guide = require 'parser.guide'
+local define = require 'proto.define'
+local util = require 'utility'
local function buildName(source, text)
if source.type == 'setmethod'
diff --git a/script/core/find-source.lua b/script/core/find-source.lua
index b36306b6..edbb1e2c 100644
--- a/script/core/find-source.lua
+++ b/script/core/find-source.lua
@@ -1,4 +1,4 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local function isValidFunctionPos(source, offset)
for i = 1, #source.keyword // 2 do
diff --git a/script/core/folding.lua b/script/core/folding.lua
index 15678995..1bbae944 100644
--- a/script/core/folding.lua
+++ b/script/core/folding.lua
@@ -1,5 +1,5 @@
local files = require "files"
-local guide = require "core.guide"
+local searcher = require "core.searcher"
local util = require 'utility'
local Care = {
@@ -153,7 +153,7 @@ return function (uri)
local regions = {}
local status = {}
- guide.eachSource(ast.ast, function (source)
+ searcher.eachSource(ast.ast, function (source)
local tp = source.type
if Care[tp] then
Care[tp](source, text, regions)
diff --git a/script/core/generic.lua b/script/core/generic.lua
new file mode 100644
index 00000000..15950974
--- /dev/null
+++ b/script/core/generic.lua
@@ -0,0 +1,234 @@
+local guide = require 'parser.guide'
+local noder = require "core.noder"
+
+---@class generic.value
+---@field type string
+---@field closure generic.closure
+---@field proto parser.guide.object
+---@field parent parser.guide.object
+
+---@class generic.closure
+---@field type string
+---@field proto parser.guide.object
+---@field upvalues table<string, generic.value[]>
+---@field params generic.value[]
+---@field returns generic.value[]
+
+local m = {}
+
+---@param closure generic.closure
+---@param proto parser.guide.object
+local function instantValue(closure, proto)
+ ---@type generic.value
+ local value = {
+ type = 'generic.value',
+ closure = closure,
+ proto = proto,
+ parent = proto.parent,
+ }
+ closure.values[#closure.values+1] = value
+ return value
+end
+
+---递归实例化对象
+---@param proto parser.guide.object
+---@return generic.value
+local function createValue(closure, proto, callback, road)
+ if callback then
+ road = road or {}
+ end
+ if proto.type == 'doc.type' then
+ local types = {}
+ local hasGeneric
+ for i, tp in ipairs(proto.types) do
+ local genericValue = createValue(closure, tp, callback, road)
+ if genericValue then
+ hasGeneric = true
+ types[i] = genericValue
+ else
+ types[i] = tp
+ end
+ end
+ if not hasGeneric then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.types = types
+ noder.compileNode(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.name' then
+ if not proto.typeGeneric then
+ return nil
+ end
+ local key = proto[1]
+ local value = instantValue(closure, proto)
+ if callback then
+ callback(road, key, proto)
+ end
+ noder.compileNode(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.function' then
+ local hasGeneric
+ local args = {}
+ local returns = {}
+ for i, arg in ipairs(proto.args) do
+ local value = createValue(closure, arg, callback, road)
+ if value then
+ hasGeneric = true
+ end
+ args[i] = value or arg
+ end
+ for i, rtn in ipairs(proto.returns) do
+ local value = createValue(closure, rtn, callback, road)
+ if value then
+ hasGeneric = true
+ end
+ returns[i] = value or rtn
+ end
+ if not hasGeneric then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.args = args
+ value.returns = returns
+ value.isGeneric = true
+ noder.pushSource(noder.getNoders(proto), value)
+ return value
+ end
+ if proto.type == 'doc.type.array' then
+ if road then
+ road[#road+1] = noder.ANY_FIELD
+ end
+ local node = createValue(closure, proto.node, callback, road)
+ if road then
+ road[#road] = nil
+ end
+ if not node then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.node = node
+ return value
+ end
+ if proto.type == 'doc.type.table' then
+ road[#road+1] = noder.TABLE_KEY
+ local tkey = createValue(closure, proto.tkey, callback, road)
+ road[#road] = nil
+
+ road[#road+1] = noder.ANY_FIELD
+ local tvalue = createValue(closure, proto.tvalue, callback, road)
+ road[#road] = nil
+
+ if not tkey and not tvalue then
+ return nil
+ end
+ local value = instantValue(closure, proto)
+ value.tkey = tkey or proto.tkey
+ value.tvalue = tvalue or proto.tvalue
+ return value
+ end
+end
+
+local function buildValue(road, key, proto, param, upvalues)
+ local paramID
+ if proto.literal then
+ local str = param.type == 'string' and param[1]
+ if not str then
+ return
+ end
+ paramID = 'dn:' .. str
+ else
+ paramID = noder.getID(param)
+ end
+ if not paramID then
+ return
+ end
+ local myUri = guide.getUri(param)
+ local myHead = noder.URI_CHAR .. myUri .. noder.URI_CHAR
+ paramID = myHead .. paramID
+ if not upvalues[key] then
+ upvalues[key] = {}
+ end
+ upvalues[key][#upvalues[key]+1] = paramID .. table.concat(road)
+end
+
+-- 为所有的 param 与 return 创建副本
+---@param closure generic.closure
+local function buildValues(closure)
+ local protoFunction = closure.proto
+ local upvalues = closure.upvalues
+ local params = closure.call.args
+
+ if protoFunction.type == 'function' then
+ for _, doc in ipairs(protoFunction.bindDocs) do
+ if doc.type == 'doc.param' then
+ local extends = doc.extends
+ local index = extends.paramIndex
+ if index then
+ local param = params and params[index]
+ closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ buildValue(road, key, proto, param, upvalues)
+ end) or extends
+ end
+ end
+ end
+ for _, doc in ipairs(protoFunction.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ closure.returns[rtn.returnIndex] = createValue(closure, rtn) or rtn
+ end
+ end
+ end
+ end
+ if protoFunction.type == 'doc.type.function' then
+ for index, arg in ipairs(protoFunction.args) do
+ local extends = arg.extends
+ local param = params and params[index]
+ closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ buildValue(road, key, proto, param, upvalues)
+ end) or extends
+ end
+ for index, rtn in ipairs(protoFunction.returns) do
+ closure.returns[index] = createValue(closure, rtn) or rtn
+ end
+ end
+end
+
+---创建一个闭包
+---@param proto parser.guide.object|generic.value # 原型函数|泛型值
+---@return generic.closure
+function m.createClosure(proto, call)
+ local protoFunction, parentClosure
+ if proto.type == 'function' then
+ protoFunction = proto
+ elseif proto.type == 'doc.type.function' then
+ protoFunction = proto
+ elseif proto.type == 'generic.value' then
+ protoFunction = proto.proto
+ parentClosure = proto.closure
+ end
+ ---@type generic.closure
+ local closure = {
+ type = 'generic.closure',
+ parent = protoFunction.parent,
+ proto = protoFunction,
+ call = call,
+ upvalues = parentClosure and parentClosure.upvalues or {},
+ params = {},
+ returns = {},
+ values = {},
+ }
+ buildValues(closure)
+
+ if #closure.returns == 0 then
+ return nil
+ end
+
+ noder.compileNode(noder.getNoders(proto), closure)
+
+ return closure
+end
+
+return m
diff --git a/script/core/guide.lua b/script/core/guide2.lua
index e4871060..576c0c20 100644
--- a/script/core/guide.lua
+++ b/script/core/guide2.lua
@@ -292,8 +292,13 @@ end
---@param obj parser.guide.object
---@return parser.guide.object
function m.getRoot(obj)
+ local source = obj
+ if source._root then
+ return source._root
+ end
for _ = 1, 1000 do
if obj.type == 'main' then
+ source._root = obj
return obj
end
local parent = obj.parent
diff --git a/script/core/highlight.lua b/script/core/highlight.lua
index 12ec114f..45001134 100644
--- a/script/core/highlight.lua
+++ b/script/core/highlight.lua
@@ -1,31 +1,18 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local files = require 'files'
local vm = require 'vm'
local define = require 'proto.define'
local findSource = require 'core.find-source'
local util = require 'utility'
+local guide = require 'parser.guide'
local function eachRef(source, callback)
- local results = guide.requestReference(source)
+ local results = searcher.requestReference(source)
for i = 1, #results do
callback(results[i])
end
end
-local function eachField(source, callback)
- if not source then
- return
- end
- local isGlobal = guide.isGlobal(source)
- local results = guide.requestReference(source)
- for i = 1, #results do
- local res = results[i]
- if isGlobal == guide.isGlobal(res) then
- callback(res)
- end
- end
-end
-
local function eachLocal(source, callback)
callback(source)
if source.ref then
@@ -43,21 +30,21 @@ local function find(source, uri, callback)
eachLocal(source.node, callback)
elseif source.type == 'field'
or source.type == 'method' then
- eachField(source.parent, callback)
+ eachRef(source.parent, callback)
elseif source.type == 'getindex'
or source.type == 'setindex'
or source.type == 'tableindex' then
- eachField(source, callback)
+ eachRef(source, callback)
elseif source.type == 'setglobal'
or source.type == 'getglobal' then
- eachField(source, callback)
+ eachRef(source, callback)
elseif source.type == 'goto'
or source.type == 'label' then
eachRef(source, callback)
elseif source.type == 'string'
and source.parent
and source.parent.index == source then
- eachField(source.parent, callback)
+ eachRef(source.parent, callback)
elseif source.type == 'string'
or source.type == 'boolean'
or source.type == 'number'
@@ -238,6 +225,16 @@ local accept = {
['nil'] = true,
}
+local function isLiteralValue(source)
+ if not guide.isLiteral(source) then
+ return false
+ end
+ if source.parent.index == source then
+ return false
+ end
+ return true
+end
+
return function (uri, offset)
local ast = files.getAst(uri)
if not ast then
@@ -249,10 +246,25 @@ return function (uri, offset)
local source = findSource(ast, offset, accept)
if source then
+ local isGlobal = guide.isGlobal(source)
+ local isLiteral = isLiteralValue(source)
find(source, uri, function (target)
+ if not target then
+ return
+ end
if target.dummy then
return
end
+ if mark[target] then
+ return
+ end
+ mark[target] = true
+ if isGlobal ~= guide.isGlobal(target) then
+ return
+ end
+ if isLiteral ~= isLiteralValue(target) then
+ return
+ end
local kind
if target.type == 'getfield' then
target = target.field
@@ -315,13 +327,6 @@ return function (uri, offset)
else
return
end
- if not target then
- return
- end
- if mark[target] then
- return
- end
- mark[target] = true
results[#results+1] = {
start = target.start,
finish = target.finish,
diff --git a/script/core/hint.lua b/script/core/hint.lua
index 13d01dc7..43b8726e 100644
--- a/script/core/hint.lua
+++ b/script/core/hint.lua
@@ -1,7 +1,7 @@
-local files = require 'files'
-local guide = require 'core.guide'
-local vm = require 'vm'
-local config = require 'config'
+local files = require 'files'
+local searcher = require 'core.searcher'
+local vm = require 'vm'
+local config = require 'config'
local function typeHint(uri, edits, start, finish)
local ast = files.getAst(uri)
@@ -9,7 +9,7 @@ local function typeHint(uri, edits, start, finish)
return
end
local mark = {}
- guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ searcher.eachSourceBetween(ast.ast, start, finish, function (source)
if source.type ~= 'local'
and source.type ~= 'setglobal'
and source.type ~= 'tablefield'
@@ -21,7 +21,7 @@ local function typeHint(uri, edits, start, finish)
if source[1] == '_' then
return
end
- if source.value and guide.isLiteral(source.value) then
+ if source.value and searcher.isLiteral(source.value) then
return
end
if source.parent.type == 'funcargs' then
@@ -84,7 +84,7 @@ local function hasLiteralArgInCall(call)
return false
end
for _, arg in ipairs(call.args) do
- if guide.isLiteral(arg) then
+ if searcher.isLiteral(arg) then
return true
end
end
@@ -100,14 +100,14 @@ local function paramName(uri, edits, start, finish)
return
end
local mark = {}
- guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ searcher.eachSourceBetween(ast.ast, start, finish, function (source)
if source.type ~= 'call' then
return
end
if not hasLiteralArgInCall(source) then
return
end
- local defs = vm.getDefs(source.node, 0)
+ local defs = vm.getDefs(source.node)
if not defs then
return
end
@@ -130,7 +130,7 @@ local function paramName(uri, edits, start, finish)
table.remove(args, 1)
end
for i, arg in ipairs(source.args) do
- if not mark[arg] and guide.isLiteral(arg) then
+ if not mark[arg] and searcher.isLiteral(arg) then
mark[arg] = true
if args[i] and args[i] ~= '' then
edits[#edits+1] = {
diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua
index 324d28af..822be2b6 100644
--- a/script/core/hover/arg.lua
+++ b/script/core/hover/arg.lua
@@ -1,4 +1,5 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
+local infer = require 'core.infer'
local vm = require 'vm'
local function optionalArg(arg)
@@ -21,7 +22,7 @@ local function asFunction(source, oop)
methodDef = true
end
if methodDef then
- args[#args+1] = ('self: %s'):format(vm.getInferType(parent.node))
+ args[#args+1] = ('self: %s'):format(infer.searchAndViewInfers(parent.node))
end
if source.args then
for i = 1, #source.args do
@@ -34,10 +35,12 @@ local function asFunction(source, oop)
args[#args+1] = ('%s%s: %s'):format(
name,
optionalArg(arg) and '?' or '',
- vm.getInferType(arg)
+ infer.searchAndViewInfers(arg)
)
+ elseif arg.type == '...' then
+ args[#args+1] = '...'
else
- args[#args+1] = ('%s'):format(vm.getInferType(arg))
+ args[#args+1] = ('%s'):format(infer.searchAndViewInfers(arg))
end
::CONTINUE::
end
@@ -61,7 +64,7 @@ local function asDocFunction(source)
args[i] = ('%s%s: %s'):format(
name,
arg.optional and '?' or '',
- vm.getInferType(arg.extends)
+ infer.searchAndViewInfers(arg.extends)
)
else
args[i] = ('%s%s'):format(
diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua
index 401ca5a7..bcc3065a 100644
--- a/script/core/hover/description.lua
+++ b/script/core/hover/description.lua
@@ -2,11 +2,13 @@ local vm = require 'vm'
local ws = require 'workspace'
local furi = require 'file-uri'
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local markdown = require 'provider.markdown'
local config = require 'config'
local lang = require 'language'
local util = require 'utility'
+local guide = require 'parser.guide'
+local noder = require 'core.noder'
local function asStringInRequire(source, literal)
local rootPath = ws.path or ''
@@ -124,10 +126,10 @@ local function getBindComment(source, docGroup, base)
end
local function tryDocClassComment(source)
- for _, def in ipairs(vm.getDefs(source, 0)) do
+ for _, def in ipairs(vm.getDefs(source)) do
if def.type == 'doc.class.name'
or def.type == 'doc.alias.name' then
- local class = guide.getDocState(def)
+ local class = noder.getDocState(def)
local comment = getBindComment(class, class.bindGroup, class)
if comment then
return comment
@@ -180,7 +182,7 @@ local function isFunction(source)
if source.type == 'function' then
return true
end
- local value = guide.getObjectValue(source)
+ local value = searcher.getObjectValue(source)
if not value then
return false
end
@@ -223,13 +225,14 @@ local function getBindEnums(source, docGroup)
end
local function tryDocFieldUpComment(source)
- if source.type ~= 'doc.field' then
+ if source.type ~= 'doc.field.name' then
return
end
- if not source.bindGroup then
+ local docField = source.parent
+ if not docField.bindGroup then
return
end
- local comment = getBindComment(source, source.bindGroup, source)
+ local comment = getBindComment(docField, docField.bindGroup, docField)
return comment
end
diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua
index 0c8644ed..5dd00c43 100644
--- a/script/core/hover/init.lua
+++ b/script/core/hover/init.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local vm = require 'vm'
local getLabel = require 'core.hover.label'
local getDesc = require 'core.hover.description'
@@ -7,6 +7,7 @@ local util = require 'utility'
local findSource = require 'core.find-source'
local lang = require 'language'
local markdown = require 'provider.markdown'
+local infer = require 'core.infer'
local function eachFunctionAndOverload(value, callback)
callback(value)
@@ -24,7 +25,7 @@ local function getHoverAsValue(source)
local label = getLabel(source)
local desc = getDesc(source)
if not desc then
- local values = vm.getDefs(source, 0)
+ local values = vm.getDefs(source)
for _, def in ipairs(values) do
desc = getDesc(def)
if desc then
@@ -40,7 +41,7 @@ local function getHoverAsValue(source)
end
local function getHoverAsFunction(source)
- local values = vm.getDefs(source, 0)
+ local values = vm.getDefs(source)
local desc = getDesc(source)
local labels = {}
local defs = 0
@@ -48,7 +49,7 @@ local function getHoverAsFunction(source)
local other = 0
local mark = {}
for _, def in ipairs(values) do
- def = guide.getObjectValue(def) or def
+ def = searcher.getObjectValue(def) or def
if def.type == 'function'
or def.type == 'doc.type.function' then
eachFunctionAndOverload(def, function (value)
@@ -123,7 +124,7 @@ local function getHover(source)
if source.type == 'doc.type.name' then
return getHoverAsDocName(source)
end
- local isFunction = vm.hasInferType(source, 'function', 0)
+ local isFunction = infer.hasType(source, 'function')
if isFunction then
return getHoverAsFunction(source)
else
diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua
index d93b14e3..032f19c0 100644
--- a/script/core/hover/label.lua
+++ b/script/core/hover/label.lua
@@ -2,9 +2,10 @@ local buildName = require 'core.hover.name'
local buildArg = require 'core.hover.arg'
local buildReturn = require 'core.hover.return'
local buildTable = require 'core.hover.table'
+local infer = require 'core.infer'
local vm = require 'vm'
local util = require 'utility'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local lang = require 'language'
local config = require 'config'
local files = require 'files'
@@ -31,29 +32,28 @@ local function asDocFunction(source)
end
local function asDocTypeName(source)
- for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ local defs = searcher.requestDefinition(source)
+ for _, doc in ipairs(defs) do
if doc.type == 'doc.class.name' then
- return 'class ' .. source[1]
+ return 'class ' .. doc[1]
end
if doc.type == 'doc.alias.name' then
local extends = doc.parent.extends
- return lang.script('HOVER_EXTENDS', vm.getInferType(extends))
+ return lang.script('HOVER_EXTENDS', infer.searchAndViewInfers(extends))
end
end
end
local function asValue(source, title)
local name = buildName(source)
- local infers = vm.getInfers(source, 0)
- local type = vm.getInferType(source, 0)
- local class = vm.getClass(source, 0)
- local literal = vm.getInferLiteral(source, 0)
+ local type = infer.searchAndViewInfers(source)
+ local literal = infer.searchAndViewLiterals(source)
local cont
- if not vm.hasInferType(source, 'string', 0)
+ if not infer.hasType(source, 'string')
and not type:find('%[%]$')
and not type:find('%w%<') then
- if #vm.getFields(source, 0) > 0
- or vm.hasInferType(source, 'table', 0) then
+ if #vm.getRefs(source, '*') > 0
+ or infer.hasType(source, 'table') then
cont = buildTable(source)
end
end
@@ -66,11 +66,7 @@ local function asValue(source, title)
or type == 'nil') then
type = nil
end
- if class then
- pack[#pack+1] = class
- else
- pack[#pack+1] = type
- end
+ pack[#pack+1] = type
if literal then
pack[#pack+1] = '='
pack[#pack+1] = literal
@@ -123,30 +119,21 @@ local function asField(source)
return asValue(source, 'field')
end
-local function asDocField(source)
- local name = source.field[1]
+local function asDocFieldName(source)
+ local name = source[1]
+ local docField = source.parent
local class
- for _, doc in ipairs(source.bindGroup) do
+ for _, doc in ipairs(docField.bindGroup) do
if doc.type == 'doc.class' then
class = doc
break
end
end
- local infers = {}
- for _, infer in ipairs(vm.getInfers(source.extends) or {}) do
- infers[#infers+1] = infer
- end
+ local view = infer.searchAndViewInfers(docField.extends)
if not class then
- return ('field ?.%s: %s'):format(
- name,
- guide.viewInferType(infers)
- )
- end
- return ('field %s.%s: %s'):format(
- class.class[1],
- name,
- guide.viewInferType(infers)
- )
+ return ('field ?.%s: %s'):format(name, view)
+ end
+ return ('field %s.%s: %s'):format(class.class[1], name, view)
end
local function asString(source)
@@ -177,7 +164,7 @@ local function asNumber(source)
if type(num) ~= 'number' then
return nil
end
- local uri = guide.getUri(source)
+ local uri = searcher.getUri(source)
local text = files.getText(uri)
if not text then
return nil
@@ -215,7 +202,7 @@ return function (source, oop)
return asDocFunction(source)
elseif source.type == 'doc.type.name' then
return asDocTypeName(source)
- elseif source.type == 'doc.field' then
- return asDocField(source)
+ elseif source.type == 'doc.field.name' then
+ return asDocFieldName(source)
end
end
diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua
index d583f1e1..d2b9d30b 100644
--- a/script/core/hover/name.lua
+++ b/script/core/hover/name.lua
@@ -1,4 +1,6 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
+local infer = require 'core.infer'
+local guide = require 'parser.guide'
local vm = require 'vm'
local buildName
@@ -19,7 +21,7 @@ end
local function asField(source, oop)
local class
if source.node.type ~= 'getglobal' then
- class = vm.getClass(source.node, 0)
+ class = infer.getClass(source.node)
end
local node = class or guide.getKeyName(source.node) or '?'
local method = guide.getKeyName(source)
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
index c3e9656d..0f0d85e0 100644
--- a/script/core/hover/return.lua
+++ b/script/core/hover/return.lua
@@ -1,12 +1,4 @@
-local guide = require 'core.guide'
-local vm = require 'vm'
-
-local function mergeTypes(returns)
- if type(returns) == 'string' then
- return returns
- end
- return guide.mergeTypes(returns)
-end
+local infer = require 'core.infer'
local function getReturnDualByDoc(source)
local docs = source.bindDocs
@@ -55,24 +47,20 @@ local function asFunction(source)
local returns = {}
for i, rtn in ipairs(dual) do
local line = {}
- local types = {}
+ local infers = {}
if i == 1 then
line[#line+1] = ' -> '
else
line[#line+1] = ('% 3d. '):format(i)
end
for n = 1, #rtn do
- local values = vm.getInfers(rtn[n])
- for _, value in ipairs(values) do
- if value.type then
- for tp in value.type:gmatch '[^|]+' do
- types[tp] = true
- end
- end
+ local values = infer.searchInfers(rtn[n])
+ for tp in pairs(values) do
+ infers[tp] = true
end
end
- if next(types) or rtn[1] then
- local tp = mergeTypes(types) or 'any'
+ if next(infers) or rtn[1] then
+ local tp = infer.viewInfers(infers)
if rtn[1].name then
line[#line+1] = ('%s%s: %s'):format(
rtn[1].name[1],
@@ -103,7 +91,7 @@ local function asDocFunction(source)
local returns = {}
for i, rtn in ipairs(source.returns) do
local rtnText = ('%s%s'):format(
- vm.getInferType(rtn),
+ infer.searchAndViewInfers(rtn),
rtn.optional and '?' or ''
)
if i == 1 then
diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua
index edb7751b..159453e6 100644
--- a/script/core/hover/table.lua
+++ b/script/core/hover/table.lua
@@ -1,26 +1,12 @@
local vm = require 'vm'
local util = require 'utility'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local config = require 'config'
local lang = require 'language'
+local infer = require 'core.infer'
-local function getKey(src)
- local key = vm.getKeyName(src)
- if not key or #key <= 0 then
- if not src.index then
- return '[any]'
- end
- local class = vm.getClass(src.index)
- if class then
- return ('[%s]'):format(class)
- end
- local tp = vm.getInferType(src.index)
- if tp then
- return ('[%s]'):format(tp)
- end
- return '[any]'
- end
- if guide.getKeyType(src) == 'string' then
+local function formatKey(key)
+ if type(key) == 'string' then
if key:match '^[%a_][%w_]*$' then
return key
else
@@ -30,104 +16,16 @@ local function getKey(src)
return ('[%s]'):format(key)
end
-local function getFieldFull(src)
- local value = guide.getObjectValue(src) or src
- local tp = vm.getInferType(value, 0)
- --local class = vm.getClass(src)
- local literal = vm.getInferLiteral(value)
- if type(literal) == 'string' and #literal >= 50 then
- literal = literal:sub(1, 47) .. '...'
- end
- return tp, literal
-end
-
-local function getFieldFast(src)
- if src.bindDocs then
- return getFieldFull(src)
- end
- local value = guide.getObjectValue(src) or src
- if not value then
- return 'any'
- end
- if value.type == 'boolean' then
- return value.type, util.viewLiteral(value[1])
- end
- if value.type == 'number'
- or value.type == 'integer' then
- if math.tointeger(value[1]) then
- if config.config.runtime.version == 'Lua 5.3'
- or config.config.runtime.version == 'Lua 5.4' then
- return 'integer', util.viewLiteral(value[1])
- end
- end
- return value.type, util.viewLiteral(value[1])
- end
- if value.type == 'table'
- or value.type == 'function' then
- return value.type
- end
- if value.type == 'string' then
- local literal = value[1]
- if type(literal) == 'string' and #literal >= 50 then
- literal = literal:sub(1, 47) .. '...'
- end
- return value.type, util.viewLiteral(literal)
- end
- if value.type == 'doc.field' then
- return vm.getInferType(value)
- end
-end
-
-local function getField(src, timeUp, mark, key)
- if src.type == 'table'
- or src.type == 'function' then
- return nil
- end
- if src.parent then
- if src.type == 'string'
- or src.type == 'boolean'
- or src.type == 'number'
- or src.type == 'integer' then
- if src.parent.type == 'tableindex'
- or src.parent.type == 'setindex'
- or src.parent.type == 'getindex' then
- if src.parent.index == src then
- src = src.parent
- end
- end
- end
- end
- local tp, literal
- tp, literal = getFieldFast(src)
- if tp then
- return tp, literal
- end
- if timeUp or mark[key] then
- return nil
- end
- mark[key] = true
- tp, literal = getFieldFull(src)
- if tp then
- return tp, literal
- end
- return nil
-end
-
-local function buildAsHash(classes, literals, reachMax)
- local keys = {}
- for k in pairs(classes) do
- keys[#keys+1] = k
- end
- table.sort(keys)
+local function buildAsHash(keys, inferMap, literalMap, reachMax)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local class = classes[key]
- local literal = literals[key]
- if literal then
- lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ local inferView = inferMap[key]
+ local literalView = literalMap[key]
+ if literalView then
+ lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView)
else
- lines[#lines+1] = (' %s: %s,'):format(key, class)
+ lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView)
end
end
if reachMax then
@@ -137,23 +35,19 @@ local function buildAsHash(classes, literals, reachMax)
return table.concat(lines, '\n')
end
-local function buildAsConst(classes, literals, reachMax)
- local keys = {}
- for k in pairs(classes) do
- keys[#keys+1] = k
- end
+local function buildAsConst(keys, inferMap, literalMap, reachMax)
table.sort(keys, function (a, b)
- return tonumber(literals[a]) < tonumber(literals[b])
+ return tonumber(literalMap[a]) < tonumber(literalMap[b])
end)
local lines = {}
lines[#lines+1] = '{'
for _, key in ipairs(keys) do
- local class = classes[key]
- local literal = literals[key]
- if literal then
- lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal)
+ local inferView = inferMap[key]
+ local literalView = literalMap[key]
+ if literalView then
+ lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView)
else
- lines[#lines+1] = (' %s: %s,'):format(key, class)
+ lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView)
end
end
if reachMax then
@@ -163,111 +57,79 @@ local function buildAsConst(classes, literals, reachMax)
return table.concat(lines, '\n')
end
-local function mergeLiteral(literals)
- local results = {}
+local typeSorter = {
+ ['string'] = 1,
+ ['number'] = 2,
+ ['boolean'] = 3,
+}
+
+local function getKeyMap(fields)
+ local keys = {}
local mark = {}
- for _, value in ipairs(literals) do
- if not mark[value] then
- mark[value] = true
- results[#results+1] = value
+ for _, field in ipairs(fields) do
+ local key = vm.getKeyName(field)
+ local tp = vm.getKeyType(field)
+ if tp == 'number' then
+ key = tonumber(key)
+ elseif tp == 'boolean' then
+ key = key == 'true'
end
- end
- if #results == 0 then
- return nil
- end
- table.sort(results)
- return table.concat(results, '|')
-end
-
-local function mergeTypes(types)
- local results = {}
- local mark = {
- -- 讲道理table的keyvalue不会是nil
- ['nil'] = true,
- }
- for _, tv in ipairs(types) do
- for tp in tv:gmatch '[^|]+' do
- if not mark[tp] then
- mark[tp] = true
- results[tp] = true
- end
+ if key and not mark[key] then
+ mark[key] = true
+ keys[#keys+1] = key
end
end
- return guide.mergeTypes(results)
-end
-
-local function clearClasses(classes)
- classes['[nil]'] = nil
- classes['[any]'] = nil
- classes['[string]'] = nil
+ table.sort(keys, function (a, b)
+ local ta = typeSorter[type(a)]
+ local tb = typeSorter[type(b)]
+ if ta == tb then
+ return tostring(a) < tostring(b)
+ else
+ return ta < tb
+ end
+ end)
+ return keys
end
return function (source)
- if config.config.hover.previewFields <= 0 then
+ local maxFields = config.config.hover.previewFields
+ if maxFields <= 0 then
return 'table'
end
- local literals = {}
- local classes = {}
- local clock = os.clock()
- local timeUp
- local mark = {}
- local fields = vm.getFields(source, 0)
- local keyCount = 0
- local reachMax
- for _, src in ipairs(fields) do
- local key = getKey(src)
- if not key then
- goto CONTINUE
- end
- if not classes[key] then
- classes[key] = {}
- keyCount = keyCount + 1
- end
- if not literals[key] then
- literals[key] = {}
- end
- if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then
- timeUp = true
- end
- local class, literal = getField(src, timeUp, mark, key)
- if literal == 'nil' then
- literal = nil
- end
- classes[key][#classes[key]+1] = class
- literals[key][#literals[key]+1] = literal
- if keyCount >= config.config.hover.previewFields then
- reachMax = true
- break
- end
- ::CONTINUE::
- end
-
- clearClasses(classes)
- for key, class in pairs(classes) do
- literals[key] = mergeLiteral(literals[key])
- classes[key] = mergeTypes(class)
- end
+ local fields = vm.getRefs(source, '*')
+ local keys = getKeyMap(fields)
- if not next(classes) then
+ if #keys == 0 then
return '{}'
end
- local intValue = true
- for key, class in pairs(classes) do
- if class ~= 'integer' or not tonumber(literals[key]) then
- intValue = false
- break
+ local inferMap = {}
+ local literalMap = {}
+
+ local reachMax = maxFields < #keys
+
+ local isConsts = true
+ for i = 1, math.min(maxFields, #keys) do
+ local key = keys[i]
+ inferMap[key] = infer.searchAndViewInfers(source, key)
+ literalMap[key] = infer.searchAndViewLiterals(source, key)
+ if not tonumber(literalMap[key]) then
+ isConsts = false
end
end
+
local result
- if intValue then
- result = buildAsConst(classes, literals, reachMax)
+
+ if isConsts then
+ result = buildAsConst(keys, inferMap, literalMap, reachMax)
else
- result = buildAsHash(classes, literals, reachMax)
- end
- if timeUp then
- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ result = buildAsHash(keys, inferMap, literalMap, reachMax)
end
+
+ --if timeUp then
+ -- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result)
+ --end
+
return result
end
diff --git a/script/core/infer.lua b/script/core/infer.lua
new file mode 100644
index 00000000..77236811
--- /dev/null
+++ b/script/core/infer.lua
@@ -0,0 +1,634 @@
+local searcher = require 'core.searcher'
+local config = require 'config'
+local noder = require 'core.noder'
+local util = require 'utility'
+
+local STRING_OR_TABLE = {'STRING_OR_TABLE'}
+local BE_RETURN = {'BE_RETURN'}
+local BE_CONNACT = {'BE_CONNACT'}
+local CLASS = {'CLASS'}
+local TABLE = {'TABLE'}
+
+local TypeSort = {
+ ['boolean'] = 1,
+ ['string'] = 2,
+ ['integer'] = 3,
+ ['number'] = 4,
+ ['table'] = 5,
+ ['function'] = 6,
+ ['true'] = 101,
+ ['false'] = 102,
+ ['nil'] = 999,
+}
+
+local m = {}
+
+local function mergeTable(a, b)
+ if not b then
+ return
+ end
+ for v in pairs(b) do
+ a[v] = true
+ end
+end
+
+local function searchInferOfUnary(value, infers)
+ local op = value.op.type
+ if op == 'not' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '#' then
+ infers['integer'] = true
+ return
+ end
+ if op == '-' then
+ if m.hasType(value[1], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+end
+
+local function searchInferOfBinary(value, infers)
+ local op = value.op.type
+ if op == 'and' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[2]))
+ else
+ mergeTable(infers, m.searchInfers(value[1]))
+ end
+ return
+ end
+ if op == 'or' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[1]))
+ else
+ mergeTable(infers, m.searchInfers(value[2]))
+ end
+ return
+ end
+ if op == '=='
+ or op == '~='
+ or op == '<'
+ or op == '>'
+ or op == '<='
+ or op == '>=' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '&'
+ or op == '|' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers['string'] = true
+ return
+ end
+ if op == '^'
+ or op == '/' then
+ infers['number'] = true
+ return
+ end
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '%'
+ or op == '//' then
+ if m.hasType(value[1], 'integer')
+ and m.hasType(value[2], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+end
+
+local function searchInferOfValue(value, infers)
+ if value.type == 'string' then
+ infers['string'] = true
+ return true
+ end
+ if value.type == 'boolean' then
+ infers['boolean'] = true
+ return true
+ end
+ if value.type == 'table' then
+ if value.array then
+ local node = m.searchAndViewInfers(value.array)
+ local infer = node .. '[]'
+ infers[infer] = true
+ else
+ infers['table'] = true
+ end
+ return true
+ end
+ if value.type == 'number' then
+ if math.type(value[1]) == 'integer' then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return true
+ end
+ if value.type == 'nil' then
+ infers['nil'] = true
+ return true
+ end
+ if value.type == 'function' then
+ infers['function'] = true
+ return true
+ end
+ if value.type == 'unary' then
+ searchInferOfUnary(value, infers)
+ return true
+ end
+ if value.type == 'binary' then
+ searchInferOfBinary(value, infers)
+ return true
+ end
+ return false
+end
+
+local function searchLiteralOfValue(value, literals)
+ if value.type == 'string'
+ or value.type == 'boolean'
+ or value.type == 'number'
+ or value.type == 'integer' then
+ local v = value[1]
+ if v ~= nil then
+ literals[v] = true
+ end
+ return
+ end
+ if value.type == 'unary' then
+ local op = value.op.type
+ if op == '-' then
+ local subLiterals = m.searchLiterals(value[1])
+ if subLiterals then
+ for subLiteral in pairs(subLiterals) do
+ local num = tonumber(subLiteral)
+ if num then
+ literals[-num] = true
+ end
+ end
+ end
+ end
+ if op == '~' then
+ local subLiterals = m.searchLiterals(value[1])
+ if subLiterals then
+ for subLiteral in pairs(subLiterals) do
+ local num = math.tointeger(subLiteral)
+ if num then
+ literals[~num] = true
+ end
+ end
+ end
+ end
+ end
+ return
+end
+
+local function bindClassOrType(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.type' then
+ return true
+ end
+ end
+ return false
+end
+
+local function cleanInfers(infers)
+ local version = config.config.runtime.version
+ local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4'
+ infers['unknown'] = nil
+ if infers['any'] and infers['nil'] then
+ infers['nil'] = nil
+ end
+ if infers['number'] then
+ enableInteger = false
+ end
+ if not enableInteger and infers['integer'] then
+ infers['integer'] = nil
+ infers['number'] = true
+ end
+ -- stringlib 就是 string
+ if infers['stringlib'] and infers['string'] then
+ infers['stringlib'] = nil
+ end
+ -- 如果是通过 .. 来推测的,且结果里没有 number 与 integer,则推测为string
+ if infers[BE_CONNACT] then
+ infers[BE_CONNACT] = nil
+ if not infers['number'] and not infers['integer'] then
+ infers['string'] = true
+ end
+ end
+ -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型
+ if infers[STRING_OR_TABLE] then
+ infers[STRING_OR_TABLE] = nil
+ if not infers['table'] and not infers['string'] then
+ infers['table'] = true
+ infers['string'] = true
+ end
+ end
+ -- 如果有doc标记,则先移除table类型
+ if infers[CLASS] then
+ infers[CLASS] = nil
+ infers['table'] = nil
+ end
+ -- 用doc标记的table,加入table类型
+ if infers[TABLE] then
+ infers[TABLE] = nil
+ infers['table'] = true
+ end
+ if infers[BE_RETURN] then
+ infers[BE_RETURN] = nil
+ infers['nil'] = nil
+ end
+ infers['any'] = nil
+end
+
+---合并对象的推断类型
+---@param infers string[]
+---@return string
+function m.viewInfers(infers)
+ if infers[0] then
+ return infers[0]
+ end
+ -- 如果有显性的 any ,则直接显示为 any
+ if infers['any'] then
+ infers[0] = 'any'
+ return 'any'
+ end
+ local result = {}
+ local count = 0
+ for infer in pairs(infers) do
+ count = count + 1
+ result[count] = infer
+ end
+ -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any
+ if count == 0 then
+ infers[0] = 'any'
+ return 'any'
+ end
+ table.sort(result, function (a, b)
+ local sa = TypeSort[a] or 100
+ local sb = TypeSort[b] or 100
+ if sa == sb then
+ return a < b
+ else
+ return sa < sb
+ end
+ end)
+ infers[0] = table.concat(result, '|')
+ return infers[0]
+end
+
+---合并对象的值
+---@param literals string[]
+---@return string
+function m.viewLiterals(literals)
+ local result = {}
+ local count = 0
+ for infer in pairs(literals) do
+ count = count + 1
+ result[count] = util.viewLiteral(infer)
+ end
+ if count == 0 then
+ return nil
+ end
+ table.sort(result)
+ local view = table.concat(result, '|')
+ return view
+end
+
+function m.viewDocName(doc)
+ if not doc then
+ return nil
+ end
+ if doc.type == 'doc.type' then
+ local list = {}
+ for _, tp in ipairs(doc.types) do
+ list[#list+1] = m.getDocName(tp)
+ end
+ for _, enum in ipairs(doc.enums) do
+ list[#list+1] = m.getDocName(enum)
+ end
+ return table.concat(list, '|')
+ end
+ return m.getDocName(doc)
+end
+
+function m.getDocName(doc)
+ if not doc then
+ return nil
+ end
+ if doc.type == 'doc.class.name'
+ or doc.type == 'doc.type.name' then
+ local name = doc[1] or '?'
+ if doc.typeGeneric then
+ return '<' .. name .. '>'
+ else
+ return name
+ end
+ end
+ if doc.type == 'doc.type.array' then
+ local nodeName = m.viewDocName(doc.node) or '?'
+ return nodeName .. '[]'
+ end
+ if doc.type == 'doc.type.table' then
+ local key = m.viewDocName(doc.tkey) or '?'
+ local value = m.viewDocName(doc.tvalue) or '?'
+ return ('table<%s, %s>'):format(key, value)
+ end
+ if doc.type == 'doc.type.function' then
+ return 'function'
+ end
+ if doc.type == 'doc.type.enum'
+ or doc.type == 'doc.resume' then
+ local value = doc[1] or '?'
+ return value
+ end
+end
+
+function m.viewDocFunction(doc)
+ if doc.type ~= 'doc.type.function' then
+ return ''
+ end
+ local args = {}
+ for i, arg in ipairs(doc.args) do
+ args[i] = ('%s: %s'):format(arg.name[1], m.viewDocName(arg.extends))
+ end
+ local label = ('fun(%s)'):format(table.concat(args, ', '))
+ if #doc.returns > 0 then
+ local returns = {}
+ for i, rtn in ipairs(doc.returns) do
+ returns[i] = m.viewDocName(rtn)
+ end
+ label = ('%s:%s'):format(label, table.concat(returns))
+ end
+ return label
+end
+
+---显示对象的推断类型
+---@param source parser.guide.object
+---@return string
+local function searchInfer(source, infers)
+ if bindClassOrType(source) then
+ return
+ end
+ if searchInferOfValue(source, infers) then
+ return
+ end
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchInferOfValue(value, infers)
+ return
+ end
+ -- check LuaDoc
+ local docName = m.getDocName(source)
+ if docName then
+ infers[docName] = true
+ if docName ~= 'unknown' then
+ infers[CLASS] = true
+ end
+ if docName == 'table' then
+ infers[TABLE] = true
+ end
+ end
+ if source.parent.type == 'unary' then
+ local op = source.parent.op.type
+ -- # XX -> string | table
+ if op == '#' then
+ infers[STRING_OR_TABLE] = true
+ return
+ end
+ if op == '-' then
+ infers['number'] = true
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+ return
+ end
+ if source.parent.type == 'binary' then
+ local op = source.parent.op.type
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '/'
+ or op == '//'
+ or op == '^'
+ or op == '%' then
+ infers['number'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '|'
+ or op == '&' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers[BE_CONNACT] = true
+ return
+ end
+ end
+ -- X.a -> table
+ if source.next and source.next.node == source then
+ if source.next.type == 'setfield'
+ or source.next.type == 'setindex'
+ or source.next.type == 'setmethod'
+ or source.next.type == 'getfield'
+ or source.next.type == 'getindex' then
+ infers['table'] = true
+ end
+ if source.next.type == 'getmethod' then
+ infers[STRING_OR_TABLE] = true
+ end
+ end
+ -- return XX
+ if source.parent.type == 'return' then
+ infers[BE_RETURN] = true
+ end
+end
+
+local function searchLiteral(source, literals)
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchLiteralOfValue(value, literals)
+ return
+ end
+end
+
+---搜索对象的推断类型
+---@param source parser.guide.object
+---@param field? string
+---@return string[]
+function m.searchInfers(source, field)
+ if not source then
+ return nil
+ end
+ local defs = searcher.requestDefinition(source, field)
+ local infers = {}
+ local mark = {}
+ if not field then
+ mark[source] = true
+ searchInfer(source, infers)
+ local id = noder.getID(source)
+ if id then
+ local node = noder.getNodeByID(source, id)
+ if node and node.sources then
+ for _, src in ipairs(node.sources) do
+ if not mark[src] then
+ mark[src] = true
+ searchInfer(src, infers)
+ end
+ end
+ end
+ end
+ end
+ if source.type == 'field' or source.type == 'method' then
+ mark[source.parent] = true
+ searchInfer(source.parent, infers)
+ end
+ for _, def in ipairs(defs) do
+ if not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers)
+ end
+ end
+ if source.docParam then
+ local docType = source.docParam.extends
+ if docType.type == 'doc.type' then
+ for _, def in ipairs(docType.types) do
+ if def.typeGeneric and not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers)
+ end
+ end
+ end
+ end
+ if source.type == 'doc.type' then
+ if source.type == 'doc.type' then
+ for _, def in ipairs(source.types) do
+ if def.typeGeneric and not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers)
+ end
+ end
+ end
+ end
+ cleanInfers(infers)
+ return infers
+end
+
+---搜索对象的字面量值
+---@param source parser.guide.object
+---@param field? string
+---@return table
+function m.searchLiterals(source, field)
+ local defs = searcher.requestDefinition(source, field)
+ local literals = {}
+ local mark = {}
+ if not field then
+ mark[source] = true
+ searchLiteral(source, literals)
+ end
+ for _, def in ipairs(defs) do
+ if not mark[def] then
+ mark[def] = true
+ searchLiteral(def, literals)
+ end
+ end
+ return literals
+end
+
+---搜索并显示推断值
+---@param source parser.guide.object
+---@param field? string
+---@return string
+function m.searchAndViewLiterals(source, field)
+ if not source then
+ return nil
+ end
+ local literals = m.searchLiterals(source, field)
+ local view = m.viewLiterals(literals)
+ return view
+end
+
+---判断对象的推断值是否是 true
+---@param source parser.guide.object
+function m.isTrue(source)
+ if not source then
+ return false
+ end
+ local literals = m.searchLiterals(source)
+ for literal in pairs(literals) do
+ if literal ~= false then
+ return true
+ end
+ end
+ return false
+end
+
+---判断对象的推断类型是否包含某个类型
+function m.hasType(source, tp)
+ local infers = m.searchInfers(source)
+ return infers[tp] or false
+end
+
+---搜索并显示推断类型
+---@param source parser.guide.object
+---@param field? string
+---@return string
+function m.searchAndViewInfers(source, field)
+ if not source then
+ return 'any'
+ end
+ local infers = m.searchInfers(source, field)
+ local view = m.viewInfers(infers)
+ return view
+end
+
+---搜索并显示推断的class
+---@param source parser.guide.object
+---@return string?
+function m.getClass(source)
+ if not source then
+ return nil
+ end
+ local infers = {}
+ local defs = searcher.requestDefinition(source)
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.class.name' then
+ infers[def[1]] = true
+ end
+ end
+ local view = m.viewInfers(infers)
+ if view == 'any' then
+ return nil
+ end
+ return view
+end
+
+return m
diff --git a/script/core/keyword.lua b/script/core/keyword.lua
index 71ea4969..73892f18 100644
--- a/script/core/keyword.lua
+++ b/script/core/keyword.lua
@@ -1,6 +1,6 @@
local define = require 'proto.define'
-local guide = require 'core.guide'
local files = require 'files'
+local guide = require 'parser.guide'
local keyWordMap = {
{'do', function (info, results)
diff --git a/script/core/noder.lua b/script/core/noder.lua
new file mode 100644
index 00000000..c3679612
--- /dev/null
+++ b/script/core/noder.lua
@@ -0,0 +1,926 @@
+local util = require 'utility'
+local guide = require 'parser.guide'
+
+local LastIDCache = {}
+local FirstIDCache = {}
+local SPLIT_CHAR = '\x1F'
+local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$'
+local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*'
+local ANY_FIELD_CHAR = '*'
+local RETURN_INDEX = SPLIT_CHAR .. '#'
+local PARAM_INDEX = SPLIT_CHAR .. '&'
+local TABLE_KEY = SPLIT_CHAR .. '<'
+local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR
+local URI_CHAR = '@'
+local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(.*)'
+
+---@class node
+-- 当前节点的id
+---@field id string
+-- 使用该ID的单元
+---@field sources parser.guide.object[]
+-- 前进的关联ID
+---@field forward string[]
+-- 后退的关联ID
+---@field backward string[]
+-- 函数调用参数信息(用于泛型)
+---@field call parser.guide.object
+
+---@alias noders table<string, node[]>
+
+---创建source的链接信息
+---@param noders noders
+---@param id string
+---@return node
+local function getNode(noders, id)
+ if not noders[id] then
+ noders[id] = {
+ id = id,
+ }
+ end
+ return noders[id]
+end
+
+---获取语法树单元的key
+---@param source parser.guide.object
+---@return string? key
+---@return parser.guide.object? node
+local function getKey(source)
+ if source.type == 'local' then
+ return tostring(source.start), nil
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ return tostring(source.node.start), nil
+ elseif source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ local node = source.node
+ if node.tag == '_ENV' then
+ return ('%q'):format(source[1] or ''), nil
+ else
+ return ('%q'):format(source[1] or ''), node
+ end
+ elseif source.type == 'getfield'
+ or source.type == 'setfield' then
+ return ('%q'):format(source.field and source.field[1] or ''), source.node
+ elseif source.type == 'tablefield' then
+ return ('%q'):format(source.field and source.field[1] or ''), source.parent
+ elseif source.type == 'getmethod'
+ or source.type == 'setmethod' then
+ return ('%q'):format(source.method and source.method[1] or ''), source.node
+ elseif source.type == 'setindex'
+ or source.type == 'getindex' then
+ local index = source.index
+ if not index then
+ return ANY_FIELD_CHAR, source.node
+ end
+ if index.type == 'string'
+ or index.type == 'boolean'
+ or index.type == 'number' then
+ return ('%q'):format(index[1] or ''), source.node
+ elseif index.type ~= 'function'
+ and index.type ~= 'table' then
+ return ANY_FIELD_CHAR, source.node
+ end
+ elseif source.type == 'tableindex' then
+ local index = source.index
+ if not index then
+ return ANY_FIELD_CHAR, source.parent
+ end
+ if index.type == 'string'
+ or index.type == 'boolean'
+ or index.type == 'number' then
+ return ('%q'):format(index[1] or ''), source.parent
+ elseif index.type ~= 'function'
+ and index.type ~= 'table' then
+ return ANY_FIELD_CHAR, source.parent
+ end
+ elseif source.type == 'table' then
+ return source.start, nil
+ elseif source.type == 'label' then
+ return source.start, nil
+ elseif source.type == 'goto' then
+ if source.node then
+ return source.node.start, nil
+ end
+ return nil, nil
+ elseif source.type == 'function' then
+ return source.start, nil
+ elseif source.type == 'string' then
+ return '', nil
+ elseif source.type == 'integer'
+ or source.type == 'number'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return source.start, nil
+ elseif source.type == '...' then
+ return source.start, nil
+ elseif source.type == 'varargs' then
+ if source.node then
+ return source.node.start, nil
+ end
+ elseif source.type == 'select' then
+ return ('%d%s%d'):format(source.start, RETURN_INDEX, source.sindex)
+ elseif source.type == 'call' then
+ local node = source.node
+ if node.special == 'rawget'
+ or node.special == 'rawset' then
+ if not source.args then
+ return nil, nil
+ end
+ local tbl, key = source.args[1], source.args[2]
+ if not tbl or not key then
+ return nil, nil
+ end
+ if key.type == 'string' then
+ return ('%q'):format(key[1] or ''), tbl
+ else
+ return '', tbl
+ end
+ end
+ return source.finish, nil
+ elseif source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name'
+ or source.type == 'doc.see.name' then
+ local name = source[1]
+ return name, nil
+ elseif source.type == 'doc.type.name' then
+ local name = source[1]
+ if source.typeGeneric then
+ return source.typeGeneric[name][1].start, nil
+ else
+ return name, nil
+ end
+ elseif source.type == 'doc.class'
+ or source.type == 'doc.type'
+ or source.type == 'doc.param'
+ or source.type == 'doc.vararg'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.function' then
+ return source.start, nil
+ elseif source.type == 'doc.see.field' then
+ return ('%q'):format(source[1]), source.parent.name
+ elseif source.type == 'generic.closure' then
+ return source.call.start, nil
+ elseif source.type == 'generic.value' then
+ return ('%s|%s'):format(
+ source.closure.call.start,
+ getKey(source.proto)
+ )
+ end
+ return nil, nil
+end
+
+local function checkMode(source)
+ if source.type == 'table' then
+ return 't:'
+ end
+ if source.type == 'select' then
+ return 's:'
+ end
+ if source.type == 'function' then
+ return 'f:'
+ end
+ if source.type == 'string' then
+ return 'str:'
+ end
+ if source.type == 'number'
+ or source.type == 'integer'
+ or source.type == 'boolean'
+ or source.type == 'nil' then
+ return 'i:'
+ end
+ if source.type == 'call' then
+ return 'c:'
+ end
+ if source.type == '...'
+ or source.type == 'varargs' then
+ return 'va:'
+ end
+ if source.type == 'doc.class.name'
+ or source.type == 'doc.type.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name' then
+ if source.typeGeneric then
+ return 'dg:'
+ end
+ return 'dn:'
+ end
+ if source.type == 'doc.field.name' then
+ return 'dfn:'
+ end
+ if source.type == 'doc.see.name' then
+ return 'dsn:'
+ end
+ if source.type == 'doc.class' then
+ return 'dc:'
+ end
+ if source.type == 'doc.type' then
+ return 'dt:'
+ end
+ if source.type == 'doc.param' then
+ return 'dp:'
+ end
+ if source.type == 'doc.type.function' then
+ return 'dfun:'
+ end
+ if source.type == 'doc.type.table' then
+ return 'dtable:'
+ end
+ if source.type == 'doc.type.array' then
+ return 'darray:'
+ end
+ if source.type == 'doc.vararg' then
+ return 'dv:'
+ end
+ if source.type == 'doc.type.enum'
+ or source.type == 'doc.resume' then
+ return 'de:'
+ end
+ if source.type == 'generic.closure' then
+ return 'gc:'
+ end
+ if source.type == 'generic.value' then
+ local id = 'gv:'
+ if guide.getUri(source.closure.call) ~= guide.getUri(source.proto) then
+ id = id .. URI_CHAR .. guide.getUri(source.closure.call)
+ end
+ return id
+ end
+ if guide.isGlobal(source) then
+ return 'g:'
+ end
+ if source.type == 'getlocal'
+ or source.type == 'setlocal' then
+ source = source.node
+ end
+ if source.parent.type == 'funcargs' then
+ return 'p:'
+ end
+ return 'l:'
+end
+
+local IDList = {}
+---获取语法树单元的字符串ID
+---@param source parser.guide.object
+---@return string? id
+local function getID(source)
+ if not source then
+ return nil
+ end
+ if source._id ~= nil then
+ return source._id or nil
+ end
+ if source.type == 'field'
+ or source.type == 'method' then
+ source._id = false
+ return nil
+ end
+ local current = source
+ local index = 0
+ while true do
+ if current.type == 'paren' then
+ current = current.exp
+ goto CONTINUE
+ end
+ local id, node = getKey(current)
+ if not id then
+ break
+ end
+ index = index + 1
+ IDList[index] = id
+ if not node then
+ break
+ end
+ current = node
+ if current.special == '_G' then
+ for i = index, 2, -1 do
+ if IDList[i] == '"_G"' then
+ IDList[i] = nil
+ end
+ end
+ break
+ end
+ ::CONTINUE::
+ end
+ if index == 0 then
+ source._id = false
+ return nil
+ end
+ for i = index + 1, #IDList do
+ IDList[i] = nil
+ end
+ local mode = checkMode(current)
+ if not mode then
+ source._id = false
+ return nil
+ end
+ util.revertTable(IDList)
+ local id = mode .. table.concat(IDList, SPLIT_CHAR)
+ source._id = id
+ return id
+end
+
+---添加关联的前进ID
+---@param noders noders
+---@param id string
+---@param forwardID string
+local function pushForward(noders, id, forwardID, tag)
+ if not id
+ or not forwardID
+ or forwardID == ''
+ or id == forwardID then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.forward then
+ node.forward = {}
+ end
+ if node.forward[forwardID] ~= nil then
+ return
+ end
+ node.forward[forwardID] = tag or false
+ node.forward[#node.forward+1] = forwardID
+end
+
+---添加关联的后退ID
+---@param noders noders
+---@param id string
+---@param backwardID string
+local function pushBackward(noders, id, backwardID, tag)
+ if not id
+ or not backwardID
+ or backwardID == ''
+ or id == backwardID then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.backward then
+ node.backward = {}
+ end
+ if node.backward[backwardID] ~= nil then
+ return
+ end
+ node.backward[backwardID] = tag or false
+ node.backward[#node.backward+1] = backwardID
+end
+
+local m = {}
+
+m.SPLIT_CHAR = SPLIT_CHAR
+m.RETURN_INDEX = RETURN_INDEX
+m.PARAM_INDEX = PARAM_INDEX
+m.TABLE_KEY = TABLE_KEY
+m.ANY_FIELD = ANY_FIELD
+m.URI_CHAR = URI_CHAR
+
+--- 寻找doc的主体
+---@param obj parser.guide.object
+---@return parser.guide.object
+local function getDocStateWithoutCrossFunction(obj)
+ for _ = 1, 1000 do
+ local parent = obj.parent
+ if not parent then
+ return obj
+ end
+ if parent.type == 'doc' then
+ return obj
+ end
+ if parent.type == 'doc.type.function' then
+ return nil
+ end
+ obj = parent
+ end
+ error('guide.getDocState overstack')
+end
+
+---添加关联单元
+---@param noders noders
+---@param source parser.guide.object
+function m.pushSource(noders, source)
+ local id = m.getID(source)
+ if not id then
+ return
+ end
+ local node = getNode(noders, id)
+ if not node.sources then
+ node.sources = {}
+ end
+ node.sources[#node.sources+1] = source
+end
+
+---@param noders noders
+---@param source parser.guide.object
+---@return parser.guide.object[]
+function m.compileNode(noders, source)
+ local id = getID(source)
+ local value = source.value
+ if value then
+ local valueID = getID(value)
+ if valueID then
+ -- x = y : x -> y
+ pushForward(noders, id, valueID, 'set')
+ -- 参数禁止反向查找赋值
+ if valueID:sub(1, 2) ~= 'p:' then
+ pushBackward(noders, valueID, id, 'set')
+ end
+ end
+ end
+ -- self -> mt:xx
+ if source.type == 'local' and source[1] == 'self' then
+ local func = guide.getParentFunction(source)
+ if func.isGeneric then
+ return
+ end
+ local setmethod = func.parent
+ -- guess `self`
+ if setmethod and ( setmethod.type == 'setmethod'
+ or setmethod.type == 'setfield'
+ or setmethod.type == 'setindex') then
+ pushForward(noders, id, getID(setmethod.node), 'method')
+ pushBackward(noders, getID(setmethod.node), id, 'method')
+ end
+ end
+ -- 分解 @type
+ if source.type == 'doc.type' then
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
+ end
+ end
+ for _, enumUnit in ipairs(source.enums) do
+ pushForward(noders, id, getID(enumUnit))
+ end
+ for _, resumeUnit in ipairs(source.resumes) do
+ pushForward(noders, id, getID(resumeUnit))
+ end
+ for _, typeUnit in ipairs(source.types) do
+ local unitID = getID(typeUnit)
+ pushForward(noders, id, unitID)
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushBackward(noders, unitID, getID(src))
+ end
+ end
+ end
+ end
+ -- 分解 @alias
+ if source.type == 'doc.alias' then
+ pushForward(noders, getID(source.alias), getID(source.extends))
+ end
+ -- 分解 @class
+ if source.type == 'doc.class' then
+ pushForward(noders, id, getID(source.class))
+ pushForward(noders, getID(source.class), id)
+ if source.extends then
+ for _, ext in ipairs(source.extends) do
+ pushBackward(noders, id, getID(ext))
+ end
+ end
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
+ end
+ end
+ for _, field in ipairs(source.fields) do
+ local key = field.field[1]
+ if key then
+ local keyID = ('%s%s%q'):format(
+ id,
+ SPLIT_CHAR,
+ key
+ )
+ pushForward(noders, keyID, getID(field.field))
+ pushForward(noders, getID(field.field), keyID)
+ pushForward(noders, keyID, getID(field.extends))
+ pushBackward(noders, getID(field.extends), keyID)
+ end
+ end
+ end
+ if source.type == 'doc.param' then
+ pushForward(noders, id, getID(source.extends))
+ for _, src in ipairs(source.bindSources) do
+ if src.type == 'local' and src.parent.type == 'in' then
+ pushForward(noders, getID(src), id)
+ end
+ end
+ end
+ if source.type == 'doc.vararg' then
+ pushForward(noders, getID(source), getID(source.vararg))
+ end
+ if source.type == 'doc.see' then
+ local nameID = getID(source.name)
+ local classID = nameID:gsub('^dsn:', 'dn:')
+ pushForward(noders, nameID, classID)
+ if source.field then
+ local fieldID = getID(source.field)
+ local fieldClassID = fieldID:gsub('^dsn:', 'dn:')
+ pushForward(noders, fieldID, fieldClassID)
+ end
+ end
+ if source.type == 'call' then
+ local node = source.node
+ local nodeID = getID(node)
+ if not nodeID then
+ return
+ end
+ getNode(noders, id).call = source
+ -- 将 call 映射到 node#1 上
+ local callID = ('%s%s%s'):format(
+ nodeID,
+ RETURN_INDEX,
+ 1
+ )
+ pushForward(noders, id, callID)
+ -- 将setmetatable映射到 param1 以及 param2.__index 上
+ if node.special == 'setmetatable' then
+ local tblID = getID(source.args and source.args[1])
+ local metaID = getID(source.args and source.args[2])
+ local indexID
+ if metaID then
+ indexID = ('%s%s%q'):format(
+ metaID,
+ SPLIT_CHAR,
+ '__index'
+ )
+ end
+ pushForward(noders, id, callID)
+ pushBackward(noders, callID, id)
+ pushForward(noders, callID, tblID)
+ pushForward(noders, callID, indexID)
+ pushBackward(noders, tblID, callID)
+ --pushBackward(noders, indexID, callID)
+ end
+ if node.special == 'require' then
+ local arg1 = source.args and source.args[1]
+ if arg1 and arg1.type == 'string' then
+ getNode(noders, callID).require = arg1[1]
+ end
+ end
+ end
+ if source.type == 'select' then
+ if source.vararg.type == 'call' then
+ local call = source.vararg
+ local node = call.node
+ local nodeID = getID(node)
+ if not nodeID then
+ return
+ end
+ -- 将call的返回值接收映射到函数返回值上
+ local callXID = ('%s%s%s'):format(
+ nodeID,
+ RETURN_INDEX,
+ source.sindex
+ )
+ pushForward(noders, id, callXID)
+ pushBackward(noders, callXID, id)
+ getNode(noders, id).call = call
+ if node.special == 'pcall'
+ or node.special == 'xpcall' then
+ local index = source.sindex - 1
+ if index <= 0 then
+ return
+ end
+ local funcID = call.args and getID(call.args[1])
+ if not funcID then
+ return
+ end
+ local funcXID = ('%s%s%s'):format(
+ funcID,
+ RETURN_INDEX,
+ index
+ )
+ pushForward(noders, id, funcXID)
+ pushBackward(noders, funcXID, id)
+ end
+ end
+ if source.vararg.type == 'varargs' then
+ pushForward(noders, id, getID(source.vararg))
+ end
+ end
+ if source.type == 'doc.type.function' then
+ if source.returns then
+ for index, rtn in ipairs(source.returns) do
+ local returnID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ index
+ )
+ pushForward(noders, returnID, getID(rtn))
+ end
+ end
+ -- @type fun(x: T):T 的情况
+ local docType = getDocStateWithoutCrossFunction(source)
+ if docType and docType.type == 'doc.type' then
+ guide.eachSourceType(source, 'doc.type.name', function (typeName)
+ if typeName.typeGeneric then
+ source.isGeneric = true
+ return false
+ end
+ end)
+ end
+ end
+ if source.type == 'doc.type.table' then
+ if source.tkey then
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, getID(source.tkey))
+ end
+ if source.tvalue then
+ local valueID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, valueID, getID(source.tvalue))
+ end
+ end
+ if source.type == 'doc.type.array' then
+ if source.node then
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source.node))
+ end
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, 'dn:integer')
+ end
+ -- 将函数的返回值映射到具体的返回值上
+ if source.type == 'function' then
+ local hasDocReturn = {}
+ -- 检查 luadoc
+ if source.bindDocs then
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, rtn in ipairs(doc.returns) do
+ local fullID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ rtn.returnIndex
+ )
+ pushForward(noders, fullID, getID(rtn))
+ hasDocReturn[rtn.returnIndex] = true
+ end
+ end
+ if doc.type == 'doc.param' then
+ local paramName = doc.param[1]
+ if source.docParamMap then
+ local paramIndex = source.docParamMap[paramName]
+ local param = source.args[paramIndex]
+ if param then
+ pushForward(noders, getID(param), getID(doc))
+ param.docParam = doc
+ end
+ end
+ end
+ if doc.type == 'doc.vararg' then
+ for _, param in ipairs(source.args) do
+ if param.type == '...' then
+ pushForward(noders, getID(param), getID(doc))
+ end
+ end
+ end
+ if doc.type == 'doc.generic' then
+ source.isGeneric = true
+ end
+ if doc.type == 'doc.overload' then
+ pushForward(noders, id, getID(doc.overload))
+ end
+ end
+ end
+ -- 检查实体返回值
+ if source.returns then
+ local returns = {}
+ for _, rtn in ipairs(source.returns) do
+ for index, rtnObj in ipairs(rtn) do
+ if not hasDocReturn[index] then
+ if not returns[index] then
+ returns[index] = {}
+ end
+ returns[index][#returns[index]+1] = rtnObj
+ end
+ end
+ end
+ for index, rtnObjs in ipairs(returns) do
+ local returnID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ index
+ )
+ for _, rtnObj in ipairs(rtnObjs) do
+ pushForward(noders, returnID, getID(rtnObj))
+ if rtnObj.type == 'function'
+ or rtnObj.type == 'call' then
+ pushBackward(noders, getID(rtnObj), returnID)
+ end
+ end
+ end
+ end
+ end
+ if source.type == 'table' then
+ if #source == 1 and source[1].type == 'varargs' then
+ source.array = source[1]
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source[1]))
+ end
+ end
+ if source.type == 'main' then
+ if source.returns then
+ for _, rtn in ipairs(source.returns) do
+ local rtnObj = rtn[1]
+ if rtnObj then
+ pushForward(noders, 'mainreturn', getID(rtnObj))
+ pushBackward(noders, getID(rtnObj), 'mainreturn')
+ end
+ end
+ end
+ end
+ if source.type == 'generic.closure' then
+ for i, rtn in ipairs(source.returns) do
+ local closureID = ('%s%s%s'):format(
+ id,
+ RETURN_INDEX,
+ i
+ )
+ local returnID = getID(rtn)
+ pushForward(noders, closureID, returnID)
+ end
+ end
+ if source.type == 'generic.value' then
+ local proto = source.proto
+ local closure = source.closure
+ local upvalues = closure.upvalues
+ if proto.type == 'doc.type.name' then
+ local key = proto[1]
+ if upvalues[key] then
+ for _, paramID in ipairs(upvalues[key]) do
+ pushForward(noders, id, paramID)
+ pushBackward(noders, paramID, id)
+ end
+ end
+ end
+ if proto.type == 'doc.type' then
+ for _, tp in ipairs(source.types) do
+ pushForward(noders, id, getID(tp))
+ pushBackward(noders, getID(tp), id)
+ end
+ end
+ if proto.type == 'doc.type.array' then
+ local nodeID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, nodeID, getID(source.node))
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, 'dn:integer')
+ end
+ if proto.type == 'doc.type.table' then
+ if source.tkey then
+ local keyID = ('%s%s'):format(
+ id,
+ TABLE_KEY
+ )
+ pushForward(noders, keyID, getID(source.tkey))
+ end
+ if source.tvalue then
+ local valueID = ('%s%s'):format(
+ id,
+ ANY_FIELD
+ )
+ pushForward(noders, valueID, getID(source.tvalue))
+ end
+ end
+ end
+end
+
+---根据ID来获取所有的node
+---@param root parser.guide.object
+---@param id string
+---@return node?
+function m.getNodeByID(root, id)
+ root = guide.getRoot(root)
+ local noders = root._noders
+ if not noders then
+ return nil
+ end
+ return noders[id]
+end
+
+---根据ID来获取第一个节点的ID
+---@param id string
+---@return string
+function m.getFirstID(id)
+ if FirstIDCache[id] then
+ return FirstIDCache[id] or nil
+ end
+ local firstID, count = id:match(FIRST_REGEX)
+ if count == 0 then
+ FirstIDCache[id] = false
+ return nil
+ end
+ FirstIDCache[id] = firstID
+ return firstID
+end
+
+---根据ID来获取上个节点的ID
+---@param id string
+---@return string
+function m.getLastID(id)
+ if LastIDCache[id] then
+ return LastIDCache[id] or nil
+ end
+ local lastID, count = id:gsub(LAST_REGEX, '')
+ if count == 0 then
+ LastIDCache[id] = false
+ return nil
+ end
+ LastIDCache[id] = lastID
+ return lastID
+end
+
+---把形如 `@file:\\\XXXXX@gv:1|1`拆分成uri与id
+---@param id string
+---@return uri? string
+---@return string id
+function m.getUriAndID(id)
+ local uri, newID = id:match(URI_REGEX)
+ return uri, newID
+end
+
+---获取source的ID
+---@param source parser.guide.object
+---@return string
+function m.getID(source)
+ return getID(source)
+end
+
+---获取source的key
+---@param source parser.guide.object
+---@return string
+function m.getKey(source)
+ return getKey(source)
+end
+
+---清除临时id(用于泛型的临时对象)
+---@param root parser.guide.object
+---@param id string
+function m.removeID(root, id)
+ root = guide.getRoot(root)
+ local noders = root._noders
+ noders[id] = nil
+end
+
+---寻找doc的主体
+---@param doc parser.guide.object
+function m.getDocState(doc)
+ return getDocStateWithoutCrossFunction(doc)
+end
+
+---获取对象的noders
+---@param source parser.guide.object
+---@return noders
+function m.getNoders(source)
+ local root = guide.getRoot(source)
+ if not root._noders then
+ root._noders = {}
+ end
+ return root._noders
+end
+
+---编译整个文件的node
+---@param source parser.guide.object
+---@return table
+function m.compileNodes(source)
+ local root = guide.getRoot(source)
+ local noders = m.getNoders(source)
+ if next(noders) then
+ return noders
+ end
+ guide.eachSource(root, function (src)
+ m.pushSource(noders, src)
+ m.compileNode(noders, src)
+ end)
+ -- Special rule: ('').XX -> stringlib.XX
+ pushBackward(noders, 'str:', 'dn:stringlib')
+ pushBackward(noders, 'dn:string', 'dn:stringlib')
+ return noders
+end
+
+return m
diff --git a/script/core/reference.lua b/script/core/reference.lua
index 7620b09e..c3f3b349 100644
--- a/script/core/reference.lua
+++ b/script/core/reference.lua
@@ -1,4 +1,5 @@
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
+local guide = require 'parser.guide'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
@@ -6,8 +7,8 @@ local findSource = require 'core.find-source'
local function sortResults(results)
-- 先按照顺序排序
table.sort(results, function (a, b)
- local u1 = guide.getUri(a.target)
- local u2 = guide.getUri(b.target)
+ local u1 = searcher.getUri(a.target)
+ local u2 = searcher.getUri(b.target)
if u1 == u2 then
return a.target.start < b.target.start
else
@@ -19,7 +20,7 @@ local function sortResults(results)
for i = #results, 1, -1 do
local res = results[i].target
local f = res.finish
- local uri = guide.getUri(res)
+ local uri = searcher.getUri(res)
if lf and f > lf and uri == lu then
table.remove(results, i)
else
@@ -64,11 +65,23 @@ return function (uri, offset)
local metaSource = vm.isMetaFile(uri)
+ local refs = vm.getRefs(source)
+ local values = {}
+ for _, src in ipairs(refs) do
+ local value = searcher.getObjectValue(src)
+ if value and value ~= src and guide.isLiteral(value) then
+ values[value] = true
+ end
+ end
+
local results = {}
- for _, src in ipairs(vm.getRefs(source, 5)) do
+ for _, src in ipairs(refs) do
if src.dummy then
goto CONTINUE
end
+ if values[src] then
+ goto CONTINUE
+ end
local root = guide.getRoot(src)
if not root then
goto CONTINUE
diff --git a/script/core/rename.lua b/script/core/rename.lua
index da82b0a6..6b67d4be 100644
--- a/script/core/rename.lua
+++ b/script/core/rename.lua
@@ -1,11 +1,11 @@
local files = require 'files'
local vm = require 'vm'
-local guide = require 'core.guide'
local proto = require 'proto'
local define = require 'proto.define'
local util = require 'utility'
local findSource = require 'core.find-source'
-local ws = require 'workspace'
+local guide = require 'parser.guide'
+local noder = require 'core.noder'
local Forcing
@@ -185,7 +185,7 @@ local function renameField(source, newname, callback)
end
callback(source, source.start, source.finish, newname)
elseif parent.type == 'setmethod' then
- local uri = guide.getUri(source)
+ local uri = guide.getUri(source)
local text = files.getText(uri)
local func = parent.value
-- function mt:name () end --> mt['newname'] = function (self) end
@@ -292,14 +292,14 @@ local function ofField(source, newname, callback)
else
node = source.node
end
- for _, src in ipairs(vm.getFields(node, 5)) do
+ for _, src in ipairs(vm.getRefs(node, '*')) do
ofFieldThen(key, src, newname, callback)
end
end
local function ofGlobal(source, newname, callback)
local key = guide.getKeyName(source)
- for _, src in ipairs(vm.getRefs(source, 0)) do
+ for _, src in ipairs(vm.getRefs(source)) do
ofFieldThen(key, src, newname, callback)
end
end
@@ -308,24 +308,27 @@ local function ofLabel(source, newname, callback)
if not isValidName(newname) and not askForcing(newname)then
return false
end
- for _, src in ipairs(vm.getRefs(source, 0)) do
+ for _, src in ipairs(vm.getRefs(source)) do
callback(src, src.start, src.finish, newname)
end
end
local function ofDocTypeName(source, newname, callback)
- for _, doc in ipairs(vm.getDocTypes(source[1])) do
+ local oldname = source[1]
+ for _, doc in ipairs(vm.getRefs(source)) do
if doc.type == 'doc.class.name'
or doc.type == 'doc.type.name'
or doc.type == 'doc.alias.name' then
- callback(doc, doc.start, doc.finish, newname)
+ if oldname == doc[1] then
+ callback(doc, doc.start, doc.finish, newname)
+ end
end
end
end
local function ofDocParamName(source, newname, callback)
callback(source, source.start, source.finish, newname)
- local doc = guide.getDocState(source)
+ local doc = noder.getDocState(source)
if doc.bindSources then
for _, src in ipairs(doc.bindSources) do
if src.type == 'local'
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
new file mode 100644
index 00000000..11e00378
--- /dev/null
+++ b/script/core/searcher.lua
@@ -0,0 +1,728 @@
+local noder = require 'core.noder'
+local guide = require 'parser.guide'
+local files = require 'files'
+local generic = require 'core.generic'
+local ws = require 'workspace'
+local vm = require 'vm.vm'
+
+local NONE = {'NONE'}
+local LAST = {'LAST'}
+
+local ignoredIDs = {
+ ['dn:unknown'] = true,
+ ['dn:nil'] = true,
+ ['dn:any'] = true,
+ ['dn:boolean'] = true,
+ ['dn:string'] = true,
+ ['dn:table'] = true,
+ ['dn:number'] = true,
+ ['dn:integer'] = true,
+ ['dn:userdata'] = true,
+ ['dn:lightuserdata'] = true,
+ ['dn:function'] = true,
+ ['dn:thread'] = true,
+}
+
+local m = {}
+
+---@alias guide.searchmode '"ref"'|'"def"'
+
+---添加结果
+---@param status guide.status
+---@param mode guide.searchmode
+---@param source parser.guide.object
+---@param force boolean
+function m.pushResult(status, mode, source, force)
+ if not source then
+ return
+ end
+ local results = status.results
+ if results[source] then
+ return
+ end
+ results[source] = true
+ if force then
+ results[#results+1] = source
+ return
+ end
+ local parent = source.parent
+ if mode == 'def' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal'
+ or source.type == 'label'
+ or source.type == 'setfield'
+ or source.type == 'setmethod'
+ or source.type == 'setindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset' then
+ results[#results+1] = source
+ end
+ end
+ if parent.type == 'return' then
+ if noder.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
+ elseif mode == 'ref' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal'
+ or source.type == 'label'
+ or source.type == 'goto'
+ or source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod'
+ or source.type == 'setindex'
+ or source.type == 'getindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number'
+ or source.type == 'nil'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.type.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.resume'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset'
+ or source.node.special == 'rawget' then
+ results[#results+1] = source
+ end
+ end
+ if parent.type == 'return' then
+ if noder.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
+ end
+end
+
+---获取uri
+---@param obj parser.guide.object
+---@return uri
+function m.getUri(obj)
+ if obj.uri then
+ return obj.uri
+ end
+ local root = guide.getRoot(obj)
+ if root then
+ return root.uri
+ end
+ return ''
+end
+
+---@param obj parser.guide.object
+---@return parser.guide.object?
+function m.getObjectValue(obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return nil
+ end
+ end
+ if obj.type == 'boolean'
+ or obj.type == 'number'
+ or obj.type == 'integer'
+ or obj.type == 'string' then
+ return obj
+ end
+ if obj.value then
+ return obj.value
+ end
+ if obj.type == 'field'
+ or obj.type == 'method' then
+ return obj.parent and obj.parent.value
+ end
+ if obj.type == 'call' then
+ if obj.node.special == 'rawset' then
+ return obj.args and obj.args[3]
+ else
+ return obj
+ end
+ end
+ if obj.type == 'select' then
+ return obj
+ end
+ return nil
+end
+
+local function crossSearch(status, uri, expect, mode)
+ m.searchRefsByID(status, uri, expect, mode)
+end
+
+local function getLock(status, uri, expect, mode)
+ local slock = status.lock
+ local ulock = slock[uri]
+ if not ulock then
+ ulock = {}
+ slock[uri] = ulock
+ end
+ local mlock = ulock[mode]
+ if not mlock then
+ mlock = {}
+ ulock[mode] = mlock
+ end
+ if mlock[expect] then
+ return false
+ end
+ mlock[expect] = true
+ return true
+end
+
+function m.searchRefsByID(status, uri, expect, mode)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ if not getLock(status, uri, expect, mode) then
+ return
+ end
+ local root = ast.ast
+ local searchStep
+ noder.compileNodes(root)
+
+ status.id = expect
+
+ local callStack = status.callStack
+
+ local mark = {}
+
+ local function search(id, field)
+ local firstID = noder.getFirstID(id)
+ if ignoredIDs[firstID] and (field or firstID ~= id) then
+ return
+ end
+ local cmark = mark[id]
+ if not cmark then
+ cmark = {}
+ mark[id] = cmark
+ end
+ log.debug('search:', id, field)
+ if field then
+ if cmark[field] then
+ return
+ end
+ cmark[field] = true
+ searchStep(id, field)
+ cmark[field] = nil
+ else
+ if cmark[NONE] then
+ return
+ end
+ cmark[NONE] = true
+ searchStep(id, nil)
+ cmark[NONE] = nil
+ end
+ log.debug('pop:', id, field)
+ end
+
+ local function checkLastID(id, field)
+ local cmark = mark[id]
+ if not cmark then
+ cmark = {}
+ mark[id] = cmark
+ end
+ if cmark[LAST] then
+ return
+ end
+ local lastID = noder.getLastID(id)
+ if not lastID then
+ return
+ end
+ local newField = id:sub(#lastID + 1)
+ if field then
+ newField = newField .. field
+ end
+ cmark[LAST] = true
+ search(lastID, newField)
+ cmark[LAST] = nil
+ return lastID
+ end
+
+ local function searchID(id, field)
+ if not id then
+ return
+ end
+ if field then
+ id = id .. field
+ end
+ search(id, nil)
+ end
+
+ local function isCallID(field)
+ if not field then
+ return false
+ end
+ if field:sub(1, 2) == noder.RETURN_INDEX then
+ return true
+ end
+ return false
+ end
+
+ local function findLastCall()
+ for i = #callStack, 1, -1 do
+ local call = callStack[i]
+ if call then
+ -- 标记此处的call失效,等待在堆栈平衡时弹出
+ callStack[i] = false
+ return call
+ end
+ end
+ return nil
+ end
+
+ local genericCallArgs = {}
+ local closureCache = {}
+ local function checkGeneric(source, field)
+ if not source.isGeneric then
+ return
+ end
+ if not isCallID(field) then
+ return
+ end
+ local call = findLastCall()
+ if not call then
+ return
+ end
+
+ if call.args then
+ for _, arg in ipairs(call.args) do
+ genericCallArgs[arg] = true
+ end
+ end
+
+ local cacheID = noder.getID(source) .. noder.getID(call)
+ local closure = closureCache[cacheID]
+ if closure == false then
+ return
+ end
+ if not closure then
+ closure = generic.createClosure(source, call)
+ closureCache[cacheID] = closure or false
+ if not closure then
+ return
+ end
+ end
+ local id = noder.getID(closure)
+ searchID(id, field)
+ end
+
+ local function checkENV(source, field)
+ if not field then
+ return
+ end
+ if source.special ~= '_G' then
+ return
+ end
+ local newID = 'g:' .. field:sub(2)
+ searchID(newID)
+ end
+
+ local forwardTag = {}
+ local backwardTag = {}
+ local function checkForward(id, node, field)
+ for _, forwardID in ipairs(node.forward) do
+ local tag = node.forward[forwardID]
+ if tag then
+ if backwardTag[tag] and backwardTag[tag] > 0 then
+ goto CONTINUE
+ end
+ forwardTag[tag] = (forwardTag[tag] or 0) + 1
+ end
+ local targetUri, targetID = noder.getUriAndID(forwardID)
+ if targetUri and not files.eq(targetUri, uri) then
+ crossSearch(status, targetUri, targetID .. (field or ''), mode)
+ else
+ searchID(targetID or forwardID, field)
+ end
+ if tag then
+ forwardTag[tag] = forwardTag[tag] - 1
+ end
+ ::CONTINUE::
+ end
+ end
+
+ local function checkBackward(id, node, field)
+ if mode ~= 'ref' and not field then
+ return
+ end
+ for _, backwardID in ipairs(node.backward) do
+ local tag = node.backward[backwardID]
+ if tag then
+ if forwardTag[tag] and forwardTag[tag] > 0 then
+ goto CONTINUE
+ end
+ backwardTag[tag] = (backwardTag[tag] or 0) + 1
+ end
+ local targetUri, targetID = noder.getUriAndID(backwardID)
+ if targetUri and not files.eq(targetUri, uri) then
+ crossSearch(status, targetUri, targetID .. (field or ''), mode)
+ else
+ searchID(targetID or backwardID, field)
+ end
+ if tag then
+ backwardTag[tag] = backwardTag[tag] - 1
+ end
+ ::CONTINUE::
+ end
+ end
+
+ local function checkRequire(requireName, field)
+ local tid = 'mainreturn' .. (field or '')
+ local uris = ws.findUrisByRequirePath(requireName)
+ for _, ruri in ipairs(uris) do
+ if not files.eq(uri, ruri) then
+ crossSearch(status, ruri, tid, mode)
+ end
+ end
+ end
+
+ local function checkGlobal(id, node, field)
+ if id:sub(1, 2) ~= 'g:' then
+ return
+ end
+ local firstID = noder.getFirstID(id)
+ if status.crossed[firstID] then
+ return
+ end
+ status.crossed[firstID] = true
+ local tid = id .. (field or '')
+ for guri in files.eachFile() do
+ if not files.eq(uri, guri) then
+ crossSearch(status, guri, tid, mode)
+ end
+ end
+ end
+
+ local function checkClass(id, node, field)
+ if id:sub(1, 3) ~= 'dn:' then
+ return
+ end
+ local firstID = noder.getFirstID(id)
+ if status.crossed[firstID] then
+ return
+ end
+ status.crossed[firstID] = true
+ local tid = id .. (field or '')
+ for guri in files.eachFile() do
+ if not files.eq(uri, guri) then
+ crossSearch(status, guri, tid, mode)
+ end
+ end
+ end
+
+ local function checkMainReturn(id, node, field)
+ if id ~= 'mainreturn' then
+ return
+ end
+ if mode ~= 'ref' and not field then
+ return
+ end
+ local calls = vm.getLinksTo(uri)
+ for _, call in ipairs(calls) do
+ local turi = guide.getUri(call)
+ if not files.eq(turi, uri) then
+ local tid = noder.getID(call) .. (field or '')
+ crossSearch(status, turi, tid, mode)
+ end
+ end
+ end
+
+ local function searchNode(id, node, field)
+ if node.call then
+ callStack[#callStack+1] = node.call
+ end
+ if field == nil and node.sources then
+ for _, source in ipairs(node.sources) do
+ local force = genericCallArgs[source]
+ m.pushResult(status, mode, source, force)
+ end
+ end
+ if node.forward then
+ checkForward(id, node, field)
+ end
+ if node.backward then
+ checkBackward(id, node, field)
+ end
+
+ if node.sources then
+ checkGeneric(node.sources[1], field)
+ checkENV(node.sources[1], field)
+ end
+
+ if node.require then
+ checkRequire(node.require, field)
+ end
+
+ checkMainReturn(id, node, field)
+
+ if node.call then
+ callStack[#callStack] = nil
+ end
+ end
+
+ local function checkCrossUri(id, field)
+ local targetUri, newID = noder.getUriAndID(id)
+ if not targetUri then
+ return false
+ end
+ crossSearch(status, targetUri, newID .. (field or ''), mode)
+ return true
+ end
+
+ local stepCount = 0
+ function searchStep(id, field)
+ stepCount = stepCount + 1
+ if stepCount > 1000 then
+ error('too large')
+ end
+ local node = noder.getNodeByID(root, id)
+ if node then
+ searchNode(id, node, field)
+ end
+ checkGlobal(id, node, field)
+ checkClass(id, node, field)
+ local lastID = checkLastID(id, field)
+ if not lastID then
+ return
+ end
+ local originField = id:sub(#lastID + 1)
+ if originField == noder.TABLE_KEY then
+ return
+ end
+ local anyFieldID = lastID .. noder.ANY_FIELD
+ local anyFieldNode = noder.getNodeByID(root, anyFieldID)
+ if anyFieldNode then
+ searchNode(anyFieldID, anyFieldNode, field)
+ end
+ end
+
+ search(expect)
+
+ --清除来自泛型的临时对象
+ for _, closure in pairs(closureCache) do
+ noder.removeID(root, noder.getID(closure))
+ if closure then
+ for _, value in ipairs(closure.values) do
+ noder.removeID(root, noder.getID(value))
+ end
+ end
+ end
+end
+
+local function prepareSearch(source)
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ local root = guide.getRoot(source)
+ noder.compileNodes(root)
+ local uri = guide.getUri(source)
+ local id = noder.getID(source)
+ return uri, id
+end
+
+local function getField(status, source, mode)
+ if source.type == 'table' then
+ for _, field in ipairs(source) do
+ m.pushResult(status, mode, field)
+ end
+ return
+ end
+ if source.type == 'doc.class.name' then
+ local class = source.parent
+ for _, field in ipairs(class.fields) do
+ m.pushResult(status, mode, field.field)
+ end
+ return
+ end
+ local field = source.next
+ if field then
+ if field.type == 'getmethod'
+ or field.type == 'setmethod'
+ or field.type == 'getfield'
+ or field.type == 'setfield'
+ or field.type == 'getindex'
+ or field.type == 'setindex' then
+ m.pushResult(status, mode, field)
+ end
+ return
+ end
+end
+
+local function searchAllGlobalByUri(status, mode, uri, fullID)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local root = ast.ast
+ noder.compileNodes(root)
+ local noders = noder.getNoders(root)
+ if fullID then
+ for id, node in pairs(noders) do
+ if node.sources
+ and id == fullID then
+ for _, source in ipairs(node.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ end
+ else
+ for id, node in pairs(noders) do
+ if node.sources
+ and id:sub(1, 2) == 'g:'
+ and not id:find(noder.SPLIT_CHAR) then
+ for _, source in ipairs(node.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ end
+ end
+end
+
+local function searchAllGlobals(status, mode, fullID)
+ for uri in files.eachFile() do
+ searchAllGlobalByUri(status, mode, uri, fullID)
+ end
+end
+
+---搜索对象的引用
+---@param status guide.status
+---@param source parser.guide.object
+---@param mode guide.searchmode
+function m.searchRefs(status, source, mode)
+ local uri, id = prepareSearch(source)
+ if not id then
+ return
+ end
+ log.debug('searchRefs:', id)
+ m.searchRefsByID(status, uri, id, mode)
+end
+
+function m.findGlobals(uri, mode, name)
+ local status = m.status()
+
+ if name then
+ local fullID = ('g:%q'):format(name)
+ searchAllGlobalByUri(status, mode, uri, fullID)
+ else
+ searchAllGlobalByUri(status, mode, uri)
+ end
+
+ return status.results
+end
+
+---搜索对象的field
+---@param status guide.status
+---@param source parser.guide.object
+---@param mode guide.searchmode
+---@param field string
+function m.searchFields(status, source, mode, field)
+ local uri, id = prepareSearch(source)
+ if not id then
+ return
+ end
+ log.debug('searchFields:', id, field)
+ if field == '*' then
+ if source.special == '_G' then
+ searchAllGlobals(status, mode)
+ else
+ local newStatus = m.status(status)
+ m.searchRefsByID(newStatus, uri, id, mode)
+ for _, def in ipairs(newStatus.results) do
+ getField(status, def, mode)
+ end
+ end
+ else
+ if source.special == '_G' then
+ local fullID = ('g:%q'):format(field)
+ m.searchRefsByID(status, uri, fullID, mode)
+ else
+ local fullID = ('%s%s%q'):format(id, noder.SPLIT_CHAR, field)
+ m.searchRefsByID(status, uri, fullID, mode)
+ end
+ end
+end
+
+---@class guide.status
+---搜索结果
+---@field results parser.guide.object[]
+
+---创建搜索状态
+---@param parentStatus guide.status
+---@return guide.status
+function m.status(parentStatus)
+ local status = {
+ --mark = parentStatus and parentStatus.mark or {},
+ callStack = {},
+ crossed = {},
+ lock = {},
+ results = {},
+ }
+ return status
+end
+
+--- 请求对象的引用
+---@param obj parser.guide.object
+---@param field? string
+---@return parser.guide.object[]
+function m.requestReference(obj, field)
+ local status = m.status()
+
+ if field then
+ m.searchFields(status, obj, 'ref', field)
+ else
+ m.searchRefs(status, obj, 'ref')
+ end
+
+ return status.results
+end
+
+--- 请求对象的定义
+---@param obj parser.guide.object
+---@param field? string
+---@return parser.guide.object[]
+function m.requestDefinition(obj, field)
+ local status = m.status()
+
+ if field then
+ m.searchFields(status, obj, 'def', field)
+ else
+ m.searchRefs(status, obj, 'def')
+ end
+
+ return status.results
+end
+
+return m
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index f8feaa09..5e9ee9b1 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local await = require 'await'
local define = require 'proto.define'
local vm = require 'vm'
@@ -221,7 +221,7 @@ return function (uri, start, finish)
local results = {}
local count = 0
- guide.eachSourceBetween(ast.ast, start, finish, function (source)
+ searcher.eachSourceBetween(ast.ast, start, finish, function (source)
local method = Care[source.type]
if not method then
return
diff --git a/script/core/signature.lua b/script/core/signature.lua
index eb740784..915310c0 100644
--- a/script/core/signature.lua
+++ b/script/core/signature.lua
@@ -1,8 +1,9 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local vm = require 'vm'
local hoverLabel = require 'core.hover.label'
local hoverDesc = require 'core.hover.description'
+local guide = require 'parser.guide'
local function findNearCall(uri, ast, pos)
local text = files.getText(uri)
@@ -96,10 +97,10 @@ local function makeSignatures(call, pos)
index = 1
end
local signs = {}
- local defs = vm.getDefs(node, 0)
+ local defs = vm.getDefs(node)
local mark = {}
for _, src in ipairs(defs) do
- src = guide.getObjectValue(src) or src
+ src = searcher.getObjectValue(src) or src
if src.type == 'function'
or src.type == 'doc.type.function' then
if not mark[src] then
diff --git a/script/core/type-formatting.lua b/script/core/type-formatting.lua
index c2290ef3..49a721e5 100644
--- a/script/core/type-formatting.lua
+++ b/script/core/type-formatting.lua
@@ -1,6 +1,6 @@
local files = require 'files'
local lookBackward = require 'core.look-backward'
-local guide = require 'core.guide'
+local guide = require "parser.guide"
local function insertIndentation(uri, offset, edits)
local lines = files.getLines(uri)
diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua
index ae420d32..2df23a4d 100644
--- a/script/core/workspace-symbol.lua
+++ b/script/core/workspace-symbol.lua
@@ -1,5 +1,5 @@
local files = require 'files'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local matchKey = require 'core.matchkey'
local define = require 'proto.define'
local await = require 'await'
@@ -52,7 +52,7 @@ local function searchFile(uri, key, results)
return
end
- guide.eachSource(ast.ast, function (source)
+ searcher.eachSource(ast.ast, function (source)
buildSource(uri, source, key, results)
end)
end
diff --git a/script/files.lua b/script/files.lua
index 9cc6b549..bb143250 100644
--- a/script/files.lua
+++ b/script/files.lua
@@ -9,7 +9,7 @@ local await = require 'await'
local timer = require 'timer'
local plugin = require 'plugin'
local util = require 'utility'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local smerger = require 'string-merger'
local progress = require "progress"
@@ -345,6 +345,7 @@ function m.getAllUris()
i = i + 1
files[i] = uri
end
+ table.sort(files)
end
return m._pairsCache
end
diff --git a/script/parser/ast.lua b/script/parser/ast.lua
index 45d77631..40b5788e 100644
--- a/script/parser/ast.lua
+++ b/script/parser/ast.lua
@@ -110,7 +110,7 @@ local function getSelect(vararg, index)
start = vararg.start,
finish = vararg.finish,
vararg = vararg,
- index = index,
+ sindex = index,
}
end
@@ -1460,8 +1460,14 @@ local Defs = {
local values
if func then
local call = createCall(exp, func.finish + 1, exp.finish)
+ if #exp == 0 then
+ exp[1] = getSelect(func, 2)
+ exp[2] = getSelect(func, 3)
+ exp[3] = getSelect(func, 4)
+ end
call.node = func
- call.start = func.start
+ call.start = inA
+ call.finish = doB - 1
func.next = call
func.iterator = true
values = { call }
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index a7e0dc1f..21be406d 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -125,6 +125,7 @@ local vmMap = {
vararg.ref = {}
end
vararg.ref[#vararg.ref+1] = obj
+ obj.node = vararg
end
end
end,
@@ -150,8 +151,8 @@ local vmMap = {
local value = obj.value
local localself = {
type = 'local',
- start = 0,
- finish = 0,
+ start = value.start,
+ finish = value.finish,
method = obj,
effect = obj.finish,
tag = 'self',
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 2369e84f..8d2708cf 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -1,34 +1,8 @@
-local util = require 'utility'
local error = error
local type = type
-local next = next
-local tostring = tostring
-local print = print
-local ipairs = ipairs
-local tableInsert = table.insert
-local tableUnpack = table.unpack
-local tableRemove = table.remove
-local tableMove = table.move
-local tableSort = table.sort
-local tableConcat = table.concat
-local mathType = math.type
-local pairs = pairs
-local setmetatable = setmetatable
-local assert = assert
-local select = select
-local osClock = os.clock
-local tonumber = tonumber
-local tointeger = math.tointeger
-local DEVELOP = _G.DEVELOP
-local log = log
-local _G = _G
---@class parser.guide.object
-local function logWarn(...)
- log.warn(...)
-end
-
---@class guide
---@field debugMode boolean
local m = {}
@@ -91,7 +65,7 @@ m.childMap = {
['doc'] = {'#'},
['doc.class'] = {'class', '#extends', 'comment'},
- ['doc.type'] = {'#types', '#enums', 'name', 'comment'},
+ ['doc.type'] = {'#types', '#enums', '#resumes', 'name', 'comment'},
['doc.alias'] = {'alias', 'extends', 'comment'},
['doc.param'] = {'param', 'extends', 'comment'},
['doc.return'] = {'#returns', 'comment'},
@@ -100,9 +74,9 @@ m.childMap = {
['doc.generic.object'] = {'generic', 'extends', 'comment'},
['doc.vararg'] = {'vararg', 'comment'},
['doc.type.array'] = {'node'},
- ['doc.type.table'] = {'node', 'key', 'value', 'comment'},
+ ['doc.type.table'] = {'tkey', 'tvalue', 'comment'},
['doc.type.function'] = {'#args', '#returns', 'comment'},
- ['doc.type.typeliteral'] = {'node'},
+ ['doc.type.literal'] = {'node'},
['doc.type.arg'] = {'extends'},
['doc.overload'] = {'overload', 'comment'},
['doc.see'] = {'name', 'field'},
@@ -123,19 +97,31 @@ m.actionMap = {
['funcargs'] = {'#'},
}
-local TypeSort = {
- ['boolean'] = 1,
- ['string'] = 2,
- ['integer'] = 3,
- ['number'] = 4,
- ['table'] = 5,
- ['function'] = 6,
- ['true'] = 101,
- ['false'] = 102,
- ['nil'] = 999,
-}
+local inf = 1 / 0
+local nan = 0 / 0
+
+local function isInteger(n)
+ if math.type then
+ return math.type(n) == 'integer'
+ else
+ return type(n) == 'number' and n % 1 == 0
+ end
+end
-local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
+local function formatNumber(n)
+ if n == inf
+ or n == -inf
+ or n == nan
+ or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则
+ return ('%q'):format(n)
+ end
+ if isInteger(n) then
+ return tostring(n)
+ end
+ local str = ('%.10f'):format(n)
+ str = str:gsub('%.?0*$', '')
+ return str
+end
--- 是否是字面量
---@param obj parser.guide.object
@@ -182,23 +168,6 @@ function m.getParentFunction(obj)
return nil
end
---- 寻找父的table类型 doc.type.table
----@param obj parser.guide.object
----@return parser.guide.object
-function m.getParentDocTypeTable(obj)
- for _ = 1, 1000 do
- local parent = obj.parent
- if not parent then
- return nil
- end
- if parent.type == 'doc.type.table' then
- return obj
- end
- obj = parent
- end
- error('guide.getParentDocTypeTable overstack')
-end
-
--- 寻找所在区块
---@param obj parser.guide.object
---@return parser.guide.object
@@ -293,10 +262,19 @@ end
---@param obj parser.guide.object
---@return parser.guide.object
function m.getRoot(obj)
+ local source = obj
+ if source._root then
+ return source._root
+ end
for _ = 1, 1000 do
if obj.type == 'main' then
+ source._root = obj
return obj
end
+ if obj._root then
+ source._root = obj._root
+ return source._root
+ end
local parent = obj.parent
if not parent then
return nil
@@ -501,8 +479,8 @@ function m.addChilds(list, obj, map)
for i = 1, #keys do
local key = keys[i]
if key == '#' then
- for i = 1, #obj do
- list[#list+1] = obj[i]
+ for j = 1, #obj do
+ list[#list+1] = obj[j]
end
elseif obj[key] then
list[#list+1] = obj[key]
@@ -510,8 +488,8 @@ function m.addChilds(list, obj, map)
and key:sub(1, 1) == '#' then
key = key:sub(2)
if obj[key] then
- for i = 1, #obj[key] do
- list[#list+1] = obj[key][i]
+ for j = 1, #obj[key] do
+ list[#list+1] = obj[key][j]
end
end
end
@@ -613,9 +591,16 @@ function m.eachSource(ast, callback)
index = index + 1
if not mark[obj] then
mark[obj] = true
- callback(obj)
+ local res = callback(obj)
+ if res == true then
+ goto CONTINUE
+ end
+ if res == false then
+ return
+ end
m.addChilds(list, obj, m.childMap)
end
+ ::CONTINUE::
end
end
@@ -718,4 +703,288 @@ function m.lineData(lines, row)
return lines[row]
end
+function m.isSet(source)
+ local tp = source.type
+ if tp == 'setglobal'
+ or tp == 'local'
+ or tp == 'setlocal'
+ or tp == 'setfield'
+ or tp == 'setmethod'
+ or tp == 'setindex'
+ or tp == 'tablefield'
+ or tp == 'tableindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(source.node)
+ if special == 'rawset' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.isGet(source)
+ local tp = source.type
+ if tp == 'getglobal'
+ or tp == 'getlocal'
+ or tp == 'getfield'
+ or tp == 'getmethod'
+ or tp == 'getindex' then
+ return true
+ end
+ if tp == 'call' then
+ local special = m.getSpecial(source.node)
+ if special == 'rawget' then
+ return true
+ end
+ end
+ return false
+end
+
+function m.getSpecial(source)
+ if not source then
+ return nil
+ end
+ return source.special
+end
+
+function m.getKeyNameOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'field'
+ or tp == 'method' then
+ return obj[1]
+ elseif tp == 'string' then
+ local s = obj[1]
+ if s then
+ return s
+ end
+ elseif tp == 'number' then
+ local n = obj[1]
+ if n then
+ return ('%s'):format(formatNumber(obj[1]))
+ end
+ elseif tp == 'boolean' then
+ local b = obj[1]
+ if b then
+ return tostring(b)
+ end
+ end
+end
+
+function m.getKeyName(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return obj[1]
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return obj[1]
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ if obj.field then
+ return obj.field[1]
+ end
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ if obj.method then
+ return obj.method[1]
+ end
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getKeyNameOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method'
+ or tp == 'doc.see.field' then
+ return obj[1]
+ elseif tp == 'doc.class' then
+ return obj.class[1]
+ elseif tp == 'doc.alias' then
+ return obj.alias[1]
+ elseif tp == 'doc.field' then
+ return obj.field[1]
+ elseif tp == 'doc.field.name' then
+ return obj[1]
+ elseif tp == 'dummy' then
+ return obj[1]
+ end
+ return m.getKeyNameOfLiteral(obj)
+end
+
+function m.getKeyTypeOfLiteral(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'field'
+ or tp == 'method' then
+ return 'string'
+ elseif tp == 'string' then
+ return 'string'
+ elseif tp == 'number' then
+ return 'number'
+ elseif tp == 'boolean' then
+ return 'boolean'
+ end
+end
+
+function m.getKeyType(obj)
+ if not obj then
+ return nil
+ end
+ local tp = obj.type
+ if tp == 'getglobal'
+ or tp == 'setglobal' then
+ return 'string'
+ elseif tp == 'local'
+ or tp == 'getlocal'
+ or tp == 'setlocal' then
+ return 'local'
+ elseif tp == 'getfield'
+ or tp == 'setfield'
+ or tp == 'tablefield' then
+ return 'string'
+ elseif tp == 'getmethod'
+ or tp == 'setmethod' then
+ return 'string'
+ elseif tp == 'getindex'
+ or tp == 'setindex'
+ or tp == 'tableindex' then
+ return m.getKeyTypeOfLiteral(obj.index)
+ elseif tp == 'field'
+ or tp == 'method'
+ or tp == 'doc.see.field' then
+ return 'string'
+ elseif tp == 'doc.class' then
+ return 'string'
+ elseif tp == 'doc.alias' then
+ return 'string'
+ elseif tp == 'doc.field' then
+ return 'string'
+ elseif tp == 'dummy' then
+ return 'string'
+ end
+ if tp == 'doc.field.name' then
+ return 'string'
+ end
+ return m.getKeyTypeOfLiteral(obj)
+end
+
+--- 测试 a 到 b 的路径(不经过函数,不考虑 goto),
+--- 每个路径是一个 block 。
+---
+--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>`
+---
+--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>`
+---
+--- 否则返回 `false`
+---
+--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。
+---@param a table
+---@param b table
+---@return string|boolean mode
+---@return table pathA?
+---@return table pathB?
+function m.getPath(a, b, sameFunction)
+ --- 首先测试双方在同一个函数内
+ if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then
+ return false
+ end
+ local mode
+ local objA
+ local objB
+ if a.finish < b.start then
+ mode = 'before'
+ objA = a
+ objB = b
+ elseif a.start > b.finish then
+ mode = 'after'
+ objA = b
+ objB = a
+ else
+ return 'equal', {}, {}
+ end
+ local pathA = {}
+ local pathB = {}
+ for _ = 1, 1000 do
+ objA = m.getParentBlock(objA)
+ pathA[#pathA+1] = objA
+ if (not sameFunction and objA.type == 'function') or objA.type == 'main' then
+ break
+ end
+ end
+ for _ = 1, 1000 do
+ objB = m.getParentBlock(objB)
+ pathB[#pathB+1] = objB
+ if (not sameFunction and objA.type == 'function') or objB.type == 'main' then
+ break
+ end
+ end
+ -- pathA: {1, 2, 3, 4, 5}
+ -- pathB: {5, 6, 2, 3}
+ local top = #pathB
+ local start
+ for i = #pathA, 1, -1 do
+ local currentBlock = pathA[i]
+ if currentBlock == pathB[top] then
+ start = i
+ break
+ end
+ end
+ if not start then
+ return nil
+ end
+ -- pathA: { 1, 2, 3}
+ -- pathB: {5, 6, 2, 3}
+ local extra = 0
+ local align = top - start
+ for i = start, 1, -1 do
+ local currentA = pathA[i]
+ local currentB = pathB[i+align]
+ if currentA ~= currentB then
+ extra = i
+ break
+ end
+ end
+ -- pathA: {1}
+ local resultA = {}
+ for i = extra, 1, -1 do
+ resultA[#resultA+1] = pathA[i]
+ end
+ -- pathB: {5, 6}
+ local resultB = {}
+ for i = extra + align, 1, -1 do
+ resultB[#resultB+1] = pathB[i]
+ end
+ return mode, resultA, resultB
+end
+
+---是否是全局变量(包括 _G.XXX 形式)
+---@param source parser.guide.object
+---@return boolean
+function m.isGlobal(source)
+ if source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ if source.node and source.node.tag == '_ENV' then
+ return true
+ end
+ end
+ if source.type == 'field' then
+ source = source.parent
+ end
+ if source.special == '_G' then
+ return true
+ end
+ return false
+end
+
return m
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index ae8e3f34..335c8f24 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1,7 +1,7 @@
local m = require 'lpeglabel'
local re = require 'parser.relabel'
local lines = require 'parser.lines'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local grammar = require 'parser.grammar'
local TokenTypes, TokenStarts, TokenFinishs, TokenContents
@@ -194,6 +194,7 @@ local function parseClass(parent)
local result = {
type = 'doc.class',
parent = parent,
+ fields = {},
}
result.class = parseName('doc.class.name', result)
if not result.class then
@@ -300,8 +301,8 @@ local function parseTypeUnitTable(parent, node)
node.parent = result;
result.finish = getFinish()
- result.key = key
- result.value = value
+ result.tkey = key
+ result.tvalue = value
return result
end
@@ -425,9 +426,10 @@ local function parseTypeUnit(parent, content)
return result
end
-local function parseResume()
+local function parseResume(parent)
local result = {
- type = 'doc.resume'
+ type = 'doc.resume',
+ parent = parent,
}
if checkToken('symbol', '>', 1) then
@@ -456,7 +458,6 @@ local function parseResume()
return result
end
-local LastType
function parseType(parent)
local result = {
type = 'doc.type',
@@ -484,13 +485,7 @@ function parseType(parent)
break
end
-- TypeLiteral,指代类型的字面值。比如,对于类 Cat 来说,它的 TypeLiteral 是 "Cat"
- typeLiteral = {
- type = 'doc.type.typeliteral',
- parent = result,
- start = getStart(),
- finish = nil,
- node = nil,
- }
+ typeLiteral = true
end
if tp == 'name' then
@@ -501,10 +496,7 @@ function parseType(parent)
end
if typeLiteral then
nextToken()
- typeLiteral.finish = getFinish()
- typeLiteral.node = typeUnit
- typeUnit.parent = typeLiteral
- typeUnit = typeLiteral
+ typeUnit.literal = true
end
result.types[#result.types+1] = typeUnit
if not result.start then
@@ -566,7 +558,7 @@ function parseType(parent)
row = row + i + 1
local finishPos = nextComm.text:find('#', 3) or #nextComm.text
parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1)
- local resume = parseResume()
+ local resume = parseResume(result)
if resume then
if comments then
resume.comment = table.concat(comments, '\n')
@@ -1122,17 +1114,25 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
end
local src = sources[index]
if src.start < start then
- left = index
+ left = index + 1
else
right = index
end
end
- for i = index - 1, max do
+
+ -- 从前往后进行绑定
+ for i = index, max do
local src = sources[i]
if src then
if src.start > finish then
break
end
+ -- 遇到table后中断,处理以下情况:
+ -- ---@type AAA
+ -- local t = {x = 1, y = 2}
+ if src.type == 'table' then
+ break
+ end
if src.start >= start then
src.bindDocs = binded
bindSources[#bindSources+1] = src
@@ -1152,21 +1152,22 @@ local function bindParamAndReturnIndex(binded)
if not func then
return
end
- if not func.args then
- return
- end
- local paramIndex = 0
- local paramMap = {}
- for _, param in ipairs(func.args) do
- paramIndex = paramIndex + 1
- if param[1] then
- paramMap[param[1]] = paramIndex
+ local paramMap
+ if func.args then
+ local paramIndex = 0
+ paramMap = {}
+ for _, param in ipairs(func.args) do
+ paramIndex = paramIndex + 1
+ if param[1] then
+ paramMap[param[1]] = paramIndex
+ end
end
+ func.docParamMap = paramMap
end
local returnIndex = 0
for _, doc in ipairs(binded) do
if doc.type == 'doc.param' then
- if doc.extends then
+ if paramMap and doc.extends then
doc.extends.paramIndex = paramMap[doc.param[1]]
end
elseif doc.type == 'doc.return' then
@@ -1178,6 +1179,24 @@ local function bindParamAndReturnIndex(binded)
end
end
+local function bindClassAndFields(binded)
+ local class
+ for _, doc in ipairs(binded) do
+ if doc.type == 'doc.class' then
+ -- 多个class连续写在一起,只有最后一个class可以绑定source
+ if class then
+ class.bindSources = nil
+ end
+ class = doc
+ elseif doc.type == 'doc.field' then
+ if class then
+ class.fields[#class.fields+1] = doc
+ doc.class = class
+ end
+ end
+ end
+end
+
local function bindDoc(sources, lns, binded)
if not binded then
return
@@ -1200,6 +1219,7 @@ local function bindDoc(sources, lns, binded)
bindDocsBetween(sources, binded, bindSources, nstart, nfinish)
end
bindParamAndReturnIndex(binded)
+ bindClassAndFields(binded)
end
local function bindDocs(state)
@@ -1214,6 +1234,7 @@ local function bindDocs(state)
or src.type == 'tablefield'
or src.type == 'tableindex'
or src.type == 'function'
+ or src.type == 'table'
or src.type == '...' then
sources[#sources+1] = src
end
diff --git a/script/vm/eachDef.lua b/script/vm/eachDef.lua
index d72c8f01..6f7af295 100644
--- a/script/vm/eachDef.lua
+++ b/script/vm/eachDef.lua
@@ -1,49 +1,7 @@
---@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local files = require 'files'
-local util = require 'utility'
-local await = require 'await'
-local config = require 'config'
+local vm = require 'vm.vm'
+local searcher = require 'core.searcher'
-local function getDefs(source, deep)
- local results = {}
- local lock = vm.lock('eachDef', source)
- if not lock then
- return results
- end
-
- await.delay()
-
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- local clock = os.clock()
- local myResults, count = guide.requestDefinition(source, vm.interface, deep)
- if DEVELOP and os.clock() - clock > 0.1 then
- log.warn('requestDefinition', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
- end
- vm.mergeResults(results, myResults)
-
- lock()
-
- return results
-end
-
-function vm.getDefs(source, deep)
- deep = deep or -999
- if guide.isGlobal(source) then
- local key = guide.getKeyName(source)
- if not key then
- return {}
- end
- return vm.getGlobalSets(key)
- else
- local cache = vm.getCache('eachDef')[source]
- if not cache or cache.deep < deep then
- cache = getDefs(source, deep)
- cache.deep = deep
- vm.getCache('eachDef')[source] = cache
- end
- return cache
- end
+function vm.getDefs(source, field)
+ return searcher.requestDefinition(source, field)
end
diff --git a/script/vm/eachField.lua b/script/vm/eachField.lua
deleted file mode 100644
index 59f35f0c..00000000
--- a/script/vm/eachField.lua
+++ /dev/null
@@ -1,109 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local await = require 'await'
-local config = require 'config'
-
-local function getFields(source, deep, filterKey)
- local unlock = vm.lock('eachField', source)
- if not unlock then
- return {}
- end
-
- while source.type == 'paren' do
- source = source.exp
- if not source then
- return {}
- end
- end
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- await.delay()
- local results = guide.requestFields(source, vm.interface, deep, filterKey)
-
- unlock()
- return results
-end
-
-local function getDefFields(source, deep, filterKey)
- local unlock = vm.lock('eachDefField', source)
- if not unlock then
- return {}
- end
-
- while source.type == 'paren' do
- source = source.exp
- if not source then
- return {}
- end
- end
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- await.delay()
- local results = guide.requestDefFields(source, vm.interface, deep, filterKey)
-
- unlock()
- return results
-end
-
-local function getFieldsBySource(source, deep, filterKey)
- deep = deep or -999
- local cache = vm.getCache('eachField')[source]
- if not cache or cache.deep < deep then
- cache = getFields(source, deep, filterKey)
- cache.deep = deep
- if not filterKey then
- vm.getCache('eachField')[source] = cache
- end
- end
- return cache
-end
-
-local function getDefFieldsBySource(source, deep, filterKey)
- deep = deep or -999
- local cache = vm.getCache('eachDefField')[source]
- if not cache or cache.deep < deep then
- cache = getDefFields(source, deep, filterKey)
- cache.deep = deep
- if not filterKey then
- vm.getCache('eachDefField')[source] = cache
- end
- end
- return cache
-end
-
-function vm.getFields(source, deep)
- if source.special == '_G' then
- return vm.getGlobals '*'
- end
- if guide.isGlobal(source) then
- local name = guide.getKeyName(source)
- if not name then
- return {}
- end
- local cache = vm.getCache('eachFieldOfGlobal')[name]
- or getFieldsBySource(source, deep)
- vm.getCache('eachFieldOfGlobal')[name] = cache
- return cache
- else
- return getFieldsBySource(source, deep)
- end
-end
-
-function vm.getDefFields(source, deep)
- if source.special == '_G' then
- return vm.getGlobalSets '*'
- end
- if guide.isGlobal(source) then
- local name = guide.getKeyName(source)
- if not name then
- return {}
- end
- local cache = vm.getCache('eachDefFieldOfGlobal')[name]
- or getDefFieldsBySource(source, deep)
- vm.getCache('eachDefFieldOfGlobal')[name] = cache
- return cache
- else
- return getDefFieldsBySource(source, deep)
- end
-end
diff --git a/script/vm/eachRef.lua b/script/vm/eachRef.lua
index 9d0f061c..5aca198e 100644
--- a/script/vm/eachRef.lua
+++ b/script/vm/eachRef.lua
@@ -1,48 +1,7 @@
---@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local util = require 'utility'
-local await = require 'await'
-local config = require 'config'
+local vm = require 'vm.vm'
+local searcher = require 'core.searcher'
-local function getRefs(source, deep)
- local results = {}
- local lock = vm.lock('eachRef', source)
- if not lock then
- return results
- end
-
- await.delay()
-
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- local clock = os.clock()
- local myResults, count = guide.requestReference(source, vm.interface, deep)
- if DEVELOP and os.clock() - clock > 0.1 then
- log.warn('requestReference', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
- end
- vm.mergeResults(results, myResults)
-
- lock()
-
- return results
-end
-
-function vm.getRefs(source, deep)
- deep = deep or -999
- if guide.isGlobal(source) then
- local key = guide.getKeyName(source)
- if not key then
- return {}
- end
- return vm.getGlobals(key)
- else
- local cache = vm.getCache('eachRef')[source]
- if not cache or cache.deep < deep then
- cache = getRefs(source, deep)
- cache.deep = deep
- vm.getCache('eachRef')[source] = cache
- end
- return cache
- end
+function vm.getRefs(source, field)
+ return searcher.requestReference(source, field)
end
diff --git a/script/vm/getClass.lua b/script/vm/getClass.lua
deleted file mode 100644
index 5c68e0bb..00000000
--- a/script/vm/getClass.lua
+++ /dev/null
@@ -1,64 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-
-local function lookUpDocClass(source)
- local infers = vm.getInfers(source, 0)
- for _, infer in ipairs(infers) do
- if infer.source.type == 'doc.class'
- or infer.source.type == 'doc.type' then
- return guide.viewInferType(infers)
- end
- end
- return nil
-end
-
-local function getClass(source, classes, depth, deep)
- local docClass = lookUpDocClass(source)
- if docClass then
- classes[docClass] = true
- return
- end
- if depth > 3 then
- return
- end
- local value = guide.getObjectValue(source) or source
- if not deep then
- if value and value.type == 'string' then
- classes[value[1]] = true
- end
- else
- for _, src in ipairs(vm.getDefFields(value)) do
- local key = vm.getKeyName(src)
- if not key then
- goto CONTINUE
- end
- local lkey = key:lower()
- if lkey == 'type'
- or lkey == '__name'
- or lkey == 'name'
- or lkey == 'class' then
- local value = guide.getObjectValue(src)
- if value and value.type == 'string' then
- classes[value[1]] = true
- end
- end
- ::CONTINUE::
- end
- end
- if next(classes) then
- return
- end
- vm.eachMeta(source, function (mt)
- getClass(mt, classes, depth + 1, deep)
- end)
-end
-
-function vm.getClass(source, deep)
- local classes = {}
- getClass(source, classes, 1, deep)
- if not next(classes) then
- return nil
- end
- return guide.mergeTypes(classes)
-end
diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua
index cfa9326f..dbb8b4fd 100644
--- a/script/vm/getDocs.lua
+++ b/script/vm/getDocs.lua
@@ -1,148 +1,51 @@
-local files = require 'files'
-local util = require 'utility'
-local guide = require 'core.guide'
+local files = require 'files'
+local guide = require 'parser.guide'
---@type vm
-local vm = require 'vm.vm'
-local config = require 'config'
+local vm = require 'vm.vm'
+local config = require 'config'
+local searcher = require 'core.searcher'
-local function getTypesOfFile(uri)
- local types = {}
- local ast = files.getAst(uri)
- if not ast or not ast.ast.docs then
- return types
- end
- guide.eachSource(ast.ast.docs, function (src)
- if src.type == 'doc.type.name'
- or src.type == 'doc.class.name'
- or src.type == 'doc.extends.name'
- or src.type == 'doc.alias.name' then
- if src.type == 'doc.type.name' then
- if guide.getParentDocTypeTable(src) then
- return
- end
+local function getDocDefinesInAst(results, root, name)
+ for _, doc in ipairs(root.docs) do
+ if doc.type == 'doc.class' then
+ if not name or name == doc.class[1] then
+ results[#results+1] = doc.class
end
- local name = src[1]
- if name then
- if not types[name] then
- types[name] = {}
- end
- types[name][#types[name]+1] = src
+ elseif doc.type == 'doc.alias' then
+ if not name or name == doc.alias[1] then
+ results[#results+1] = doc.alias
end
end
- end)
- return types
+ end
end
-local function getDocTypes(name)
+---获取class与alias
+---@param name? string
+---@return parser.guide.object[]
+function vm.getDocDefines(name)
local results = {}
- if name == 'any'
- or name == 'nil' then
- return results
- end
for uri in files.eachFile() do
- local cache = files.getCache(uri)
- cache.classes = cache.classes or getTypesOfFile(uri)
- if name == '*' then
- for _, sources in util.sortPairs(cache.classes) do
- for _, source in ipairs(sources) do
- results[#results+1] = source
- end
- end
- else
- if cache.classes[name] then
- for _, source in ipairs(cache.classes[name]) do
- results[#results+1] = source
- end
- end
- end
+ local ast = files.getAst(uri)
+ getDocDefinesInAst(results, ast.ast, name)
end
return results
end
-function vm.getDocEnums(doc, mark, results)
+function vm.getDocEnums(doc)
if not doc then
return nil
end
- mark = mark or {}
- if mark[doc] then
- return nil
- end
- mark[doc] = true
- results = results or {}
- for _, enum in ipairs(doc.enums) do
- results[#results+1] = enum
- end
- for _, resume in ipairs(doc.resumes) do
- results[#results+1] = resume
- end
- for _, unit in ipairs(doc.types) do
- if unit.type == 'doc.type.name' then
- for _, other in ipairs(vm.getDocTypes(unit[1])) do
- if other.type == 'doc.alias.name' then
- vm.getDocEnums(other.parent.extends, mark, results)
- end
- end
- end
- end
- return results
-end
+ local defs = searcher.requestDefinition(doc)
+ local results = {}
-function vm.getDocTypeUnits(doc, mark, results)
- if not doc then
- return nil
- end
- mark = mark or {}
- if mark[doc] then
- return nil
- end
- mark[doc] = true
- results = results or {}
- for _, enum in ipairs(doc.enums) do
- results[#results+1] = enum
- end
- for _, resume in ipairs(doc.resumes) do
- results[#results+1] = resume
- end
- for _, unit in ipairs(doc.types) do
- if unit.type == 'doc.type.name' then
- for _, other in ipairs(vm.getDocTypes(unit[1])) do
- if other.type == 'doc.alias.name' then
- vm.getDocTypeUnits(other.parent.extends, mark, results)
- elseif other.type == 'doc.class.name' then
- results[#results+1] = other
- end
- end
- else
- results[#results+1] = unit
+ for _, def in ipairs(defs) do
+ if def.type == 'doc.type.enum'
+ or def.type == 'doc.resume' then
+ results[#results+1] = def
end
end
- return results
-end
-
-function vm.getDocTypes(name)
- local cache = vm.getCache('getDocTypes')[name]
- if cache ~= nil then
- return cache
- end
- cache = getDocTypes(name)
- vm.getCache('getDocTypes')[name] = cache
- return cache
-end
-function vm.getDocClass(name)
- local cache = vm.getCache('getDocClass')[name]
- if cache ~= nil then
- return cache
- end
- cache = {}
- local results = getDocTypes(name)
- for _, doc in ipairs(results) do
- if doc.type == 'doc.class.name' then
- cache[#cache+1] = doc
- end
- end
- vm.getCache('getDocClass')[name] = cache
- return cache
+ return results
end
function vm.isMetaFile(uri)
@@ -224,7 +127,7 @@ end
function vm.isDeprecated(value, deep)
if deep then
- local defs = vm.getDefs(value, 0)
+ local defs = vm.getDefs(value)
if #defs == 0 then
return false
end
diff --git a/script/vm/getGlobals.lua b/script/vm/getGlobals.lua
index 2752ce09..bea192ef 100644
--- a/script/vm/getGlobals.lua
+++ b/script/vm/getGlobals.lua
@@ -1,5 +1,6 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local await = require "await"
+local searcher = require "core.searcher"
---@type vm
local vm = require 'vm.vm'
local files = require 'files'
@@ -17,12 +18,8 @@ local function getGlobalsOfFile(uri)
end
local globals = {}
cache.globals = globals
- local ast = files.getAst(uri)
- if not ast then
- return globals
- end
tracy.ZoneBeginN 'getGlobalsOfFile'
- local results = guide.findGlobals(ast.ast)
+ local results = searcher.findGlobals(uri)
local subscribe = ws.getCache 'globalSubscribe'
subscribe[uri] = {}
local mark = {}
@@ -34,7 +31,7 @@ local function getGlobalsOfFile(uri)
goto CONTINUE
end
mark[res] = true
- local name = guide.getSimpleName(res)
+ local name = guide.getKeyName(res)
if name then
if not globals[name] then
globals[name] = {}
@@ -59,12 +56,8 @@ local function getGlobalSetsOfFile(uri)
end
local globals = {}
cache.globalSets = globals
- local ast = files.getAst(uri)
- if not ast then
- return globals
- end
tracy.ZoneBeginN 'getGlobalSetsOfFile'
- local results = guide.findGlobals(ast.ast)
+ local results = searcher.findGlobals(uri, 'def')
local subscribe = ws.getCache 'globalSetsSubscribe'
subscribe[uri] = {}
local mark = {}
@@ -76,16 +69,14 @@ local function getGlobalSetsOfFile(uri)
goto CONTINUE
end
mark[res] = true
- if vm.isSet(res) then
- local name = guide.getSimpleName(res)
- if name then
- if not globals[name] then
- globals[name] = {}
- subscribe[uri][#subscribe[uri]+1] = name
- end
- globals[name][#globals[name]+1] = res
- globals['*'][#globals['*']+1] = res
+ local name = guide.getKeyName(res)
+ if name then
+ if not globals[name] then
+ globals[name] = {}
+ subscribe[uri][#subscribe[uri]+1] = name
end
+ globals[name][#globals[name]+1] = res
+ globals['*'][#globals['*']+1] = res
end
::CONTINUE::
end
@@ -265,7 +256,7 @@ files.watch(function (ev, uri)
end
needUpdateGlobals[uri] = true
elseif ev == 'create' then
- getGlobalsOfFile(uri)
- getGlobalSetsOfFile(uri)
+ --getGlobalsOfFile(uri)
+ --getGlobalSetsOfFile(uri)
end
end)
diff --git a/script/vm/getInfer.lua b/script/vm/getInfer.lua
deleted file mode 100644
index 5447ca23..00000000
--- a/script/vm/getInfer.lua
+++ /dev/null
@@ -1,104 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-local guide = require 'core.guide'
-local util = require 'utility'
-local await = require 'await'
-local config = require 'config'
-
-NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
-
---- 是否包含某种类型
-function vm.hasType(source, type, deep)
- local defs = vm.getDefs(source, deep)
- for i = 1, #defs do
- local def = defs[i]
- local value = guide.getObjectValue(def) or def
- if value.type == type then
- return true
- end
- end
- return false
-end
-
---- 是否包含某种类型
-function vm.hasInferType(source, type, deep)
- local infers = vm.getInfers(source, deep)
- for i = 1, #infers do
- local infer = infers[i]
- if infer.type == type then
- return true
- end
- end
- return false
-end
-
-function vm.getInferType(source, deep)
- local infers = vm.getInfers(source, deep)
- return guide.viewInferType(infers)
-end
-
-function vm.getInferLiteral(source, deep)
- local infers = vm.getInfers(source, deep)
- local literals = {}
- local mark = {}
- for _, infer in ipairs(infers) do
- local value = infer.value
- if value and not mark[value] then
- mark[value] = true
- literals[#literals+1] = util.viewLiteral(value)
- end
- end
- if #literals == 0 then
- return nil
- end
- table.sort(literals)
- return table.concat(literals, '|')
-end
-
-local function getInfers(source, deep)
- local results = {}
- local lock = vm.lock('getInfers', source)
- if not lock then
- return results
- end
-
- deep = config.config.intelliSense.searchDepth + (deep or 0)
-
- await.delay()
-
- local clock = os.clock()
- local myResults, count = guide.requestInfer(source, vm.interface, deep)
- if DEVELOP and os.clock() - clock > 0.1 then
- log.warn('requestInfer', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 }))
- end
- vm.mergeResults(results, myResults)
-
- lock()
-
- return results
-end
-
-local function getInfersBySource(source, deep)
- deep = deep or -999
- local cache = vm.getCache('getInfers')[source]
- if not cache or cache.deep < deep then
- cache = getInfers(source, deep)
- cache.deep = deep
- vm.getCache('getInfers')[source] = cache
- end
- return cache
-end
-
---- 获取对象的值
---- 会尝试穿透函数调用
-function vm.getInfers(source, deep)
- if guide.isGlobal(source) then
- local name = guide.getKeyName(source)
- local cache = vm.getCache('getInfersOfGlobal')[name]
- or getInfersBySource(source, deep)
- vm.getCache('getInfersOfGlobal')[name] = cache
- return cache
- else
- return getInfersBySource(source, deep)
- end
-end
diff --git a/script/vm/getLibrary.lua b/script/vm/getLibrary.lua
index b52f7240..a3c8feb0 100644
--- a/script/vm/getLibrary.lua
+++ b/script/vm/getLibrary.lua
@@ -1,8 +1,11 @@
---@type vm
local vm = require 'vm.vm'
-function vm.getLibraryName(source, deep)
- local defs = vm.getDefs(source, deep)
+function vm.getLibraryName(source)
+ if source.special then
+ return source.special
+ end
+ local defs = vm.getDefs(source)
for _, def in ipairs(defs) do
if def.special then
return def.special
diff --git a/script/vm/getLinks.lua b/script/vm/getLinks.lua
index 91a5f1a0..51a18d58 100644
--- a/script/vm/getLinks.lua
+++ b/script/vm/getLinks.lua
@@ -1,5 +1,4 @@
-local guide = require 'core.guide'
----@type vm
+local guide = require 'parser.guide'
local vm = require 'vm.vm'
local files = require 'files'
@@ -33,11 +32,17 @@ local function getFileLinks(uri)
return links
end
+local function getFileLinksOrCache(uri)
+ local cache = files.getCache(uri)
+ cache.links = cache.links or getFileLinks(uri)
+ return cache.links
+end
+
local function getLinksTo(uri)
uri = files.asKey(uri)
local links = {}
for u in files.eachFile() do
- local ls = vm.getFileLinks(u)
+ local ls = getFileLinksOrCache(u)
if ls[uri] then
for _, l in ipairs(ls[uri]) do
links[#links+1] = l
@@ -56,9 +61,3 @@ function vm.getLinksTo(uri)
vm.getCache('getLinksTo')[uri] = cache
return cache
end
-
-function vm.getFileLinks(uri)
- local cache = files.getCache(uri)
- cache.links = cache.links or getFileLinks(uri)
- return cache.links
-end
diff --git a/script/vm/getMeta.lua b/script/vm/getMeta.lua
deleted file mode 100644
index 44d1874a..00000000
--- a/script/vm/getMeta.lua
+++ /dev/null
@@ -1,53 +0,0 @@
----@type vm
-local vm = require 'vm.vm'
-
-local function eachMetaOfArg1(source, callback)
- local node, index = vm.getArgInfo(source)
- local special = vm.getSpecial(node)
- if special == 'setmetatable' and index == 1 then
- local mt = node.next.args[2]
- if mt then
- callback(mt)
- end
- end
-end
-
-local function eachMetaOfRecv(source, callback)
- if not source or source.type ~= 'select' then
- return
- end
- if source.index ~= 1 then
- return
- end
- local call = source.vararg
- if not call or call.type ~= 'call' then
- return
- end
- local special = vm.getSpecial(call.node)
- if special ~= 'setmetatable' then
- return
- end
- local mt = call.args[2]
- if mt then
- callback(mt)
- end
-end
-
-function vm.eachMetaValue(source, callback)
- vm.eachMeta(source, function (mt)
- for _, src in ipairs(vm.getDefFields(mt)) do
- if vm.getKeyName(src) == '__index' then
- if src.value then
- for _, valueSrc in ipairs(vm.getDefFields(src.value)) do
- callback(valueSrc)
- end
- end
- end
- end
- end)
-end
-
-function vm.eachMeta(source, callback)
- eachMetaOfArg1(source, callback)
- eachMetaOfRecv(source.value, callback)
-end
diff --git a/script/vm/guideInterface.lua b/script/vm/guideInterface.lua
index ae060481..e59fc6e3 100644
--- a/script/vm/guideInterface.lua
+++ b/script/vm/guideInterface.lua
@@ -2,7 +2,7 @@
local vm = require 'vm.vm'
local files = require 'files'
local ws = require 'workspace'
-local guide = require 'core.guide'
+local searcher = require 'core.searcher'
local await = require 'await'
local config = require 'config'
@@ -27,7 +27,7 @@ function m.require(args, index)
return nil
end
local results = {}
- local myUri = guide.getUri(args[1])
+ local myUri = searcher.getUri(args[1])
local uris = ws.findUrisByRequirePath(reqName)
for _, uri in ipairs(uris) do
if not files.eq(myUri, uri) then
@@ -47,7 +47,7 @@ function m.dofile(args, index)
return
end
local results = {}
- local myUri = guide.getUri(args[1])
+ local myUri = searcher.getUri(args[1])
local uris = ws.findUrisByFilePath(reqName)
for _, uri in ipairs(uris) do
if not files.eq(myUri, uri) then
@@ -87,9 +87,9 @@ function vm.interface.global(name, onlyDef)
end
end
-function vm.interface.docType(name)
+function vm.interface.doc(name, type)
await.delay()
- return vm.getDocTypes(name)
+ return vm.getDocNames(name, type)
end
function vm.interface.link(uri)
diff --git a/script/vm/init.lua b/script/vm/init.lua
index b9e8e147..c38f01d5 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -2,10 +2,6 @@ local vm = require 'vm.vm'
require 'vm.getGlobals'
require 'vm.getDocs'
require 'vm.getLibrary'
-require 'vm.getInfer'
-require 'vm.getClass'
-require 'vm.getMeta'
-require 'vm.eachField'
require 'vm.eachDef'
require 'vm.eachRef'
require 'vm.getLinks'
diff --git a/script/vm/vm.lua b/script/vm/vm.lua
index 0248ad8c..ebd0102b 100644
--- a/script/vm/vm.lua
+++ b/script/vm/vm.lua
@@ -1,18 +1,14 @@
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
local util = require 'utility'
local files = require 'files'
local timer = require 'timer'
local setmetatable = setmetatable
-local assert = assert
-local require = require
-local type = type
local running = coroutine.running
local ipairs = ipairs
local log = log
local xpcall = xpcall
local mathHuge = math.huge
-local collectgarbage = collectgarbage
_ENV = nil
@@ -63,10 +59,6 @@ function m.getArgInfo(source)
return nil
end
-function m.getSpecial(source)
- return guide.getSpecial(source)
-end
-
function m.getKeyName(source)
if not source then
return nil
diff --git a/test.lua b/test.lua
index bbc5f208..fe998a89 100644
--- a/test.lua
+++ b/test.lua
@@ -65,22 +65,22 @@ local function testAll()
test 'references'
test 'definition'
test 'type_inference'
+ test 'hover'
+ test 'completion'
+ test 'crossfile'
test 'diagnostics'
test 'highlight'
test 'rename'
- test 'hover'
- test 'completion'
test 'signature'
test 'document_symbol'
test 'code_action'
test 'type_formatting'
- test 'crossfile'
--test 'other'
end
local function main()
debug.setcstacklimit(1000)
- require 'core.guide'.debugMode = true
+ require 'core.searcher'.debugMode = true
require 'language' 'zh-cn'
require 'utility'.enableCloseFunction()
diff --git a/test/basic/init.lua b/test/basic/init.lua
index a3a11f62..1b698493 100644
--- a/test/basic/init.lua
+++ b/test/basic/init.lua
@@ -1,219 +1,2 @@
-local files = require 'files'
-local tm = require 'text-merger'
-
-local function TEST(source)
- return function (expect)
- return function (changes)
- files.removeAll()
- files.setText('', source)
- local text = tm('', changes)
- assert(text == expect)
- end
- end
-end
-
-TEST [[
-
-
-function Test(self)
-
-end
-]][[
-
-
-function Test(self)
-
-end
-
-asser]]{
- [1] = {
- range = {
- ["end"] = {
- character = 0,
- line = 5,
- },
- start = {
- character = 0,
- line = 5,
- },
- },
- rangeLength = 0,
- text = "\
-",
- },
- [2] = {
- range = {
- ["end"] = {
- character = 0,
- line = 6,
- },
- start = {
- character = 0,
- line = 6,
- },
- },
- rangeLength = 0,
- text = "a",
- },
- [3] = {
- range = {
- ["end"] = {
- character = 1,
- line = 6,
- },
- start = {
- character = 1,
- line = 6,
- },
- },
- rangeLength = 0,
- text = "s",
- },
- [4] = {
- range = {
- ["end"] = {
- character = 2,
- line = 6,
- },
- start = {
- character = 2,
- line = 6,
- },
- },
- rangeLength = 0,
- text = "s",
- },
- [5] = {
- range = {
- ["end"] = {
- character = 3,
- line = 6,
- },
- start = {
- character = 3,
- line = 6,
- },
- },
- rangeLength = 0,
- text = "e",
- },
- [6] = {
- range = {
- ["end"] = {
- character = 4,
- line = 6,
- },
- start = {
- character = 4,
- line = 6,
- },
- },
- rangeLength = 0,
- text = "r",
- },
-}
-
-TEST [[
-local mt = {}
-
-function mt['xxx']()
-
-
-
-end
-]] [[
-local mt = {}
-
-function mt['xxx']()
-
-end
-]] {
- [1] = {
- range = {
- ["end"] = {
- character = 4,
- line = 5,
- },
- start = {
- character = 4,
- line = 3,
- },
- },
- rangeLength = 8,
- text = "",
- },
-}
-
-TEST [[
-local mt = {}
-
-function mt['xxx']()
-
-end
-]] [[
-local mt = {}
-
-function mt['xxx']()
- p
-end
-]] {
- [1] = {
- range = {
- ["end"] = {
- character = 4,
- line = 3,
- },
- start = {
- character = 4,
- line = 3,
- },
- },
- rangeLength = 0,
- text = "p",
- },
-}
-
-TEST [[
-print(12345)
-]] [[
-print(123
-45)
-]] {
- [1] = {
- range = {
- ["end"] = {
- character = 9,
- line = 0,
- },
- start = {
- character = 9,
- line = 0,
- },
- },
- rangeLength = 0,
- text = "\
-",
- },
-}
-
-TEST [[
-print(123
-45)
-]] [[
-print(12345)
-]] {
- [1] = {
- range = {
- ["end"] = {
- character = 0,
- line = 1,
- },
- start = {
- character = 9,
- line = 0,
- },
- },
- rangeLength = 2,
- text = "",
- },
-}
+require 'basic.textmerger'
+require 'basic.noder'
diff --git a/test/basic/linker.txt b/test/basic/linker.txt
new file mode 100644
index 00000000..ea3ba180
--- /dev/null
+++ b/test/basic/linker.txt
@@ -0,0 +1,141 @@
+ast -> linkers = {
+ ['g|"X"|"Y"|"Z"'] = {src1, src2, src3},
+ ['g|"X"|"Y"'] = {src4, src5, src6},
+ ['g|"X"'] = {src7, src8, src9},
+ ['l|7'] = {src10},
+ ['l|7|"x"'] = {src11},
+ ['l|11|"k"'] = {src12},
+}
+
+```lua
+x.y.<?z?> = <!f!>
+
+<?g?> = x.y.z
+
+t.<!z!> = 1
+x.y = t
+
+x = {
+ y = {
+ <!z!> = 1
+ }
+}
+```
+
+expect: 'l|x|y|z'
+forward: 'l|x|y|z' -> f
+backward: 'l|x|y|z' -> g
+last: 'l|x|y' + 'z'
+
+expect: 'l|x|y' + '|z'
+forward: 'l|t' + '|z' -> 'l|t|z' -> t.z
+backward: nil
+last: 'l|x' + '|y|z'
+
+expect: 'l|x' + '|y|z'
+forward: 'l|0' + '|y|z' -> 'l|0|y|z'
+backward: nil
+last: nil
+
+expect: 'l|0|y|z'
+forward: nil
+backward: nil
+last: 'l|0|y' + '|z'
+
+expect: 'l|0|y' + '|z'
+forward: 'l|1'+ '|z' -> 'l|1|z' -> field z
+backward: nil
+last: 'l|0' + '|y|z'
+
+
+```lua
+a = {
+ b = {
+ <?c?> = 1,
+ }
+}
+
+print(a.b.<!c!>)
+```
+
+expect: 't|3|c'
+forward: nil
+backward: nil
+last: 't|3' + '|c'
+
+expect: 't|3' + '|c'
+forward: nil
+backward: 't|2|b' + '|c'
+last: nil
+
+expect: 't|2|b|c'
+forward: nil
+backward: 't|2|b' + '|c'
+last: nil
+
+```lua
+---@return <?A?>
+local function f()
+end
+
+local <!x!> = f()
+```
+
+'d|A'
+'f|1|#1'
+'f|1' + '|#1'
+'l|1' + '|#1'
+'s|1' + '|#1'
+
+```lua
+---@generic T
+---@param a T
+---@return T
+local function f(a) end
+
+local <?c?>
+
+local <!v!> = f(c)
+```
+
+'l1'
+'l2|@1'
+'f|1|@1'
+'f|1|#1'
+
+```
+---@generic T
+---@param p T
+---@return T
+local function f(p) end
+
+local <?r?> = f(<!k!>)
+```
+
+l:r
+s:1#1 call
+l:f#1 call
+f:1#1 call -> f:1&T = l:k
+l:f@1 --> 从保存的call信息里找到 f:1&T = l:k
+l:k
+
+
+
+```
+---@generic T, V
+---@param p T
+---@return fun(V):T, V
+local function f(p) end
+
+local f2 = f(<!k!>)
+local <?r?> = f2()
+```
+
+l:r
+s:2|#1 call1
+l:f2|#1 call1
+f:2|#1 call1
+s:1#1|#1 call2
+f:1#1|#1 call2 -> f:1&T = l:k
+dfun:1|#1
+dn:V -> f:1&T = l:k
diff --git a/test/basic/noder.lua b/test/basic/noder.lua
new file mode 100644
index 00000000..3e5e9f25
--- /dev/null
+++ b/test/basic/noder.lua
@@ -0,0 +1,146 @@
+local noder = require 'core.noder'
+local files = require 'files'
+local util = require 'utility'
+local guide = require 'parser.guide'
+
+local function getSource(pos)
+ local ast = files.getAst('')
+ return guide.eachSourceContain(ast.ast, pos, function (source)
+ if source.type == 'local'
+ or source.type == 'getlocal'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal'
+ or source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod'
+ or source.type == 'tablefield'
+ or source.type == 'setindex'
+ or source.type == 'getindex'
+ or source.type == 'tableindex'
+ or source.type == 'label'
+ or source.type == 'goto' then
+ return source
+ end
+ end)
+end
+
+local CARE = {}
+local function TEST(script)
+ return function (expect)
+ files.removeAll()
+ local start = script:find('<?', 1, true)
+ local finish = script:find('?>', 1, true)
+ local pos = (start + finish) // 2 + 1
+ local newScript = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ')
+ files.setText('', newScript)
+ local source = getSource(pos)
+ assert(source)
+ noder.compileNodes(source)
+ local result = {
+ id = noder.getID(source),
+ }
+
+ expect['id'] = expect['id']:gsub('|', '\x1F')
+
+ for key in pairs(CARE) do
+ assert(result[key] == expect[key])
+ end
+ end
+end
+
+CARE['id'] = true
+TEST [[
+local <?x?>
+]] {
+ id = 'l:9',
+}
+
+TEST [[
+local x
+print(<?x?>)
+]] {
+ id = 'l:7',
+}
+
+TEST [[
+local x
+<?x?> = 1
+]] {
+ id = 'l:7',
+}
+
+TEST [[
+print(<?X?>)
+]] {
+ id = 'g:"X"',
+}
+
+TEST [[
+print(<?X?>)
+]] {
+ id = 'g:"X"',
+}
+
+TEST [[
+local x
+print(x.y.<?z?>)
+]] {
+ id = 'l:7|"y"|"z"',
+}
+
+TEST [[
+local x
+function x:<?f?>() end
+]] {
+ id = 'l:7|"f"',
+}
+
+TEST [[
+print(X.Y.<?Z?>)
+]] {
+ id = 'g:"X"|"Y"|"Z"',
+}
+
+TEST [[
+function x:<?f?>() end
+]] {
+ id = 'g:"x"|"f"',
+}
+
+TEST [[
+{
+ <?x?> = 1,
+}
+]] {
+ id = 't:1|"x"',
+}
+
+TEST [[
+return <?X?>
+]] {
+ id = 'g:"X"',
+}
+
+TEST [[
+function f()
+ return <?X?>
+end
+]] {
+ id = 'g:"X"',
+}
+
+TEST [[
+::<?label?>::
+goto label
+]] {
+ id = 'l:5',
+}
+
+TEST [[
+::label::
+goto <?label?>
+]] {
+ id = 'l:3',
+}
diff --git a/test/basic/textmerger.lua b/test/basic/textmerger.lua
new file mode 100644
index 00000000..a3a11f62
--- /dev/null
+++ b/test/basic/textmerger.lua
@@ -0,0 +1,219 @@
+local files = require 'files'
+local tm = require 'text-merger'
+
+local function TEST(source)
+ return function (expect)
+ return function (changes)
+ files.removeAll()
+ files.setText('', source)
+ local text = tm('', changes)
+ assert(text == expect)
+ end
+ end
+end
+
+TEST [[
+
+
+function Test(self)
+
+end
+]][[
+
+
+function Test(self)
+
+end
+
+asser]]{
+ [1] = {
+ range = {
+ ["end"] = {
+ character = 0,
+ line = 5,
+ },
+ start = {
+ character = 0,
+ line = 5,
+ },
+ },
+ rangeLength = 0,
+ text = "\
+",
+ },
+ [2] = {
+ range = {
+ ["end"] = {
+ character = 0,
+ line = 6,
+ },
+ start = {
+ character = 0,
+ line = 6,
+ },
+ },
+ rangeLength = 0,
+ text = "a",
+ },
+ [3] = {
+ range = {
+ ["end"] = {
+ character = 1,
+ line = 6,
+ },
+ start = {
+ character = 1,
+ line = 6,
+ },
+ },
+ rangeLength = 0,
+ text = "s",
+ },
+ [4] = {
+ range = {
+ ["end"] = {
+ character = 2,
+ line = 6,
+ },
+ start = {
+ character = 2,
+ line = 6,
+ },
+ },
+ rangeLength = 0,
+ text = "s",
+ },
+ [5] = {
+ range = {
+ ["end"] = {
+ character = 3,
+ line = 6,
+ },
+ start = {
+ character = 3,
+ line = 6,
+ },
+ },
+ rangeLength = 0,
+ text = "e",
+ },
+ [6] = {
+ range = {
+ ["end"] = {
+ character = 4,
+ line = 6,
+ },
+ start = {
+ character = 4,
+ line = 6,
+ },
+ },
+ rangeLength = 0,
+ text = "r",
+ },
+}
+
+TEST [[
+local mt = {}
+
+function mt['xxx']()
+
+
+
+end
+]] [[
+local mt = {}
+
+function mt['xxx']()
+
+end
+]] {
+ [1] = {
+ range = {
+ ["end"] = {
+ character = 4,
+ line = 5,
+ },
+ start = {
+ character = 4,
+ line = 3,
+ },
+ },
+ rangeLength = 8,
+ text = "",
+ },
+}
+
+TEST [[
+local mt = {}
+
+function mt['xxx']()
+
+end
+]] [[
+local mt = {}
+
+function mt['xxx']()
+ p
+end
+]] {
+ [1] = {
+ range = {
+ ["end"] = {
+ character = 4,
+ line = 3,
+ },
+ start = {
+ character = 4,
+ line = 3,
+ },
+ },
+ rangeLength = 0,
+ text = "p",
+ },
+}
+
+TEST [[
+print(12345)
+]] [[
+print(123
+45)
+]] {
+ [1] = {
+ range = {
+ ["end"] = {
+ character = 9,
+ line = 0,
+ },
+ start = {
+ character = 9,
+ line = 0,
+ },
+ },
+ rangeLength = 0,
+ text = "\
+",
+ },
+}
+
+TEST [[
+print(123
+45)
+]] [[
+print(12345)
+]] {
+ [1] = {
+ range = {
+ ["end"] = {
+ character = 0,
+ line = 1,
+ },
+ start = {
+ character = 9,
+ line = 0,
+ },
+ },
+ rangeLength = 2,
+ text = "",
+ },
+}
diff --git a/test/crossfile/hover.lua b/test/crossfile/hover.lua
index c27cd3dd..e81494ff 100644
--- a/test/crossfile/hover.lua
+++ b/test/crossfile/hover.lua
@@ -202,20 +202,20 @@ TEST {
path = 'a.lua',
content = [[
t = {
- [{}] = 1,
+ [1] = 1,
}
]],
},
{
path = 'b.lua',
content = [[
- <?t?>[{}] = 2
+ <?t?>[1] = 2
]]
},
hover = {
label = [[
global t: {
- [table]: integer = 1|2,
+ [1]: integer = 1|2,
}]],
name = 't',
},
@@ -226,20 +226,20 @@ TEST {
path = 'a.lua',
content = [[
t = {
- [{}] = 1,
+ [1] = 1,
}
]],
},
{
path = 'a.lua',
content = [[
- <?t?>[{}] = 2
+ <?t?>[1] = 2
]]
},
hover = {
label = [[
global t: {
- [table]: integer = 2,
+ [1]: integer = 2,
}]],
name = 't',
},
@@ -729,7 +729,7 @@ food.secondField = 2
]]
},
hover = {
- label = 'field Food.firstField: integer = 0',
+ label = 'field Food.firstField: number = 0',
name = 'food.firstField',
}}
diff --git a/test/definition/init.lua b/test/definition/init.lua
index 6e6d0a9a..85bcd5d5 100644
--- a/test/definition/init.lua
+++ b/test/definition/init.lua
@@ -36,6 +36,7 @@ end
function TEST(script)
files.removeAll()
+ script = script:gsub('\n', '\r\n')
local target = catch_target(script)
local start = script:find('<?', 1, true)
local finish = script:find('?>', 1, true)
@@ -51,8 +52,14 @@ function TEST(script)
positions[i] = { result.target.start, result.target.finish }
end
end
+ if not founded(target, positions) then
+ core('', pos)
+ end
assert(founded(target, positions))
else
+ if #target ~= 0 then
+ core('', pos)
+ end
assert(#target == 0)
end
end
@@ -65,6 +72,6 @@ require 'definition.table'
require 'definition.method'
require 'definition.label'
require 'definition.call'
-require 'definition.bug'
require 'definition.special'
+require 'definition.bug'
require 'definition.luadoc'
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index ff54546b..5531e2e3 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -87,6 +87,11 @@ TEST [[
]]
TEST [[
+---@type <!fun():void!>
+local <?<!f!>?>
+]]
+
+TEST [[
---@param f <!fun():void!>
function t(<?<!f!>?>) end
]]
@@ -97,7 +102,7 @@ function f(<?...?>) end
]]
TEST [[
----@overload fun(y: boolean)
+---@overload <!fun(y: boolean)!>
---@param x number
---@param y boolean
---@param z string
@@ -108,7 +113,7 @@ print(<?f?>)
TEST [[
local function f()
- return 1
+ return <!1!>
end
---@class Class
@@ -204,6 +209,23 @@ TEST [[
]]
TEST [[
+---@return <!fun()!>
+local function f() end
+
+local <?<!r!>?> = f()
+]]
+
+TEST [[
+---@generic T
+---@param p T
+---@return T
+local function f(p) end
+
+local <!k!>
+local <?<!r!>?> = f(<!k!>)
+]]
+
+TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
@@ -260,6 +282,26 @@ print(v1.<?bar1?>)
]]
TEST [[
+---@class A
+local <!t!>
+
+---@type A[]
+local b
+
+local <?<!c!>?> = b[1]
+]]
+
+TEST [[
+---@class A
+local <!t!>
+
+---@type table<number, A>
+local b
+
+local <?<!c!>?> = b[1]
+]]
+
+TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
@@ -299,13 +341,223 @@ print(v1[1].<?bar1?>)
--]]
TEST [[
+---@type fun():<!fun()!>
+local f
+
+local <?<!f2!>?> = f()
+]]
+
+TEST [[
+---@generic T
+---@type fun(x: T):T
+local f
+
+local <?<!v2!>?> = f(<!{}!>)
+]]
+
+TEST [[
+---@generic T
+---@param x T
+---@return fun():T
+local function f(x) end
+
+local v1 = f(<!{}!>)
+local <?<!v2!>?> = v1()
+]]
+
+TEST [[
+---@generic T
+---@type fun(x: T):fun():T
+local f
+
+local v1 = f(<!{}!>)
+local <?<!v2!>?> = v1()
+]]
+
+TEST [[
+---@generic V
+---@return fun(x: V):V
+local function f(x) end
+
+local v1 = f()
+local <?<!v2!>?> = v1(<!{}!>)
+]]
+
+TEST [[
+---@generic V
+---@param x V[]
+---@return V
+local function f(x) end
+
+---@class A
+local <!a!>
+
+---@type A[]
+local b
+
+local <?<!c!>?> = f(b)
+]]
+
+TEST [[
+---@generic V
+---@param x table<number, V>
+---@return V
+local function f(x) end
+
+---@class A
+local <!a!>
+
+---@type table<number, A>
+local b
+
+local <?<!c!>?> = f(b)
+]]
+
+TEST [[
+---@generic V
+---@param x V[]
+---@return V
+local function f(x) end
+
+---@class A
+local <!a!>
+
+---@type table<number, A>
+local b
+
+local <?<!c!>?> = f(b)
+]]
+
+TEST [[
+---@generic V
+---@param x table<number, V>
+---@return V
+local function f(x) end
+
+---@class A
+local <!a!>
+
+---@type A[]
+local b
+
+local <?<!c!>?> = f(b)
+]]
+
+TEST [[
+---@generic K
+---@param x table<K, number>
+---@return K
+local function f(x) end
+
+---@class A
+local <!a!>
+
+---@type table<A, number>
+local b
+
+local <?<!c!>?> = f(b)
+]]
+
+TEST [[
+---@generic V
+---@return fun(t: V[]):V
+local function f() end
+
+---@class A
+local <!a!>
+
+---@type A[]
+local b
+
+local f2 = f()
+
+local <?<!c!>?> = f2(b)
+]]
+
+TEST [[
+---@generic T, V
+---@param t T
+---@return fun(t: V[]):V
+---@return T
+local function f(t) end
+
+---@class A
+local <!a!>
+
+---@type A[]
+local b
+
+local f2, c = f(b)
+
+local <?<!d!>?> = f2(c)
+]]
+
+TEST [[
+---@class C
+local <!v1!>
+
+---@generic V, T
+---@param t T
+---@return fun(t: V): V
+---@return T
+local function iterator(t) end
+
+for <!v!> in iterator(<!v1!>) do
+ print(<?v?>)
+end
+]]
+
+TEST [[
+---@class C
+local <!v!>
+
+---@type C
+local <!v1!>
+
+---@generic V, T
+---@param t T
+---@return fun(t: V): V
+---@return T
+local function iterator(t) end
+
+for <!v!> in iterator(<!v1!>) do
+ print(<?v?>)
+end
+]]
+
+TEST [[
+---@class C
+local <!v!>
+
+---@type C[]
+local v1
+
+---@generic V, T
+---@param t T
+---@return fun(t: V[]): V
+---@return T
+local function iterator(t) end
+
+for <!v!> in iterator(v1) do
+ print(<?v?>)
+end
+]]
+
+TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end
---@type table<number, Foo>
local v1
-local ipairs = ipairs
+
+---@generic T: table, V
+---@param t T
+---@return fun(table: V[], i?: integer):integer, V
+---@return T
+---@return integer i
+local function ipairs(t) end
+
for i, v in ipairs(v1) do
print(v.<?bar1?>)
end
@@ -318,6 +570,35 @@ function Foo:<!bar1!>() end
---@type table<Foo, Foo>
local v1
+
+---@generic T: table, K, V
+---@param t T
+---@return fun(table: table<K, V>, index: K):K, V
+---@return T
+---@return nil
+local function pairs(t) end
+
+for k, v in pairs(v1) do
+ print(k.bar1)
+ print(v.<?bar1?>)
+end
+]]
+
+TEST [[
+---@class Foo
+local Foo = {}
+function Foo:<!bar1!>() end
+
+---@type table<Foo, Foo>
+local v1
+
+---@generic T: table, K, V
+---@param t T
+---@return fun(table: table<K, V>, index: K):K, V
+---@return T
+---@return nil
+local function pairs(t) end
+
for k, v in pairs(v1) do
print(k.<?bar1?>)
print(v.bar1)
@@ -329,6 +610,13 @@ TEST [[
local Foo = {}
function Foo:<!bar1!>() end
+---@generic T: table, V
+---@param t T
+---@return fun(table: V[], i?: integer):integer, V
+---@return T
+---@return integer i
+local function ipairs(t) end
+
---@type table<number, table<number, Foo>>
local v1
for i, v in ipairs(v1) do
diff --git a/test/diagnostics/init.lua b/test/diagnostics/init.lua
index d4bffdb5..0f5880ae 100644
--- a/test/diagnostics/init.lua
+++ b/test/diagnostics/init.lua
@@ -79,7 +79,8 @@ local <!x!>
]]
TEST [[
-local x <close> = print
+local y
+local x <close> = y
]]
TEST [[
@@ -135,11 +136,11 @@ end
)
TEST [[
+local print, _G
print(<!x!>)
print(<!log!>)
print(<!X!>)
print(<!Log!>)
-print(_VERSION)
print(<!y!>)
print(Z)
print(_G)
diff --git a/test/full/example.lua b/test/full/example.lua
index 1eb66060..8633318a 100644
--- a/test/full/example.lua
+++ b/test/full/example.lua
@@ -5,6 +5,7 @@ local diag = require 'core.diagnostics'
local config = require 'config'
local fs = require 'bee.filesystem'
local luadoc = require "parser.luadoc"
+local noder = require 'core.noder'
-- 临时
local function testIfExit(path)
@@ -19,6 +20,7 @@ local function testIfExit(path)
local parseClock = 0
local compileClock = 0
local luadocClock = 0
+ local noderClock = 0
local total
for i = 1, max do
vm = TEST(buf)
@@ -26,21 +28,26 @@ local function testIfExit(path)
luadoc(nil, vm)
local luadocPassed = os.clock() - luadocStart
local passed = os.clock() - clock
- parseClock = parseClock + vm.parseClock
+ local noderStart = os.clock()
+ noder.compileNodes(vm.ast)
+ local noderPassed = os.clock() - noderStart
+ parseClock = parseClock + vm.parseClock
compileClock = compileClock + vm.compileClock
luadocClock = luadocClock + luadocPassed
+ noderClock = noderClock + noderPassed
if passed >= 1.0 or i == max then
need = passed / i
total = i
break
end
end
- print(('基准编译测试[%s]单次耗时:%.10f(解析:%.10f, 编译:%.10f, LuaDoc: %.10f)'):format(
+ print(('基准编译测试[%s]单次耗时:%.10f(解析:%.10f, 编译:%.10f, LuaDoc: %.10f, Noder: %.10f)'):format(
path:filename():string(),
need,
parseClock / total,
compileClock / total,
- luadocClock / total
+ luadocClock / total,
+ noderClock / total
))
local clock = os.clock()
diff --git a/test/hover/init.lua b/test/hover/init.lua
index d0e50036..2c68fef5 100644
--- a/test/hover/init.lua
+++ b/test/hover/init.lua
@@ -54,39 +54,39 @@ obj:<?init?>(1, '测试')
function mt:init(a: any, b: any, c: any)
]]
-TEST [[
-local mt = {}
-mt.__index = mt
-mt.type = 'Class'
-
-function mt:init(a, b, c)
- return
-end
-
-local obj = setmetatable({}, mt)
-
-obj:<?init?>(1, '测试')
-]]
-[[
-function Class:init(a: any, b: any, c: any)
-]]
-
-TEST [[
-local mt = {}
-mt.__index = mt
-mt.__name = 'Class'
-
-function mt:init(a, b, c)
- return
-end
-
-local obj = setmetatable({}, mt)
+--TEST [[
+--local mt = {}
+--mt.__index = mt
+--mt.type = 'Class'
+--
+--function mt:init(a, b, c)
+-- return
+--end
+--
+--local obj = setmetatable({}, mt)
+--
+--obj:<?init?>(1, '测试')
+--]]
+--[[
+--function Class:init(a: any, b: any, c: any)
+--]]
-obj:<?init?>(1, '测试')
-]]
-[[
-function Class:init(a: any, b: any, c: any)
-]]
+--TEST [[
+--local mt = {}
+--mt.__index = mt
+--mt.__name = 'Class'
+--
+--function mt:init(a, b, c)
+-- return
+--end
+--
+--local obj = setmetatable({}, mt)
+--
+--obj:<?init?>(1, '测试')
+--]]
+--[[
+--function Class:init(a: any, b: any, c: any)
+--]]
TEST [[
local mt = {}
@@ -170,55 +170,55 @@ local <?obj?> = {}
]]
"local obj: {}"
-TEST [[
-local mt = {}
-mt.__name = 'class'
-
-local <?obj?> = setmetatable({}, mt)
-]]
-"local obj: class {}"
-
-TEST [[
-local mt = {}
-mt.name = 'class'
-mt.__index = mt
-
-local <?obj?> = setmetatable({}, mt)
-]]
-[[
-local obj: class {
- __index: table,
- name: string = "class",
-}
-]]
-
-TEST [[
-local mt = {}
-mt.TYPE = 'class'
-mt.__index = mt
+--TEST [[
+--local mt = {}
+--mt.__name = 'class'
+--
+--local <?obj?> = setmetatable({}, mt)
+--]]
+--"local obj: class {}"
-local <?obj?> = setmetatable({}, mt)
-]]
-[[
-local obj: class {
- TYPE: string = "class",
- __index: table,
-}
-]]
+--TEST [[
+--local mt = {}
+--mt.name = 'class'
+--mt.__index = mt
+--
+--local <?obj?> = setmetatable({}, mt)
+--]]
+--[[
+--local obj: class {
+-- __index: table,
+-- name: string = "class",
+--}
+--]]
-TEST [[
-local mt = {}
-mt.Class = 'class'
-mt.__index = mt
+--TEST [[
+--local mt = {}
+--mt.TYPE = 'class'
+--mt.__index = mt
+--
+--local <?obj?> = setmetatable({}, mt)
+--]]
+--[[
+--local obj: class {
+-- TYPE: string = "class",
+-- __index: table,
+--}
+--]]
-local <?obj?> = setmetatable({}, mt)
-]]
-[[
-local obj: class {
- Class: string = "class",
- __index: table,
-}
-]]
+--TEST [[
+--local mt = {}
+--mt.Class = 'class'
+--mt.__index = mt
+--
+--local <?obj?> = setmetatable({}, mt)
+--]]
+--[[
+--local obj: class {
+-- Class: string = "class",
+-- __index: table,
+--}
+--]]
-- TODO 支持自定义的函数库
--TEST[[
@@ -422,8 +422,6 @@ local t: {
[1]: integer = 2,
[true]: integer = 3,
[5.5]: integer = 4,
- [table]: integer = 5,
- [function]: integer = 6,
b: integer = 7,
["012"]: integer = 8,
}
@@ -438,9 +436,7 @@ local any = collectgarbage()
t[any] = any
]]
[[
-local t: {
- [number]: integer = 1,
-}
+local t: {}
]]
TEST[[
@@ -492,7 +488,7 @@ local <?self?> = setmetatable({
}, mt)
]]
[[
-local self: obj {
+local self: {
__index: table,
__name: string = "obj",
id: integer = 1,
@@ -860,15 +856,15 @@ print(<?x?>)
local x <close>: integer = 1
]]
-TEST [[
-local function <?a?>(b)
- return (b.c and a(b.c) or b)
-end
-]]
-[[
-function a(b: table)
- -> table
-]]
+--TEST [[
+--local function <?a?>(b)
+-- return (b.c and a(b.c) or b)
+--end
+--]]
+--[[
+--function a(b: table)
+-- -> table
+--]]
TEST [[
local <?t?> = {
@@ -927,7 +923,7 @@ field x: Class
]]
TEST[[
----@type Class
+---@class Class
local <?x?> = class()
]]
[[
@@ -935,7 +931,7 @@ local x: Class
]]
TEST[[
----@type Class
+---@class Class
<?x?> = class()
]]
[[
@@ -943,16 +939,10 @@ global x: Class
]]
TEST[[
-local t = {
- ---@type Class
- <?x?> = class()
-}
-]]
-[[
-field x: Class
-]]
+---@class A
+---@class B
+---@class C
-TEST[[
---@type A|B|C
local <?x?> = class()
]]
@@ -994,7 +984,7 @@ function f(t)
end
]]
[[
-local t: Class {}
+local t: Class
]]
TEST [[
@@ -1020,6 +1010,10 @@ local v: Class
]]
TEST [[
+---@class A
+---@class B
+---@class C
+
---@return A|B
---@return C
local function <?f?>()
@@ -1078,6 +1072,8 @@ function f(x: number, y: boolean)
]]
TEST [[
+---@class Class
+
---@vararg Class
local function f(...)
local _, <?x?> = ...
@@ -1089,6 +1085,21 @@ local x: Class
]]
TEST [[
+---@class Class
+
+---@vararg Class
+local function f(...)
+ local t = {...}
+ local <?v?> = t[1]
+end
+]]
+[[
+local v: Class
+]]
+
+TEST [[
+---@class Class
+
---@vararg Class
local function f(...)
local <?t?> = {...}
@@ -1164,23 +1175,29 @@ local x: table<ClassA, ClassB>
]]
--TEST [[
+-----@class ClassA
+-----@class ClassB
+--
-----@type table<ClassA, ClassB>
--local t
--for _, <?x?> in pairs(t) do
--end
--]]
--[[
---local x: *ClassB
+--local x: ClassB
--]]
--TEST [[
+-----@class ClassA
+-----@class ClassB
+--
-----@type table<ClassA, ClassB>
--local t
--for <?k?>, v in pairs(t) do
--end
--]]
--[[
---local k: *ClassA
+--local k: ClassA
--]]
TEST [[
@@ -1202,6 +1219,8 @@ local r: boolean
]]
TEST [[
+---@class void
+
---@param f fun():void
function t(<?f?>) end
]]
@@ -1492,6 +1511,12 @@ TEST [[
---@field x string
local t
+---@generic T
+---@param v T
+---@param message any
+---@return T
+local function assert(v, message) end
+
local <?v?> = assert(t)
]]
[[
diff --git a/test/references/all.lua b/test/references/all.lua
new file mode 100644
index 00000000..a9442ae1
--- /dev/null
+++ b/test/references/all.lua
@@ -0,0 +1,213 @@
+local core = require 'core.reference'
+local files = require 'files'
+
+local function catch_target(script)
+ local list = {}
+ local cur = 1
+ while true do
+ local start, finish = script:find('<[!?].-[!?]>', cur)
+ if not start then
+ break
+ end
+ list[#list+1] = { start + 2, finish - 2 }
+ cur = finish + 1
+ end
+ return list
+end
+
+local function founded(targets, results)
+ if #targets ~= #results then
+ return false
+ end
+ for _, target in ipairs(targets) do
+ for _, result in ipairs(results) do
+ if target[1] == result[1] and target[2] == result[2] then
+ goto NEXT
+ end
+ end
+ do return false end
+ ::NEXT::
+ end
+ return true
+end
+
+function TEST(script)
+ files.removeAll()
+ local expect = catch_target(script)
+ local start = script:find('<[?~]')
+ local finish = script:find('[?~]>')
+ local pos = (start + finish) // 2 + 1
+ local new_script = script:gsub('<[!?~]', ' '):gsub('[!?~]>', ' ')
+ files.setText('', new_script)
+
+ local results = core('', pos)
+ if results then
+ local positions = {}
+ for i, result in ipairs(results) do
+ positions[i] = { result.target.start, result.target.finish }
+ end
+ assert(founded(expect, positions))
+ else
+ assert(#expect == 0)
+ end
+end
+
+TEST [[
+---@class A
+local a = {}
+a.<?x?> = 1
+
+---@return A
+local function f() end
+
+local b = f()
+return b.<!x!>
+]]
+
+TEST [[
+---@class A
+local a = {}
+a.<?x?> = 1
+
+---@return table
+---@return A
+local function f() end
+
+local a, b = f()
+return a.x, b.<!x!>
+]]
+
+TEST [[
+local <?mt?> = {}
+function <!mt!>:x()
+ <!self!>:x()
+end
+]]
+
+TEST [[
+local mt = {}
+function mt:<?x?>()
+ self:<!x!>()
+end
+]]
+
+TEST [[
+---@class Dog
+local mt = {}
+function mt:<?eat?>()
+end
+
+---@class Master
+local mt2 = {}
+function mt2:init()
+ ---@type Dog
+ local foo = self:doSomething()
+ ---@type Dog
+ self.dog = getDog()
+end
+function mt2:feed()
+ self.dog:<!eat!>()
+end
+function mt2:doSomething()
+end
+]]
+
+-- 泛型的反向搜索
+TEST [[
+---@class Dog
+local <?Dog?> = {}
+
+---@generic T
+---@param type1 T
+---@return T
+function foobar(type1)
+end
+
+local <!v1!> = foobar(<!Dog!>)
+]]
+
+TEST [[
+---@class Dog
+local Dog = {}
+function Dog:<?eat?>()
+end
+
+---@generic T
+---@param type1 T
+---@return T
+function foobar(type1)
+ return {}
+end
+
+local v1 = foobar(Dog)
+v1:<!eat!>()
+]]
+
+TEST [[
+---@class Dog
+local Dog = {}
+function Dog:<?eat?>()
+end
+
+---@class Master
+local Master = {}
+
+---@generic T
+---@param type1 string
+---@param type2 T
+---@return T
+function Master:foobar(type1, type2)
+ return {}
+end
+
+local v1 = Master:foobar("", Dog)
+v1.<!eat!>()
+]]
+
+TEST [[
+---@class A
+local <?A?>
+
+---@generic T
+---@param self T
+---@return T
+function m.f(self) end
+
+local <!b!> = m.f(<!A!>)
+]]
+
+TEST [[
+---@class A
+local <?A?>
+
+---@generic T
+---@param self T
+---@return T
+function m:f() end
+
+local <!b!> = m.f(<!A!>)
+]]
+
+TEST [[
+---@class A
+local <?A?>
+
+---@generic T
+---@param self T
+---@return T
+function <!A!>.f(self) end
+
+local <!b!> = <!A!>:f()
+]]
+
+TEST [[
+---@class A
+local <?A?>
+
+---@generic T
+---@param self T
+---@return T
+function <!A!>:f() end
+
+local <!b!> = <!A!>:f()
+]]
diff --git a/test/references/init.lua b/test/references/init.lua
index c4e5018a..e90cb2a8 100644
--- a/test/references/init.lua
+++ b/test/references/init.lua
@@ -1,4 +1,4 @@
-local core = require 'core.reference'
+local core = require 'core.reference'
local files = require 'files'
local function catch_target(script)
@@ -33,7 +33,7 @@ end
function TEST(script)
files.removeAll()
- local target = catch_target(script)
+ local expect = catch_target(script)
local start = script:find('<[?~]')
local finish = script:find('[?~]>')
local pos = (start + finish) // 2 + 1
@@ -46,9 +46,9 @@ function TEST(script)
for i, result in ipairs(results) do
positions[i] = { result.target.start, result.target.finish }
end
- assert(founded(target, positions))
+ assert(founded(expect, positions))
else
- assert(#target == 0)
+ assert(#expect == 0)
end
end
@@ -96,6 +96,16 @@ local <?a?> = 1
]]
TEST [[
+local <!a!>
+local <?b?> = <!a!>
+]]
+
+TEST [[
+local <?a?>
+local <!b!> = <!a!>
+]]
+
+TEST [[
local t = {
<!a!> = 1
}
@@ -166,7 +176,7 @@ local y = f()()
TEST [[
local t = {}
t.<?x?> = 1
-t[a.b.x] = 1
+t[<!a.b.c!>] = 1
]]
TEST [[
@@ -208,13 +218,6 @@ end
]]
TEST [[
-local <?mt?> = {}
-function <!mt!>:x()
- <!self!>:x()
-end
-]]
-
-TEST [[
local mt = {}
function mt:<!x!>()
self:<?x?>()
@@ -222,13 +225,6 @@ end
]]
TEST [[
-local mt = {}
-function mt:<?x?>()
- self:<!x!>()
-end
-]]
-
-TEST [[
a.<!b!>.c = 1
print(a.<?b?>.c)
]]
@@ -252,7 +248,7 @@ a.<!t!> = <?f?>
]]
TEST [[
-<!t!>.f = <?t?>
+<!t!>.<!f!> = <?t?>
]]
TEST [[
@@ -302,135 +298,3 @@ TEST [[
---@return <?xxx?>
function f() end
]]
-
-TEST [[
----@class Dog
-local mt = {}
-function mt:<?eat?>()
-end
-
----@class Master
-local mt2 = {}
-function mt2:init()
- ---@type Dog
- local foo = self:doSomething()
- ---@type Dog
- self.dog = getDog()
-end
-function mt2:feed()
- self.dog:<!eat!>()
-end
-function mt2:doSomething()
-end
-]]
-
-TEST [[
----@class A
-local a = {}
-a.<?x?> = 1
-
----@return A
-local function f() end
-
-local b = f()
-return b.<!x!>
-]]
-
-TEST [[
----@class A
-local a = {}
-a.<?x?> = 1
-
----@return table
----@return A
-local function f() end
-
-local a, b = f()
-return a.x, b.<!x!>
-]]
-
-TEST [[
----@class Dog
-local Dog = {}
-function Dog:<?eat?>()
-end
-
----@generic T
----@param type1 T
----@return T
-function foobar(type1)
- return {}
-end
-
-local v1 = foobar(Dog)
-v1:<!eat!>()
-]]
-
-TEST [[
----@class Dog
-local Dog = {}
-function Dog:<?eat?>()
-end
-
----@class Master
-local Master = {}
-
----@generic T
----@param type1 string
----@param type2 T
----@return T
-function Master:foobar(type1, type2)
- return {}
-end
-
-local v1 = Master:foobar("", Dog)
-v1.<!eat!>()
-]]
-
-TEST [[
----@class A
-local <?A?>
-
----@generic T
----@param self T
----@return T
-function m.f(self) end
-
-local <!b!> = m.f(<!A!>)
-]]
-
-TEST [[
----@class A
-local <?A?>
-
----@generic T
----@param self T
----@return T
-function m:f() end
-
-local <!b!> = m.f(<!A!>)
-]]
-
-TEST [[
----@class A
-local <?A?>
-
----@generic T
----@param self T
----@return T
-function <!A!>.f(self) end
-
-local <!b!> = <!A!>:f()
-]]
-
-TEST [[
----@class A
-local <?A?>
-
----@generic T
----@param self T
----@return T
-function <!A!>:f() end
-
-local <!b!> = <!A!>:f()
-]]
diff --git a/test/rename/init.lua b/test/rename/init.lua
index 88f83269..4b10756e 100644
--- a/test/rename/init.lua
+++ b/test/rename/init.lua
@@ -18,7 +18,7 @@ end
function TEST(oldName, newName)
return function (oldScript)
- return function (newScript)
+ return function (expectScript)
files.removeAll()
files.setText('', oldScript)
local pos = oldScript:find('[^%w_]'..oldName..'[^%w_]')
@@ -29,7 +29,7 @@ function TEST(oldName, newName)
if positions then
script = replace(script, positions)
end
- assert(script == newScript)
+ assert(script == expectScript)
end
end
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 02355a94..4c8817f2 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1,7 +1,7 @@
local files = require 'files'
-local config = require 'config'
local vm = require 'vm'
-local guide = require 'core.guide'
+local guide = require 'parser.guide'
+local infer = require 'core.infer'
rawset(_G, 'TEST', true)
@@ -30,7 +30,10 @@ function TEST(wanted)
files.setText('', newScript)
local source = getSource(pos)
assert(source)
- local result = vm.getInferType(source, 0)
+ local result = infer.searchAndViewInfers(source)
+ if wanted ~= result then
+ infer.searchAndViewInfers(source)
+ end
assert(wanted == result)
end
end
@@ -140,22 +143,44 @@ TEST 'number' [[
]]
TEST 'tablelib' [[
+---@class tablelib
+table = {}
+
<?table?>()
]]
TEST 'string' [[
+_VERSION = 'Lua 5.4'
+
<?x?> = _VERSION
]]
TEST 'function' [[
+---@class stringlib
+local string
+
+string.sub = function () end
+
return ('x').<?sub?>
]]
TEST 'function' [[
+---@class stringlib
+local string
+
+string.sub = function () end
+
<?x?> = ('x').sub
]]
TEST 'function' [[
+---@class stringlib
+local string
+
+string.sub = function () end
+
+_VERSION = 'Lua 5.4'
+
<?x?> = _VERSION.sub
]]
@@ -200,8 +225,11 @@ end
_, <?y?> = pcall(x)
]]
-TEST 'oslib' [[
-local <?os?> = require 'os'
+TEST 'integer' [[
+local function x()
+ return 1
+end
+_, <?y?> = xpcall(x)
]]
TEST 'string|table' [[
@@ -218,9 +246,14 @@ local function f(<?a?>, b)
end
]]
-TEST 'string' [[
+TEST 'A' [[
+---@class A
+
+---@return A
+local function f2() end
+
local function f()
- return string.sub()
+ return f2()
end
local <?x?> = f()
@@ -238,14 +271,6 @@ local <?x?> = f()
--setmetatable(<?b?>)
--]]
-TEST 'function' [[
-string.<?sub?>()
-]]
-
-TEST 'function' [[
-(''):<?sub?>()
-]]
-
-- 不根据对方函数内的使用情况来推测
TEST 'any' [[
local function x(a)
@@ -270,12 +295,6 @@ end
local _, _, _, <?b?>, _ = x(nil, true, 1, 'yy')
]]
--- TODO 暂不支持这些特殊情况,之后用其他语法定义
---TEST 'integer' [[
---for <?i?> in ipairs(t) do
---end
---]]
-
TEST 'any' [[
local <?x?> = next()
]]
@@ -297,16 +316,23 @@ local <?x?>
]]
TEST 'string' [[
+---@class string
+
---@type string
local <?x?>
]]
TEST 'string[]' [[
+---@class string
+
---@type string[]
local <?x?>
]]
TEST 'string|table' [[
+---@class string
+---@class table
+
---@type string | table
local <?x?>
]]
@@ -322,6 +348,9 @@ local <?x?>
]]
TEST 'table<string, number>' [[
+---@class string
+---@class number
+
---@type table<string, number>
local <?x?>
]]
@@ -331,12 +360,16 @@ self.<?t?>[#self.t+1] = {}
]]
TEST 'string' [[
+---@class string
+
---@type string[]
local x
local <?y?> = x[1]
]]
TEST 'string' [[
+---@class string
+
---@return string[]
local function f() end
local x = f()
@@ -387,6 +420,15 @@ print(t.<?a?>)
]]
TEST 'integer' [[
+---@class integer
+
+---@generic T: table, V
+---@param t T
+---@return fun(table: V[], i?: integer):integer, V
+---@return T
+---@return integer i
+local function ipairs() end
+
for <?i?> in ipairs() do
end
]]
@@ -404,6 +446,8 @@ local k, v = next(<?t?>)
]]
TEST 'string' [[
+---@class string
+
---@generic K, V
---@param t table<K, V>
---@return K
@@ -416,6 +460,8 @@ local <?k?>, v = next(t)
]]
TEST 'boolean' [[
+---@class boolean
+
---@generic K, V
---@param t table<K, V>
---@return K
@@ -436,6 +482,8 @@ local <?r?> = f(true)
]]
TEST 'string' [[
+---@class string
+
---@generic K, V
---@type fun(arg: table<K, V>):K, V
local f
@@ -447,6 +495,8 @@ local <?k?>, v = f(t)
]]
TEST 'boolean' [[
+---@class boolean
+
---@generic K, V
---@type fun(arg: table<K, V>):K, V
local f
@@ -472,6 +522,8 @@ local <?r?> = f()
]]
TEST 'string' [[
+---@class string
+
---@generic K, V
---@return fun(arg: table<K, V>):K, V
local function f() end
@@ -485,6 +537,8 @@ local <?k?>, v = f2(t)
]]
TEST 'string' [[
+---@class string
+
---@generic T: table, K, V
---@param t T
---@return fun(table: table<K, V>, index: K):K, V
@@ -502,11 +556,12 @@ end
]]
TEST 'boolean' [[
+---@class boolean
+
---@generic T: table, K, V
---@param t T
----@return fun(table: table<K, V>, index: K):K, V
+---@return fun(table: table<K, V>, index?: K):K, V
---@return T
----@return nil
local function pairs(t) end
local f = pairs(t)
@@ -519,11 +574,12 @@ end
]]
TEST 'string' [[
+---@class string
+
---@generic T: table, K, V
---@param t T
----@return fun(table: table<K, V>, index: K):K, V
+---@return fun(table: table<K, V>, index?: K):K, V
---@return T
----@return nil
local function pairs(t) end
---@type table<string, boolean>
@@ -534,6 +590,8 @@ end
]]
TEST 'boolean' [[
+---@class boolean
+
---@generic T: table, K, V
---@param t T
---@return fun(table: table<K, V>, index: K):K, V
@@ -549,6 +607,8 @@ end
]]
TEST 'boolean' [[
+---@class boolean
+
---@generic T: table, V
---@param t T
---@return fun(table: V[], i?: integer):integer, V
@@ -564,6 +624,8 @@ end
]]
TEST 'boolean' [[
+---@class boolean
+
---@generic T: table, K, V
---@param t T
---@return fun(table: table<K, V>, index: K):K, V
@@ -579,11 +641,12 @@ end
]]
TEST 'integer' [[
+---@class integer
+
---@generic T: table, K, V
---@param t T
----@return fun(table: table<K, V>, index: K):K, V
+---@return fun(table: table<K, V>, index?: K):K, V
---@return T
----@return nil
local function pairs(t) end
---@type boolean[]