diff options
95 files changed, 4719 insertions, 1795 deletions
diff --git a/meta/template/basic.lua b/meta/template/basic.lua index 785819a4..7a42ab74 100644 --- a/meta/template/basic.lua +++ b/meta/template/basic.lua @@ -135,9 +135,8 @@ function next(table, index) end ---#DES 'paris' ---@generic T: table, K, V ---@param t T ----@return fun(table: table<K, V>, index: K):K, V +---@return fun(table: table<K, V>, index?: K):K, V ---@return T ----@return nil function pairs(t) end ---#DES 'pcall' diff --git a/meta/template/builtin.lua b/meta/template/builtin.lua index 2b547d1d..45ac24af 100644 --- a/meta/template/builtin.lua +++ b/meta/template/builtin.lua @@ -1,6 +1,7 @@ ---@meta ----@class any +---@class unknown +---@class any: unknown ---@class nil: any ---@class boolean: any ---@class number: any diff --git a/meta/template/package.lua b/meta/template/package.lua index ae7def31..8c18e10b 100644 --- a/meta/template/package.lua +++ b/meta/template/package.lua @@ -3,13 +3,13 @@ ---#if VERSION >=5.4 then ---#DES 'require>5.4' ---@param modname string ----@return any ----@return any loaderdata +---@return unknown +---@return unknown loaderdata function require(modname) end ---#else ---#DES 'require<5.3' ---@param modname string ----@return any +---@return unknown function require(modname) end ---#end diff --git a/script/core/code-action.lua b/script/core/code-action.lua index bae3df81..9ed000e9 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -1,10 +1,9 @@ -local files = require 'files' -local lang = require 'language' -local define = require 'proto.define' -local guide = require 'core.guide' -local util = require 'utility' -local sp = require 'bee.subprocess' -local vm = require 'vm' +local files = require 'files' +local lang = require 'language' +local util = require 'utility' +local sp = require 'bee.subprocess' +local vm = require 'vm' +local guide = require "parser.guide" local function checkDisableByLuaDocExits(uri, row, mode, code) local lines = files.getLines(uri) diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua index 527af8d5..ba1ee8eb 100644 --- a/script/core/command/removeSpace.lua +++ b/script/core/command/removeSpace.lua @@ -1,11 +1,10 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local proto = require 'proto' +local lang = require 'language' local function isInString(ast, offset) - return guide.eachSourceContain(ast.ast, offset, function (source) + return searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'string' then return true end @@ -23,10 +22,10 @@ return function (data) local textEdit = {} for i = 1, #lines do - local line = guide.lineContent(lines, text, i, true) + local line = searcher.lineContent(lines, text, i, true) local pos = line:find '[ \t]+$' if pos then - local start, finish = guide.lineRange(lines, i, true) + local start, finish = searcher.lineRange(lines, i, true) start = start + pos - 1 if isInString(ast, start) then goto NEXT_LINE diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua index 995a2109..dc23e7af 100644 --- a/script/core/command/solve.lua +++ b/script/core/command/solve.lua @@ -1,8 +1,7 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' local opMap = { ['+'] = true, diff --git a/script/core/completion.lua b/script/core/completion.lua index e3980eca..acaaa276 100644 --- a/script/core/completion.lua +++ b/script/core/completion.lua @@ -1,19 +1,14 @@ local define = require 'proto.define' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local matchKey = require 'core.matchkey' local vm = require 'vm' -local getLabel = require 'core.hover.label' local getName = require 'core.hover.name' local getArg = require 'core.hover.arg' -local getReturn = require 'core.hover.return' -local getDesc = require 'core.hover.description' local getHover = require 'core.hover' local config = require 'config' local util = require 'utility' local markdown = require 'provider.markdown' -local findSource = require 'core.find-source' -local await = require 'await' local parser = require 'parser' local keyWordMap = require 'core.keyword' local workspace = require 'workspace' @@ -21,6 +16,8 @@ local furi = require 'file-uri' local rpath = require 'workspace.require-path' local lang = require 'language' local lookBackward = require 'core.look-backward' +local guide = require 'parser.guide' +local infer = require 'core.infer' local DiagnosticModes = { 'disable-next-line', @@ -135,8 +132,8 @@ local function buildFunctionSnip(source, value, oop) end local function buildDetail(source) - local types = vm.getInferType(source, 0) - local literals = vm.getInferLiteral(source, 0) + local types = infer.searchAndViewInfers(source) + local literals = infer.searchAndViewLiterals(source) if literals then return types .. ' = ' .. literals else @@ -149,9 +146,9 @@ local function getSnip(source) if context <= 0 then return nil end - local defs = vm.getRefs(source, 0) + local defs = vm.getRefs(source) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def ~= source and def.type == 'function' then local uri = guide.getUri(def) local text = files.getText(uri) @@ -273,8 +270,8 @@ local function checkLocal(ast, word, offset, results) if not matchKey(word, name) then goto CONTINUE end - if vm.hasType(source, 'function') then - for _, def in ipairs(vm.getDefs(source, 0)) do + if infer.hasType(source, 'function') then + for _, def in ipairs(vm.getDefs(source)) do if def.type == 'function' or def.type == 'doc.type.function' then local funcLabel = name .. getParams(def, false) @@ -417,7 +414,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset) end local function checkFieldThen(name, src, word, start, offset, parent, oop, results) - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src local kind = define.CompletionItemKind.Field if value.type == 'function' or value.type == 'doc.type.function' then @@ -492,7 +489,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res end local funcLabel if config.config.completion.showParams then - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src if value.type == 'function' or value.type == 'doc.type.function' then funcLabel = name .. getParams(value, oop) @@ -539,16 +536,16 @@ end local function checkGlobal(ast, word, start, offset, parent, oop, results) local locals = guide.getVisibleLocals(ast.ast, offset) - local refs = vm.getGlobalSets '*' - checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global') + local globals = vm.getGlobalSets '*' + checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results, locals, 'global') end local function checkField(ast, word, start, offset, parent, oop, results) if parent.tag == '_ENV' or parent.special == '_G' then - local refs = vm.getGlobalSets '*' - checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) + local globals = vm.getGlobalSets '*' + checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results) else - local refs = vm.getFields(parent, 0) + local refs = vm.getRefs(parent, '*') checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) end end @@ -1043,14 +1040,14 @@ local function mergeEnums(a, b, source) end end -local function checkTypingEnum(ast, text, offset, infers, str, results) +local function checkTypingEnum(ast, text, offset, defs, str, results) local enums = {} - for _, infer in ipairs(infers) do - if infer.source.type == 'doc.type.enum' - or infer.source.type == 'doc.resume' then + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then enums[#enums+1] = { - label = infer.source[1], - description = infer.source.comment and infer.source.comment.text, + label = def[1], + description = def.comment and def.comment.text, kind = define.CompletionItemKind.EnumMember, } end @@ -1074,8 +1071,8 @@ local function checkEqualEnumLeft(ast, text, offset, source, results) return src end end) - local infers = vm.getInfers(source, 0) - checkTypingEnum(ast, text, offset, infers, str, results) + local defs = vm.getDefs(source) + checkTypingEnum(ast, text, offset, defs, str, results) end local function checkEqualEnum(ast, text, offset, results) @@ -1247,7 +1244,30 @@ function (%s)\ end"):format(table.concat(args, ', ')) end -local function getCallEnums(source, index) +local function pushCallEnumsAndFuncs(defs) + local results = {} + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then + results[#results+1] = { + label = def[1], + description = def.comment, + kind = define.CompletionItemKind.EnumMember, + } + end + if def.type == 'doc.type.function' then + results[#results+1] = { + label = infer.viewDocFunction(def), + description = def.comment, + kind = define.CompletionItemKind.Function, + insertText = buildInsertDocFunction(def), + } + end + end + return results +end + +local function getCallEnumsAndFuncs(source, index) if source.type == 'function' and source.bindDocs then if not source.args then return @@ -1266,37 +1286,10 @@ local function getCallEnums(source, index) for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.param' and doc.param[1] == arg[1] then - local enums = {} - for _, enum in ipairs(vm.getDocEnums(doc.extends) or {}) do - enums[#enums+1] = { - label = enum[1], - description = enum.comment, - kind = define.CompletionItemKind.EnumMember, - } - end - for _, unit in ipairs(vm.getDocTypeUnits(doc.extends) or {}) do - if unit.type == 'doc.type.function' then - local text = files.getText(guide.getUri(unit)) - enums[#enums+1] = { - label = text:sub(unit.start, unit.finish), - description = doc.comment, - kind = define.CompletionItemKind.Function, - insertText = buildInsertDocFunction(unit), - } - end - end - return enums + return pushCallEnumsAndFuncs(vm.getDefs(doc.extends)) elseif doc.type == 'doc.vararg' and arg.type == '...' then - local enums = {} - for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do - enums[#enums+1] = { - label = enum[1], - description = enum.comment, - kind = define.CompletionItemKind.EnumMember, - } - end - return enums + return pushCallEnumsAndFuncs(vm.getDefs(doc.vararg)) end end end @@ -1403,12 +1396,12 @@ local function checkTableLiteralFieldByCall(ast, text, offset, call, defs, index return end for _, def in ipairs(defs) do - local func = guide.getObjectValue(def) or def + local func = searcher.getObjectValue(def) or def local param = getFuncParamByCallIndex(func, index) if not param then goto CONTINUE end - local defs = vm.getDefFields(param, 0) + local defs = vm.getDefs(param, '*') for _, field in ipairs(defs) do local name = guide.getKeyName(field) if name and not mark[name] then @@ -1431,10 +1424,10 @@ local function tryCallArg(ast, text, offset, results) if arg and arg.type == 'function' then return end - local defs = vm.getDefs(call.node, 0) + local defs = vm.getDefs(call.node) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def - local enums = getCallEnums(def, argIndex) + def = searcher.getObjectValue(def) or def + local enums = getCallEnumsAndFuncs(def, argIndex) if enums then mergeEnums(myResults, enums, arg) end @@ -1461,7 +1454,8 @@ local function tryTable(ast, text, offset, results) if source.type ~= 'table' then tbl = source.parent end - local defs = vm.getDefFields(tbl, 0) + local parent = tbl.parent + local defs = vm.getDefs(parent, '*') for _, field in ipairs(defs) do local name = guide.getKeyName(field) if name and not mark[name] then @@ -1560,7 +1554,7 @@ end local function tryLuaDocBySource(ast, offset, source, results) if source.type == 'doc.extends.name' then if source.parent.type == 'doc.class' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if doc.type == 'doc.class.name' and doc.parent ~= source.parent and matchKey(source[1], doc[1]) then @@ -1578,7 +1572,7 @@ local function tryLuaDocBySource(ast, offset, source, results) end return true elseif source.type == 'doc.type.name' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') and doc.parent ~= source.parent and matchKey(source[1], doc[1]) then @@ -1652,7 +1646,7 @@ end local function tryLuaDocByErr(ast, offset, err, docState, results) if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if doc.type == 'doc.class.name' and doc.parent ~= docState then results[#results+1] = { @@ -1662,7 +1656,7 @@ local function tryLuaDocByErr(ast, offset, err, docState, results) end end elseif err.type == 'LUADOC_MISS_TYPE_NAME' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then results[#results+1] = { label = doc[1], @@ -1735,14 +1729,14 @@ local function buildLuaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = vm.getInferType(arg) + args[#args+1] = infer.searchAndViewInfers(arg) end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = vm.getInferType(rtns[n]) + returns[n] = infer.searchAndViewInfers(rtns[n]) end end end diff --git a/script/core/definition.lua b/script/core/definition.lua index b26bb922..3ced05a2 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -1,14 +1,15 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local workspace = require 'workspace' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' +local guide = require 'parser.guide' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -20,7 +21,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else @@ -127,11 +128,11 @@ return function (uri, offset) end end - local defs = vm.getDefs(source, 0) + local defs = vm.getDefs(source) local values = {} for _, src in ipairs(defs) do - local value = guide.getObjectValue(src) - if value and value ~= src then + local value = searcher.getObjectValue(src) + if value and value ~= src and guide.isLiteral(value) then values[value] = true end end @@ -148,9 +149,6 @@ return function (uri, offset) goto CONTINUE end src = src.field or src.method or src.index or src - if src.type == 'table' and src.parent.type ~= 'return' then - goto CONTINUE - end if src.type == 'doc.class.name' and source.type ~= 'doc.type.name' and source.type ~= 'doc.extends.name' diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua index 19bb4f97..37815fb5 100644 --- a/script/core/diagnostics/ambiguity-1.lua +++ b/script/core/diagnostics/ambiguity-1.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local opMap = { diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua index 702cd904..d2e26378 100644 --- a/script/core/diagnostics/circle-doc-class.lua +++ b/script/core/diagnostics/circle-doc-class.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' +local guide = require 'parser.guide' return function (uri, callback) local state = files.getAst(uri) @@ -40,7 +40,7 @@ return function (uri, callback) end if not mark[newName] then mark[newName] = true - local docs = vm.getDocTypes(newName) + local docs = vm.getDocDefines(newName) for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' then list[#list+1] = otherDoc.parent diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index d1983c42..7828efe9 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -1,7 +1,6 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' return function (uri, callback) local state = files.getAst(uri) diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua index f23755ea..f300a61a 100644 --- a/script/core/diagnostics/code-after-break.lua +++ b/script/core/diagnostics/code-after-break.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' return function (uri, callback) local state = files.getAst(uri) diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua index 65099af8..ee245781 100644 --- a/script/core/diagnostics/count-down-loop.lua +++ b/script/core/diagnostics/count-down-loop.lua @@ -1,6 +1,6 @@ -local files = require "files" -local guide = require "core.guide" -local lang = require 'language' +local files = require "files" +local guide = require "parser.guide" +local lang = require 'language' return function (uri, callback) local state = files.getAst(uri) diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua index 60d60946..a6f8a47e 100644 --- a/script/core/diagnostics/deprecated.lua +++ b/script/core/diagnostics/deprecated.lua @@ -1,10 +1,10 @@ -local files = require 'files' -local vm = require 'vm' -local lang = require 'language' -local guide = require 'core.guide' -local config = require 'config' -local define = require 'proto.define' -local await = require 'await' +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local guide = require 'parser.guide' +local config = require 'config' +local define = require 'proto.define' +local await = require 'await' return function (uri, callback) local ast = files.getAst(uri) @@ -20,7 +20,7 @@ return function (uri, callback) return end if src.type == 'getglobal' then - local key = guide.getKeyName(src) + local key = src[1] if not key then return end @@ -34,11 +34,11 @@ return function (uri, callback) await.delay() - if not vm.isDeprecated(src, 0) then + if not vm.isDeprecated(src, true) then return end - local defs = vm.getDefs(src, 0) + local defs = vm.getDefs(src) local validVersions for _, def in ipairs(defs) do if def.bindDocs then diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua index 8c6696a9..daecb836 100644 --- a/script/core/diagnostics/duplicate-doc-class.lua +++ b/script/core/diagnostics/duplicate-doc-class.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' +local guide = require 'parser.guide' return function (uri, callback) local state = files.getAst(uri) @@ -20,7 +20,7 @@ return function (uri, callback) or doc.type == 'doc.alias' then local name = guide.getKeyName(doc) if not cache[name] then - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) cache[name] = {} for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' @@ -28,7 +28,7 @@ return function (uri, callback) cache[name][#cache[name]+1] = { start = otherDoc.start, finish = otherDoc.finish, - uri = guide.getUri(otherDoc), + uri = searcher.getUri(otherDoc), } end end diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua index 5e63d39e..d1ba9261 100644 --- a/script/core/diagnostics/duplicate-index.lua +++ b/script/core/diagnostics/duplicate-index.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' return function (uri, callback) local ast = files.getAst(uri) diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua index c1e2285a..e1883fe5 100644 --- a/script/core/diagnostics/duplicate-set-field.lua +++ b/script/core/diagnostics/duplicate-set-field.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local define = require 'proto.define' +local guide = require "parser.guide" return function (uri, callback) local ast = files.getAst(uri) @@ -30,7 +30,7 @@ return function (uri, callback) if not name then goto CONTINUE end - local value = guide.getObjectValue(nxt) + local value = searcher.getObjectValue(nxt) if not value or value.type ~= 'function' then goto CONTINUE end diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua index 690a4ca2..2024f4e3 100644 --- a/script/core/diagnostics/empty-block.lua +++ b/script/core/diagnostics/empty-block.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index de23bc76..9a0d4f35 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' -- TODO: 检查路径是否可达 diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua index a2b831f7..1d1ab9af 100644 --- a/script/core/diagnostics/init.lua +++ b/script/core/diagnostics/init.lua @@ -62,14 +62,11 @@ local function check(uri, name, results) end return function (uri, response) - local vm = require 'vm' local ast = files.getAst(uri) if not ast then return nil end - local isOpen = files.isOpen(uri) - for _, name in ipairs(diagList) do await.delay() local results = {} diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua index 9c094701..8c7ae793 100644 --- a/script/core/diagnostics/lowercase-global.lua +++ b/script/core/diagnostics/lowercase-global.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local config = require 'config' local vm = require 'vm' diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua index 0727c2fd..75681cbc 100644 --- a/script/core/diagnostics/newfield-call.lua +++ b/script/core/diagnostics/newfield-call.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua index 807f76a2..159a60c9 100644 --- a/script/core/diagnostics/newline-call.lua +++ b/script/core/diagnostics/newline-call.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua index ffaab821..23af570a 100644 --- a/script/core/diagnostics/no-implicit-any.lua +++ b/script/core/diagnostics/no-implicit-any.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -10,7 +10,7 @@ return function (uri, callback) return end - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) if source.type ~= 'local' and source.type ~= 'setlocal' and source.type ~= 'setglobal' diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua index 857d80d2..48093417 100644 --- a/script/core/diagnostics/redefined-local.lua +++ b/script/core/diagnostics/redefined-local.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) @@ -13,6 +13,9 @@ return function (uri, callback) or name == ast.ENVMode then return end + if source.tag == 'self' then + return + end local exist = guide.getLocal(source, name, source.start-1) if exist then callback { diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index c5bcd5a5..eca7fc91 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' local define = require 'proto.define' @@ -84,7 +84,7 @@ return function (uri, callback) local funcArgs = cache[func] if funcArgs == nil then funcArgs = getFuncArgs(func) or false - local refs = vm.getRefs(func, 0) + local refs = vm.getRefs(func) for _, ref in ipairs(refs) do cache[ref] = funcArgs end diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua index 0a4b1d57..e54a6e60 100644 --- a/script/core/diagnostics/trailing-space.lua +++ b/script/core/diagnostics/trailing-space.lua @@ -1,6 +1,6 @@ local files = require 'files' local lang = require 'language' -local guide = require 'core.guide' +local guide = require 'parser.guide' local function isInString(ast, offset) local result = false diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua index b2b2800c..35aebb45 100644 --- a/script/core/diagnostics/unbalanced-assignments.lua +++ b/script/core/diagnostics/unbalanced-assignments.lua @@ -1,7 +1,7 @@ local files = require 'files' local define = require 'proto.define' local lang = require 'language' -local guide = require 'core.guide' +local guide = require 'parser.guide' return function (uri, callback, code) local ast = files.getAst(uri) diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua index a91cfa7f..d79f7ea4 100644 --- a/script/core/diagnostics/undefined-doc-class.lua +++ b/script/core/diagnostics/undefined-doc-class.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -25,7 +25,7 @@ return function (uri, callback) end for _, ext in ipairs(doc.extends) do local name = ext[1] - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) if cache[name] == nil then cache[name] = false for _, otherDoc in ipairs(docs) do diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index d8a4363b..871f16e1 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -22,7 +22,7 @@ return function (uri, callback) if classCache[name] ~= nil then return classCache[name] end - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' or otherDoc.type == 'doc.alias.name' then diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua index 0bf371e5..4a97947d 100644 --- a/script/core/diagnostics/undefined-doc-param.lua +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua index 89efb8c7..c97c3fe8 100644 --- a/script/core/diagnostics/undefined-env-child.lua +++ b/script/core/diagnostics/undefined-env-child.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local vm = require 'vm' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' +local lang = require 'language' return function (uri, callback) local ast = files.getAst(uri) @@ -13,7 +13,7 @@ return function (uri, callback) if source.node.tag == '_ENV' then return end - local defs = guide.requestDefinition(source) + local defs = searcher.requestDefinition(source) if #defs > 0 then return end diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index b10c9ab0..2d357d5b 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -2,9 +2,16 @@ local files = require 'files' local vm = require 'vm' local lang = require 'language' local config = require 'config' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' +local SkipCheckClass = { + ['unknown'] = true, + ['any'] = true, + ['table'] = true, + ['nil'] = true, +} + return function (uri, callback) local ast = files.getAst(uri) if not ast then @@ -18,7 +25,7 @@ return function (uri, callback) if cache[src] == nil then tracy.ZoneBeginN('undefined-field getInfers') infers = vm.getInfers(src, 0) or false - local refs = vm.getRefs(src, 0) + local refs = vm.getRefs(src) for _, ref in ipairs(refs) do cache[ref] = infers end @@ -47,7 +54,7 @@ return function (uri, callback) elseif inferSource.type == 'doc.class.name' then addTo(allDocClass, inferSource.parent) elseif inferSource.type == 'doc.type.name' then - local docTypes = vm.getDocTypes(inferSource[1]) + local docTypes = vm.getDocDefines(inferSource[1]) for _, docType in ipairs(docTypes) do if docType.type == 'doc.class.name' then addTo(allDocClass, docType.parent) @@ -65,7 +72,7 @@ return function (uri, callback) local empty = true for _, docClass in ipairs(allDocClass) do tracy.ZoneBeginN('undefined-field getDefFields') - local refs = vm.getDefFields(docClass) + local refs = vm.getDefs(docClass, '*') tracy.ZoneEnd() for _, ref in ipairs(refs) do @@ -87,35 +94,37 @@ return function (uri, callback) end local function checkUndefinedField(src) - local fieldName = guide.getKeyName(src) - - local allDocClass = getAllDocClassFromInfer(src.node) - if (not allDocClass) or (#allDocClass == 0) then - return - end - - local fields = getAllFieldsFromAllDocClass(allDocClass) - - -- 没找到任何 field,跳过检查 - if not fields then + if #vm.getDefs(src) > 0 then return end - - if not fields[fieldName] then - local message = lang.script('DIAG_UNDEF_FIELD', fieldName) - if src.type == 'getfield' and src.field then - callback { - start = src.field.start, - finish = src.field.finish, - message = message, - } - elseif src.type == 'getmethod' and src.method then - callback { - start = src.method.start, - finish = src.method.finish, - message = message, - } + local node = src.node + if node then + local defs = vm.getDefs(node) + local ok + for _, def in ipairs(defs) do + if def.type == 'doc.class.name' + and not SkipCheckClass[def[1]] then + ok = true + break + end end + if not ok then + return + end + end + local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src)) + if src.type == 'getfield' and src.field then + callback { + start = src.field.start, + finish = src.field.finish, + message = message, + } + elseif src.type == 'getmethod' and src.method then + callback { + start = src.method.start, + finish = src.method.finish, + message = message, + } end end guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField); diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index 161d8856..825b14f1 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -1,9 +1,8 @@ -local files = require 'files' -local vm = require 'vm' -local lang = require 'language' -local config = require 'config' -local guide = require 'core.guide' -local define = require 'proto.define' +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local config = require 'config' +local guide = require 'parser.guide' local requireLike = { ['include'] = true, @@ -20,7 +19,7 @@ return function (uri, callback) -- 遍历全局变量,检查所有没有 set 模式的全局变量 guide.eachSourceType(ast.ast, 'getglobal', function (src) - local key = guide.getKeyName(src) + local key = src[1] if not key then return end @@ -30,7 +29,11 @@ return function (uri, callback) if config.config.runtime.special[key] then return end - if #vm.getGlobalSets(key) == 0 then + local node = src.node + if node.tag ~= '_ENV' then + return + end + if #vm.getDefs(src) == 0 then local message = lang.script('DIAG_UNDEF_GLOBAL', key) if requireLike[key:lower()] then message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key)) diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua index b6f92e60..41c239f9 100644 --- a/script/core/diagnostics/unused-function.lua +++ b/script/core/diagnostics/unused-function.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local vm = require 'vm' local define = require 'proto.define' local lang = require 'language' diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua index e2d5e49a..e6d998ba 100644 --- a/script/core/diagnostics/unused-label.lua +++ b/script/core/diagnostics/unused-label.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua index fde90cb8..1a77a45f 100644 --- a/script/core/diagnostics/unused-local.lua +++ b/script/core/diagnostics/unused-local.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' @@ -87,6 +87,9 @@ return function (uri, callback) or name == ast.ENVMode then return end + if source.tag == 'self' then + return + end if isToBeClosed(source) then return end diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua index ec0a05fb..74cc08e7 100644 --- a/script/core/diagnostics/unused-vararg.lua +++ b/script/core/diagnostics/unused-vararg.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua index cc87e3ca..e36ba29b 100644 --- a/script/core/document-symbol.lua +++ b/script/core/document-symbol.lua @@ -1,8 +1,8 @@ -local await = require 'await' -local files = require 'files' -local guide = require 'core.guide' -local define = require 'proto.define' -local util = require 'utility' +local await = require 'await' +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local util = require 'utility' local function buildName(source, text) if source.type == 'setmethod' diff --git a/script/core/find-source.lua b/script/core/find-source.lua index b36306b6..edbb1e2c 100644 --- a/script/core/find-source.lua +++ b/script/core/find-source.lua @@ -1,4 +1,4 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local function isValidFunctionPos(source, offset) for i = 1, #source.keyword // 2 do diff --git a/script/core/folding.lua b/script/core/folding.lua index 15678995..1bbae944 100644 --- a/script/core/folding.lua +++ b/script/core/folding.lua @@ -1,5 +1,5 @@ local files = require "files" -local guide = require "core.guide" +local searcher = require "core.searcher" local util = require 'utility' local Care = { @@ -153,7 +153,7 @@ return function (uri) local regions = {} local status = {} - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) local tp = source.type if Care[tp] then Care[tp](source, text, regions) diff --git a/script/core/generic.lua b/script/core/generic.lua new file mode 100644 index 00000000..15950974 --- /dev/null +++ b/script/core/generic.lua @@ -0,0 +1,234 @@ +local guide = require 'parser.guide' +local noder = require "core.noder" + +---@class generic.value +---@field type string +---@field closure generic.closure +---@field proto parser.guide.object +---@field parent parser.guide.object + +---@class generic.closure +---@field type string +---@field proto parser.guide.object +---@field upvalues table<string, generic.value[]> +---@field params generic.value[] +---@field returns generic.value[] + +local m = {} + +---@param closure generic.closure +---@param proto parser.guide.object +local function instantValue(closure, proto) + ---@type generic.value + local value = { + type = 'generic.value', + closure = closure, + proto = proto, + parent = proto.parent, + } + closure.values[#closure.values+1] = value + return value +end + +---递归实例化对象 +---@param proto parser.guide.object +---@return generic.value +local function createValue(closure, proto, callback, road) + if callback then + road = road or {} + end + if proto.type == 'doc.type' then + local types = {} + local hasGeneric + for i, tp in ipairs(proto.types) do + local genericValue = createValue(closure, tp, callback, road) + if genericValue then + hasGeneric = true + types[i] = genericValue + else + types[i] = tp + end + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.types = types + noder.compileNode(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.name' then + if not proto.typeGeneric then + return nil + end + local key = proto[1] + local value = instantValue(closure, proto) + if callback then + callback(road, key, proto) + end + noder.compileNode(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.function' then + local hasGeneric + local args = {} + local returns = {} + for i, arg in ipairs(proto.args) do + local value = createValue(closure, arg, callback, road) + if value then + hasGeneric = true + end + args[i] = value or arg + end + for i, rtn in ipairs(proto.returns) do + local value = createValue(closure, rtn, callback, road) + if value then + hasGeneric = true + end + returns[i] = value or rtn + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.args = args + value.returns = returns + value.isGeneric = true + noder.pushSource(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.array' then + if road then + road[#road+1] = noder.ANY_FIELD + end + local node = createValue(closure, proto.node, callback, road) + if road then + road[#road] = nil + end + if not node then + return nil + end + local value = instantValue(closure, proto) + value.node = node + return value + end + if proto.type == 'doc.type.table' then + road[#road+1] = noder.TABLE_KEY + local tkey = createValue(closure, proto.tkey, callback, road) + road[#road] = nil + + road[#road+1] = noder.ANY_FIELD + local tvalue = createValue(closure, proto.tvalue, callback, road) + road[#road] = nil + + if not tkey and not tvalue then + return nil + end + local value = instantValue(closure, proto) + value.tkey = tkey or proto.tkey + value.tvalue = tvalue or proto.tvalue + return value + end +end + +local function buildValue(road, key, proto, param, upvalues) + local paramID + if proto.literal then + local str = param.type == 'string' and param[1] + if not str then + return + end + paramID = 'dn:' .. str + else + paramID = noder.getID(param) + end + if not paramID then + return + end + local myUri = guide.getUri(param) + local myHead = noder.URI_CHAR .. myUri .. noder.URI_CHAR + paramID = myHead .. paramID + if not upvalues[key] then + upvalues[key] = {} + end + upvalues[key][#upvalues[key]+1] = paramID .. table.concat(road) +end + +-- 为所有的 param 与 return 创建副本 +---@param closure generic.closure +local function buildValues(closure) + local protoFunction = closure.proto + local upvalues = closure.upvalues + local params = closure.call.args + + if protoFunction.type == 'function' then + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.param' then + local extends = doc.extends + local index = extends.paramIndex + if index then + local param = params and params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + end + end + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + closure.returns[rtn.returnIndex] = createValue(closure, rtn) or rtn + end + end + end + end + if protoFunction.type == 'doc.type.function' then + for index, arg in ipairs(protoFunction.args) do + local extends = arg.extends + local param = params and params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + for index, rtn in ipairs(protoFunction.returns) do + closure.returns[index] = createValue(closure, rtn) or rtn + end + end +end + +---创建一个闭包 +---@param proto parser.guide.object|generic.value # 原型函数|泛型值 +---@return generic.closure +function m.createClosure(proto, call) + local protoFunction, parentClosure + if proto.type == 'function' then + protoFunction = proto + elseif proto.type == 'doc.type.function' then + protoFunction = proto + elseif proto.type == 'generic.value' then + protoFunction = proto.proto + parentClosure = proto.closure + end + ---@type generic.closure + local closure = { + type = 'generic.closure', + parent = protoFunction.parent, + proto = protoFunction, + call = call, + upvalues = parentClosure and parentClosure.upvalues or {}, + params = {}, + returns = {}, + values = {}, + } + buildValues(closure) + + if #closure.returns == 0 then + return nil + end + + noder.compileNode(noder.getNoders(proto), closure) + + return closure +end + +return m diff --git a/script/core/guide.lua b/script/core/guide2.lua index e4871060..576c0c20 100644 --- a/script/core/guide.lua +++ b/script/core/guide2.lua @@ -292,8 +292,13 @@ end ---@param obj parser.guide.object ---@return parser.guide.object function m.getRoot(obj) + local source = obj + if source._root then + return source._root + end for _ = 1, 1000 do if obj.type == 'main' then + source._root = obj return obj end local parent = obj.parent diff --git a/script/core/highlight.lua b/script/core/highlight.lua index 12ec114f..45001134 100644 --- a/script/core/highlight.lua +++ b/script/core/highlight.lua @@ -1,31 +1,18 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local files = require 'files' local vm = require 'vm' local define = require 'proto.define' local findSource = require 'core.find-source' local util = require 'utility' +local guide = require 'parser.guide' local function eachRef(source, callback) - local results = guide.requestReference(source) + local results = searcher.requestReference(source) for i = 1, #results do callback(results[i]) end end -local function eachField(source, callback) - if not source then - return - end - local isGlobal = guide.isGlobal(source) - local results = guide.requestReference(source) - for i = 1, #results do - local res = results[i] - if isGlobal == guide.isGlobal(res) then - callback(res) - end - end -end - local function eachLocal(source, callback) callback(source) if source.ref then @@ -43,21 +30,21 @@ local function find(source, uri, callback) eachLocal(source.node, callback) elseif source.type == 'field' or source.type == 'method' then - eachField(source.parent, callback) + eachRef(source.parent, callback) elseif source.type == 'getindex' or source.type == 'setindex' or source.type == 'tableindex' then - eachField(source, callback) + eachRef(source, callback) elseif source.type == 'setglobal' or source.type == 'getglobal' then - eachField(source, callback) + eachRef(source, callback) elseif source.type == 'goto' or source.type == 'label' then eachRef(source, callback) elseif source.type == 'string' and source.parent and source.parent.index == source then - eachField(source.parent, callback) + eachRef(source.parent, callback) elseif source.type == 'string' or source.type == 'boolean' or source.type == 'number' @@ -238,6 +225,16 @@ local accept = { ['nil'] = true, } +local function isLiteralValue(source) + if not guide.isLiteral(source) then + return false + end + if source.parent.index == source then + return false + end + return true +end + return function (uri, offset) local ast = files.getAst(uri) if not ast then @@ -249,10 +246,25 @@ return function (uri, offset) local source = findSource(ast, offset, accept) if source then + local isGlobal = guide.isGlobal(source) + local isLiteral = isLiteralValue(source) find(source, uri, function (target) + if not target then + return + end if target.dummy then return end + if mark[target] then + return + end + mark[target] = true + if isGlobal ~= guide.isGlobal(target) then + return + end + if isLiteral ~= isLiteralValue(target) then + return + end local kind if target.type == 'getfield' then target = target.field @@ -315,13 +327,6 @@ return function (uri, offset) else return end - if not target then - return - end - if mark[target] then - return - end - mark[target] = true results[#results+1] = { start = target.start, finish = target.finish, diff --git a/script/core/hint.lua b/script/core/hint.lua index 13d01dc7..43b8726e 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local vm = require 'vm' -local config = require 'config' +local files = require 'files' +local searcher = require 'core.searcher' +local vm = require 'vm' +local config = require 'config' local function typeHint(uri, edits, start, finish) local ast = files.getAst(uri) @@ -9,7 +9,7 @@ local function typeHint(uri, edits, start, finish) return end local mark = {} - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) if source.type ~= 'local' and source.type ~= 'setglobal' and source.type ~= 'tablefield' @@ -21,7 +21,7 @@ local function typeHint(uri, edits, start, finish) if source[1] == '_' then return end - if source.value and guide.isLiteral(source.value) then + if source.value and searcher.isLiteral(source.value) then return end if source.parent.type == 'funcargs' then @@ -84,7 +84,7 @@ local function hasLiteralArgInCall(call) return false end for _, arg in ipairs(call.args) do - if guide.isLiteral(arg) then + if searcher.isLiteral(arg) then return true end end @@ -100,14 +100,14 @@ local function paramName(uri, edits, start, finish) return end local mark = {} - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) if source.type ~= 'call' then return end if not hasLiteralArgInCall(source) then return end - local defs = vm.getDefs(source.node, 0) + local defs = vm.getDefs(source.node) if not defs then return end @@ -130,7 +130,7 @@ local function paramName(uri, edits, start, finish) table.remove(args, 1) end for i, arg in ipairs(source.args) do - if not mark[arg] and guide.isLiteral(arg) then + if not mark[arg] and searcher.isLiteral(arg) then mark[arg] = true if args[i] and args[i] ~= '' then edits[#edits+1] = { diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua index 324d28af..822be2b6 100644 --- a/script/core/hover/arg.lua +++ b/script/core/hover/arg.lua @@ -1,4 +1,5 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' +local infer = require 'core.infer' local vm = require 'vm' local function optionalArg(arg) @@ -21,7 +22,7 @@ local function asFunction(source, oop) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(vm.getInferType(parent.node)) + args[#args+1] = ('self: %s'):format(infer.searchAndViewInfers(parent.node)) end if source.args then for i = 1, #source.args do @@ -34,10 +35,12 @@ local function asFunction(source, oop) args[#args+1] = ('%s%s: %s'):format( name, optionalArg(arg) and '?' or '', - vm.getInferType(arg) + infer.searchAndViewInfers(arg) ) + elseif arg.type == '...' then + args[#args+1] = '...' else - args[#args+1] = ('%s'):format(vm.getInferType(arg)) + args[#args+1] = ('%s'):format(infer.searchAndViewInfers(arg)) end ::CONTINUE:: end @@ -61,7 +64,7 @@ local function asDocFunction(source) args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - vm.getInferType(arg.extends) + infer.searchAndViewInfers(arg.extends) ) else args[i] = ('%s%s'):format( diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index 401ca5a7..bcc3065a 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -2,11 +2,13 @@ local vm = require 'vm' local ws = require 'workspace' local furi = require 'file-uri' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local markdown = require 'provider.markdown' local config = require 'config' local lang = require 'language' local util = require 'utility' +local guide = require 'parser.guide' +local noder = require 'core.noder' local function asStringInRequire(source, literal) local rootPath = ws.path or '' @@ -124,10 +126,10 @@ local function getBindComment(source, docGroup, base) end local function tryDocClassComment(source) - for _, def in ipairs(vm.getDefs(source, 0)) do + for _, def in ipairs(vm.getDefs(source)) do if def.type == 'doc.class.name' or def.type == 'doc.alias.name' then - local class = guide.getDocState(def) + local class = noder.getDocState(def) local comment = getBindComment(class, class.bindGroup, class) if comment then return comment @@ -180,7 +182,7 @@ local function isFunction(source) if source.type == 'function' then return true end - local value = guide.getObjectValue(source) + local value = searcher.getObjectValue(source) if not value then return false end @@ -223,13 +225,14 @@ local function getBindEnums(source, docGroup) end local function tryDocFieldUpComment(source) - if source.type ~= 'doc.field' then + if source.type ~= 'doc.field.name' then return end - if not source.bindGroup then + local docField = source.parent + if not docField.bindGroup then return end - local comment = getBindComment(source, source.bindGroup, source) + local comment = getBindComment(docField, docField.bindGroup, docField) return comment end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index 0c8644ed..5dd00c43 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local getLabel = require 'core.hover.label' local getDesc = require 'core.hover.description' @@ -7,6 +7,7 @@ local util = require 'utility' local findSource = require 'core.find-source' local lang = require 'language' local markdown = require 'provider.markdown' +local infer = require 'core.infer' local function eachFunctionAndOverload(value, callback) callback(value) @@ -24,7 +25,7 @@ local function getHoverAsValue(source) local label = getLabel(source) local desc = getDesc(source) if not desc then - local values = vm.getDefs(source, 0) + local values = vm.getDefs(source) for _, def in ipairs(values) do desc = getDesc(def) if desc then @@ -40,7 +41,7 @@ local function getHoverAsValue(source) end local function getHoverAsFunction(source) - local values = vm.getDefs(source, 0) + local values = vm.getDefs(source) local desc = getDesc(source) local labels = {} local defs = 0 @@ -48,7 +49,7 @@ local function getHoverAsFunction(source) local other = 0 local mark = {} for _, def in ipairs(values) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def.type == 'function' or def.type == 'doc.type.function' then eachFunctionAndOverload(def, function (value) @@ -123,7 +124,7 @@ local function getHover(source) if source.type == 'doc.type.name' then return getHoverAsDocName(source) end - local isFunction = vm.hasInferType(source, 'function', 0) + local isFunction = infer.hasType(source, 'function') if isFunction then return getHoverAsFunction(source) else diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index d93b14e3..032f19c0 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -2,9 +2,10 @@ local buildName = require 'core.hover.name' local buildArg = require 'core.hover.arg' local buildReturn = require 'core.hover.return' local buildTable = require 'core.hover.table' +local infer = require 'core.infer' local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local config = require 'config' local files = require 'files' @@ -31,29 +32,28 @@ local function asDocFunction(source) end local function asDocTypeName(source) - for _, doc in ipairs(vm.getDocTypes(source[1])) do + local defs = searcher.requestDefinition(source) + for _, doc in ipairs(defs) do if doc.type == 'doc.class.name' then - return 'class ' .. source[1] + return 'class ' .. doc[1] end if doc.type == 'doc.alias.name' then local extends = doc.parent.extends - return lang.script('HOVER_EXTENDS', vm.getInferType(extends)) + return lang.script('HOVER_EXTENDS', infer.searchAndViewInfers(extends)) end end end local function asValue(source, title) local name = buildName(source) - local infers = vm.getInfers(source, 0) - local type = vm.getInferType(source, 0) - local class = vm.getClass(source, 0) - local literal = vm.getInferLiteral(source, 0) + local type = infer.searchAndViewInfers(source) + local literal = infer.searchAndViewLiterals(source) local cont - if not vm.hasInferType(source, 'string', 0) + if not infer.hasType(source, 'string') and not type:find('%[%]$') and not type:find('%w%<') then - if #vm.getFields(source, 0) > 0 - or vm.hasInferType(source, 'table', 0) then + if #vm.getRefs(source, '*') > 0 + or infer.hasType(source, 'table') then cont = buildTable(source) end end @@ -66,11 +66,7 @@ local function asValue(source, title) or type == 'nil') then type = nil end - if class then - pack[#pack+1] = class - else - pack[#pack+1] = type - end + pack[#pack+1] = type if literal then pack[#pack+1] = '=' pack[#pack+1] = literal @@ -123,30 +119,21 @@ local function asField(source) return asValue(source, 'field') end -local function asDocField(source) - local name = source.field[1] +local function asDocFieldName(source) + local name = source[1] + local docField = source.parent local class - for _, doc in ipairs(source.bindGroup) do + for _, doc in ipairs(docField.bindGroup) do if doc.type == 'doc.class' then class = doc break end end - local infers = {} - for _, infer in ipairs(vm.getInfers(source.extends) or {}) do - infers[#infers+1] = infer - end + local view = infer.searchAndViewInfers(docField.extends) if not class then - return ('field ?.%s: %s'):format( - name, - guide.viewInferType(infers) - ) - end - return ('field %s.%s: %s'):format( - class.class[1], - name, - guide.viewInferType(infers) - ) + return ('field ?.%s: %s'):format(name, view) + end + return ('field %s.%s: %s'):format(class.class[1], name, view) end local function asString(source) @@ -177,7 +164,7 @@ local function asNumber(source) if type(num) ~= 'number' then return nil end - local uri = guide.getUri(source) + local uri = searcher.getUri(source) local text = files.getText(uri) if not text then return nil @@ -215,7 +202,7 @@ return function (source, oop) return asDocFunction(source) elseif source.type == 'doc.type.name' then return asDocTypeName(source) - elseif source.type == 'doc.field' then - return asDocField(source) + elseif source.type == 'doc.field.name' then + return asDocFieldName(source) end end diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua index d583f1e1..d2b9d30b 100644 --- a/script/core/hover/name.lua +++ b/script/core/hover/name.lua @@ -1,4 +1,6 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' +local infer = require 'core.infer' +local guide = require 'parser.guide' local vm = require 'vm' local buildName @@ -19,7 +21,7 @@ end local function asField(source, oop) local class if source.node.type ~= 'getglobal' then - class = vm.getClass(source.node, 0) + class = infer.getClass(source.node) end local node = class or guide.getKeyName(source.node) or '?' local method = guide.getKeyName(source) diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index c3e9656d..0f0d85e0 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,12 +1,4 @@ -local guide = require 'core.guide' -local vm = require 'vm' - -local function mergeTypes(returns) - if type(returns) == 'string' then - return returns - end - return guide.mergeTypes(returns) -end +local infer = require 'core.infer' local function getReturnDualByDoc(source) local docs = source.bindDocs @@ -55,24 +47,20 @@ local function asFunction(source) local returns = {} for i, rtn in ipairs(dual) do local line = {} - local types = {} + local infers = {} if i == 1 then line[#line+1] = ' -> ' else line[#line+1] = ('% 3d. '):format(i) end for n = 1, #rtn do - local values = vm.getInfers(rtn[n]) - for _, value in ipairs(values) do - if value.type then - for tp in value.type:gmatch '[^|]+' do - types[tp] = true - end - end + local values = infer.searchInfers(rtn[n]) + for tp in pairs(values) do + infers[tp] = true end end - if next(types) or rtn[1] then - local tp = mergeTypes(types) or 'any' + if next(infers) or rtn[1] then + local tp = infer.viewInfers(infers) if rtn[1].name then line[#line+1] = ('%s%s: %s'):format( rtn[1].name[1], @@ -103,7 +91,7 @@ local function asDocFunction(source) local returns = {} for i, rtn in ipairs(source.returns) do local rtnText = ('%s%s'):format( - vm.getInferType(rtn), + infer.searchAndViewInfers(rtn), rtn.optional and '?' or '' ) if i == 1 then diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index edb7751b..159453e6 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -1,26 +1,12 @@ local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local config = require 'config' local lang = require 'language' +local infer = require 'core.infer' -local function getKey(src) - local key = vm.getKeyName(src) - if not key or #key <= 0 then - if not src.index then - return '[any]' - end - local class = vm.getClass(src.index) - if class then - return ('[%s]'):format(class) - end - local tp = vm.getInferType(src.index) - if tp then - return ('[%s]'):format(tp) - end - return '[any]' - end - if guide.getKeyType(src) == 'string' then +local function formatKey(key) + if type(key) == 'string' then if key:match '^[%a_][%w_]*$' then return key else @@ -30,104 +16,16 @@ local function getKey(src) return ('[%s]'):format(key) end -local function getFieldFull(src) - local value = guide.getObjectValue(src) or src - local tp = vm.getInferType(value, 0) - --local class = vm.getClass(src) - local literal = vm.getInferLiteral(value) - if type(literal) == 'string' and #literal >= 50 then - literal = literal:sub(1, 47) .. '...' - end - return tp, literal -end - -local function getFieldFast(src) - if src.bindDocs then - return getFieldFull(src) - end - local value = guide.getObjectValue(src) or src - if not value then - return 'any' - end - if value.type == 'boolean' then - return value.type, util.viewLiteral(value[1]) - end - if value.type == 'number' - or value.type == 'integer' then - if math.tointeger(value[1]) then - if config.config.runtime.version == 'Lua 5.3' - or config.config.runtime.version == 'Lua 5.4' then - return 'integer', util.viewLiteral(value[1]) - end - end - return value.type, util.viewLiteral(value[1]) - end - if value.type == 'table' - or value.type == 'function' then - return value.type - end - if value.type == 'string' then - local literal = value[1] - if type(literal) == 'string' and #literal >= 50 then - literal = literal:sub(1, 47) .. '...' - end - return value.type, util.viewLiteral(literal) - end - if value.type == 'doc.field' then - return vm.getInferType(value) - end -end - -local function getField(src, timeUp, mark, key) - if src.type == 'table' - or src.type == 'function' then - return nil - end - if src.parent then - if src.type == 'string' - or src.type == 'boolean' - or src.type == 'number' - or src.type == 'integer' then - if src.parent.type == 'tableindex' - or src.parent.type == 'setindex' - or src.parent.type == 'getindex' then - if src.parent.index == src then - src = src.parent - end - end - end - end - local tp, literal - tp, literal = getFieldFast(src) - if tp then - return tp, literal - end - if timeUp or mark[key] then - return nil - end - mark[key] = true - tp, literal = getFieldFull(src) - if tp then - return tp, literal - end - return nil -end - -local function buildAsHash(classes, literals, reachMax) - local keys = {} - for k in pairs(classes) do - keys[#keys+1] = k - end - table.sort(keys) +local function buildAsHash(keys, inferMap, literalMap, reachMax) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local class = classes[key] - local literal = literals[key] - if literal then - lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + local inferView = inferMap[key] + local literalView = literalMap[key] + if literalView then + lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView) else - lines[#lines+1] = (' %s: %s,'):format(key, class) + lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView) end end if reachMax then @@ -137,23 +35,19 @@ local function buildAsHash(classes, literals, reachMax) return table.concat(lines, '\n') end -local function buildAsConst(classes, literals, reachMax) - local keys = {} - for k in pairs(classes) do - keys[#keys+1] = k - end +local function buildAsConst(keys, inferMap, literalMap, reachMax) table.sort(keys, function (a, b) - return tonumber(literals[a]) < tonumber(literals[b]) + return tonumber(literalMap[a]) < tonumber(literalMap[b]) end) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local class = classes[key] - local literal = literals[key] - if literal then - lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + local inferView = inferMap[key] + local literalView = literalMap[key] + if literalView then + lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView) else - lines[#lines+1] = (' %s: %s,'):format(key, class) + lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView) end end if reachMax then @@ -163,111 +57,79 @@ local function buildAsConst(classes, literals, reachMax) return table.concat(lines, '\n') end -local function mergeLiteral(literals) - local results = {} +local typeSorter = { + ['string'] = 1, + ['number'] = 2, + ['boolean'] = 3, +} + +local function getKeyMap(fields) + local keys = {} local mark = {} - for _, value in ipairs(literals) do - if not mark[value] then - mark[value] = true - results[#results+1] = value + for _, field in ipairs(fields) do + local key = vm.getKeyName(field) + local tp = vm.getKeyType(field) + if tp == 'number' then + key = tonumber(key) + elseif tp == 'boolean' then + key = key == 'true' end - end - if #results == 0 then - return nil - end - table.sort(results) - return table.concat(results, '|') -end - -local function mergeTypes(types) - local results = {} - local mark = { - -- 讲道理table的keyvalue不会是nil - ['nil'] = true, - } - for _, tv in ipairs(types) do - for tp in tv:gmatch '[^|]+' do - if not mark[tp] then - mark[tp] = true - results[tp] = true - end + if key and not mark[key] then + mark[key] = true + keys[#keys+1] = key end end - return guide.mergeTypes(results) -end - -local function clearClasses(classes) - classes['[nil]'] = nil - classes['[any]'] = nil - classes['[string]'] = nil + table.sort(keys, function (a, b) + local ta = typeSorter[type(a)] + local tb = typeSorter[type(b)] + if ta == tb then + return tostring(a) < tostring(b) + else + return ta < tb + end + end) + return keys end return function (source) - if config.config.hover.previewFields <= 0 then + local maxFields = config.config.hover.previewFields + if maxFields <= 0 then return 'table' end - local literals = {} - local classes = {} - local clock = os.clock() - local timeUp - local mark = {} - local fields = vm.getFields(source, 0) - local keyCount = 0 - local reachMax - for _, src in ipairs(fields) do - local key = getKey(src) - if not key then - goto CONTINUE - end - if not classes[key] then - classes[key] = {} - keyCount = keyCount + 1 - end - if not literals[key] then - literals[key] = {} - end - if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then - timeUp = true - end - local class, literal = getField(src, timeUp, mark, key) - if literal == 'nil' then - literal = nil - end - classes[key][#classes[key]+1] = class - literals[key][#literals[key]+1] = literal - if keyCount >= config.config.hover.previewFields then - reachMax = true - break - end - ::CONTINUE:: - end - - clearClasses(classes) - for key, class in pairs(classes) do - literals[key] = mergeLiteral(literals[key]) - classes[key] = mergeTypes(class) - end + local fields = vm.getRefs(source, '*') + local keys = getKeyMap(fields) - if not next(classes) then + if #keys == 0 then return '{}' end - local intValue = true - for key, class in pairs(classes) do - if class ~= 'integer' or not tonumber(literals[key]) then - intValue = false - break + local inferMap = {} + local literalMap = {} + + local reachMax = maxFields < #keys + + local isConsts = true + for i = 1, math.min(maxFields, #keys) do + local key = keys[i] + inferMap[key] = infer.searchAndViewInfers(source, key) + literalMap[key] = infer.searchAndViewLiterals(source, key) + if not tonumber(literalMap[key]) then + isConsts = false end end + local result - if intValue then - result = buildAsConst(classes, literals, reachMax) + + if isConsts then + result = buildAsConst(keys, inferMap, literalMap, reachMax) else - result = buildAsHash(classes, literals, reachMax) - end - if timeUp then - result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + result = buildAsHash(keys, inferMap, literalMap, reachMax) end + + --if timeUp then + -- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + --end + return result end diff --git a/script/core/infer.lua b/script/core/infer.lua new file mode 100644 index 00000000..77236811 --- /dev/null +++ b/script/core/infer.lua @@ -0,0 +1,634 @@ +local searcher = require 'core.searcher' +local config = require 'config' +local noder = require 'core.noder' +local util = require 'utility' + +local STRING_OR_TABLE = {'STRING_OR_TABLE'} +local BE_RETURN = {'BE_RETURN'} +local BE_CONNACT = {'BE_CONNACT'} +local CLASS = {'CLASS'} +local TABLE = {'TABLE'} + +local TypeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['true'] = 101, + ['false'] = 102, + ['nil'] = 999, +} + +local m = {} + +local function mergeTable(a, b) + if not b then + return + end + for v in pairs(b) do + a[v] = true + end +end + +local function searchInferOfUnary(value, infers) + local op = value.op.type + if op == 'not' then + infers['boolean'] = true + return + end + if op == '#' then + infers['integer'] = true + return + end + if op == '-' then + if m.hasType(value[1], 'integer') then + infers['integer'] = true + else + infers['number'] = true + end + return + end + if op == '~' then + infers['integer'] = true + return + end +end + +local function searchInferOfBinary(value, infers) + local op = value.op.type + if op == 'and' then + if m.isTrue(value[1]) then + mergeTable(infers, m.searchInfers(value[2])) + else + mergeTable(infers, m.searchInfers(value[1])) + end + return + end + if op == 'or' then + if m.isTrue(value[1]) then + mergeTable(infers, m.searchInfers(value[1])) + else + mergeTable(infers, m.searchInfers(value[2])) + end + return + end + if op == '==' + or op == '~=' + or op == '<' + or op == '>' + or op == '<=' + or op == '>=' then + infers['boolean'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '&' + or op == '|' then + infers['integer'] = true + return + end + if op == '..' then + infers['string'] = true + return + end + if op == '^' + or op == '/' then + infers['number'] = true + return + end + if op == '+' + or op == '-' + or op == '*' + or op == '%' + or op == '//' then + if m.hasType(value[1], 'integer') + and m.hasType(value[2], 'integer') then + infers['integer'] = true + else + infers['number'] = true + end + return + end +end + +local function searchInferOfValue(value, infers) + if value.type == 'string' then + infers['string'] = true + return true + end + if value.type == 'boolean' then + infers['boolean'] = true + return true + end + if value.type == 'table' then + if value.array then + local node = m.searchAndViewInfers(value.array) + local infer = node .. '[]' + infers[infer] = true + else + infers['table'] = true + end + return true + end + if value.type == 'number' then + if math.type(value[1]) == 'integer' then + infers['integer'] = true + else + infers['number'] = true + end + return true + end + if value.type == 'nil' then + infers['nil'] = true + return true + end + if value.type == 'function' then + infers['function'] = true + return true + end + if value.type == 'unary' then + searchInferOfUnary(value, infers) + return true + end + if value.type == 'binary' then + searchInferOfBinary(value, infers) + return true + end + return false +end + +local function searchLiteralOfValue(value, literals) + if value.type == 'string' + or value.type == 'boolean' + or value.type == 'number' + or value.type == 'integer' then + local v = value[1] + if v ~= nil then + literals[v] = true + end + return + end + if value.type == 'unary' then + local op = value.op.type + if op == '-' then + local subLiterals = m.searchLiterals(value[1]) + if subLiterals then + for subLiteral in pairs(subLiterals) do + local num = tonumber(subLiteral) + if num then + literals[-num] = true + end + end + end + end + if op == '~' then + local subLiterals = m.searchLiterals(value[1]) + if subLiterals then + for subLiteral in pairs(subLiterals) do + local num = math.tointeger(subLiteral) + if num then + literals[~num] = true + end + end + end + end + end + return +end + +local function bindClassOrType(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' + or doc.type == 'doc.type' then + return true + end + end + return false +end + +local function cleanInfers(infers) + local version = config.config.runtime.version + local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4' + infers['unknown'] = nil + if infers['any'] and infers['nil'] then + infers['nil'] = nil + end + if infers['number'] then + enableInteger = false + end + if not enableInteger and infers['integer'] then + infers['integer'] = nil + infers['number'] = true + end + -- stringlib 就是 string + if infers['stringlib'] and infers['string'] then + infers['stringlib'] = nil + end + -- 如果是通过 .. 来推测的,且结果里没有 number 与 integer,则推测为string + if infers[BE_CONNACT] then + infers[BE_CONNACT] = nil + if not infers['number'] and not infers['integer'] then + infers['string'] = true + end + end + -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型 + if infers[STRING_OR_TABLE] then + infers[STRING_OR_TABLE] = nil + if not infers['table'] and not infers['string'] then + infers['table'] = true + infers['string'] = true + end + end + -- 如果有doc标记,则先移除table类型 + if infers[CLASS] then + infers[CLASS] = nil + infers['table'] = nil + end + -- 用doc标记的table,加入table类型 + if infers[TABLE] then + infers[TABLE] = nil + infers['table'] = true + end + if infers[BE_RETURN] then + infers[BE_RETURN] = nil + infers['nil'] = nil + end + infers['any'] = nil +end + +---合并对象的推断类型 +---@param infers string[] +---@return string +function m.viewInfers(infers) + if infers[0] then + return infers[0] + end + -- 如果有显性的 any ,则直接显示为 any + if infers['any'] then + infers[0] = 'any' + return 'any' + end + local result = {} + local count = 0 + for infer in pairs(infers) do + count = count + 1 + result[count] = infer + end + -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any + if count == 0 then + infers[0] = 'any' + return 'any' + end + table.sort(result, function (a, b) + local sa = TypeSort[a] or 100 + local sb = TypeSort[b] or 100 + if sa == sb then + return a < b + else + return sa < sb + end + end) + infers[0] = table.concat(result, '|') + return infers[0] +end + +---合并对象的值 +---@param literals string[] +---@return string +function m.viewLiterals(literals) + local result = {} + local count = 0 + for infer in pairs(literals) do + count = count + 1 + result[count] = util.viewLiteral(infer) + end + if count == 0 then + return nil + end + table.sort(result) + local view = table.concat(result, '|') + return view +end + +function m.viewDocName(doc) + if not doc then + return nil + end + if doc.type == 'doc.type' then + local list = {} + for _, tp in ipairs(doc.types) do + list[#list+1] = m.getDocName(tp) + end + for _, enum in ipairs(doc.enums) do + list[#list+1] = m.getDocName(enum) + end + return table.concat(list, '|') + end + return m.getDocName(doc) +end + +function m.getDocName(doc) + if not doc then + return nil + end + if doc.type == 'doc.class.name' + or doc.type == 'doc.type.name' then + local name = doc[1] or '?' + if doc.typeGeneric then + return '<' .. name .. '>' + else + return name + end + end + if doc.type == 'doc.type.array' then + local nodeName = m.viewDocName(doc.node) or '?' + return nodeName .. '[]' + end + if doc.type == 'doc.type.table' then + local key = m.viewDocName(doc.tkey) or '?' + local value = m.viewDocName(doc.tvalue) or '?' + return ('table<%s, %s>'):format(key, value) + end + if doc.type == 'doc.type.function' then + return 'function' + end + if doc.type == 'doc.type.enum' + or doc.type == 'doc.resume' then + local value = doc[1] or '?' + return value + end +end + +function m.viewDocFunction(doc) + if doc.type ~= 'doc.type.function' then + return '' + end + local args = {} + for i, arg in ipairs(doc.args) do + args[i] = ('%s: %s'):format(arg.name[1], m.viewDocName(arg.extends)) + end + local label = ('fun(%s)'):format(table.concat(args, ', ')) + if #doc.returns > 0 then + local returns = {} + for i, rtn in ipairs(doc.returns) do + returns[i] = m.viewDocName(rtn) + end + label = ('%s:%s'):format(label, table.concat(returns)) + end + return label +end + +---显示对象的推断类型 +---@param source parser.guide.object +---@return string +local function searchInfer(source, infers) + if bindClassOrType(source) then + return + end + if searchInferOfValue(source, infers) then + return + end + local value = searcher.getObjectValue(source) + if value then + searchInferOfValue(value, infers) + return + end + -- check LuaDoc + local docName = m.getDocName(source) + if docName then + infers[docName] = true + if docName ~= 'unknown' then + infers[CLASS] = true + end + if docName == 'table' then + infers[TABLE] = true + end + end + if source.parent.type == 'unary' then + local op = source.parent.op.type + -- # XX -> string | table + if op == '#' then + infers[STRING_OR_TABLE] = true + return + end + if op == '-' then + infers['number'] = true + return + end + if op == '~' then + infers['integer'] = true + return + end + return + end + if source.parent.type == 'binary' then + local op = source.parent.op.type + if op == '+' + or op == '-' + or op == '*' + or op == '/' + or op == '//' + or op == '^' + or op == '%' then + infers['number'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '|' + or op == '&' then + infers['integer'] = true + return + end + if op == '..' then + infers[BE_CONNACT] = true + return + end + end + -- X.a -> table + if source.next and source.next.node == source then + if source.next.type == 'setfield' + or source.next.type == 'setindex' + or source.next.type == 'setmethod' + or source.next.type == 'getfield' + or source.next.type == 'getindex' then + infers['table'] = true + end + if source.next.type == 'getmethod' then + infers[STRING_OR_TABLE] = true + end + end + -- return XX + if source.parent.type == 'return' then + infers[BE_RETURN] = true + end +end + +local function searchLiteral(source, literals) + local value = searcher.getObjectValue(source) + if value then + searchLiteralOfValue(value, literals) + return + end +end + +---搜索对象的推断类型 +---@param source parser.guide.object +---@param field? string +---@return string[] +function m.searchInfers(source, field) + if not source then + return nil + end + local defs = searcher.requestDefinition(source, field) + local infers = {} + local mark = {} + if not field then + mark[source] = true + searchInfer(source, infers) + local id = noder.getID(source) + if id then + local node = noder.getNodeByID(source, id) + if node and node.sources then + for _, src in ipairs(node.sources) do + if not mark[src] then + mark[src] = true + searchInfer(src, infers) + end + end + end + end + end + if source.type == 'field' or source.type == 'method' then + mark[source.parent] = true + searchInfer(source.parent, infers) + end + for _, def in ipairs(defs) do + if not mark[def] then + mark[def] = true + searchInfer(def, infers) + end + end + if source.docParam then + local docType = source.docParam.extends + if docType.type == 'doc.type' then + for _, def in ipairs(docType.types) do + if def.typeGeneric and not mark[def] then + mark[def] = true + searchInfer(def, infers) + end + end + end + end + if source.type == 'doc.type' then + if source.type == 'doc.type' then + for _, def in ipairs(source.types) do + if def.typeGeneric and not mark[def] then + mark[def] = true + searchInfer(def, infers) + end + end + end + end + cleanInfers(infers) + return infers +end + +---搜索对象的字面量值 +---@param source parser.guide.object +---@param field? string +---@return table +function m.searchLiterals(source, field) + local defs = searcher.requestDefinition(source, field) + local literals = {} + local mark = {} + if not field then + mark[source] = true + searchLiteral(source, literals) + end + for _, def in ipairs(defs) do + if not mark[def] then + mark[def] = true + searchLiteral(def, literals) + end + end + return literals +end + +---搜索并显示推断值 +---@param source parser.guide.object +---@param field? string +---@return string +function m.searchAndViewLiterals(source, field) + if not source then + return nil + end + local literals = m.searchLiterals(source, field) + local view = m.viewLiterals(literals) + return view +end + +---判断对象的推断值是否是 true +---@param source parser.guide.object +function m.isTrue(source) + if not source then + return false + end + local literals = m.searchLiterals(source) + for literal in pairs(literals) do + if literal ~= false then + return true + end + end + return false +end + +---判断对象的推断类型是否包含某个类型 +function m.hasType(source, tp) + local infers = m.searchInfers(source) + return infers[tp] or false +end + +---搜索并显示推断类型 +---@param source parser.guide.object +---@param field? string +---@return string +function m.searchAndViewInfers(source, field) + if not source then + return 'any' + end + local infers = m.searchInfers(source, field) + local view = m.viewInfers(infers) + return view +end + +---搜索并显示推断的class +---@param source parser.guide.object +---@return string? +function m.getClass(source) + if not source then + return nil + end + local infers = {} + local defs = searcher.requestDefinition(source) + for _, def in ipairs(defs) do + if def.type == 'doc.class.name' then + infers[def[1]] = true + end + end + local view = m.viewInfers(infers) + if view == 'any' then + return nil + end + return view +end + +return m diff --git a/script/core/keyword.lua b/script/core/keyword.lua index 71ea4969..73892f18 100644 --- a/script/core/keyword.lua +++ b/script/core/keyword.lua @@ -1,6 +1,6 @@ local define = require 'proto.define' -local guide = require 'core.guide' local files = require 'files' +local guide = require 'parser.guide' local keyWordMap = { {'do', function (info, results) diff --git a/script/core/noder.lua b/script/core/noder.lua new file mode 100644 index 00000000..c3679612 --- /dev/null +++ b/script/core/noder.lua @@ -0,0 +1,926 @@ +local util = require 'utility' +local guide = require 'parser.guide' + +local LastIDCache = {} +local FirstIDCache = {} +local SPLIT_CHAR = '\x1F' +local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$' +local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*' +local ANY_FIELD_CHAR = '*' +local RETURN_INDEX = SPLIT_CHAR .. '#' +local PARAM_INDEX = SPLIT_CHAR .. '&' +local TABLE_KEY = SPLIT_CHAR .. '<' +local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR +local URI_CHAR = '@' +local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(.*)' + +---@class node +-- 当前节点的id +---@field id string +-- 使用该ID的单元 +---@field sources parser.guide.object[] +-- 前进的关联ID +---@field forward string[] +-- 后退的关联ID +---@field backward string[] +-- 函数调用参数信息(用于泛型) +---@field call parser.guide.object + +---@alias noders table<string, node[]> + +---创建source的链接信息 +---@param noders noders +---@param id string +---@return node +local function getNode(noders, id) + if not noders[id] then + noders[id] = { + id = id, + } + end + return noders[id] +end + +---获取语法树单元的key +---@param source parser.guide.object +---@return string? key +---@return parser.guide.object? node +local function getKey(source) + if source.type == 'local' then + return tostring(source.start), nil + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return tostring(source.node.start), nil + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + local node = source.node + if node.tag == '_ENV' then + return ('%q'):format(source[1] or ''), nil + else + return ('%q'):format(source[1] or ''), node + end + elseif source.type == 'getfield' + or source.type == 'setfield' then + return ('%q'):format(source.field and source.field[1] or ''), source.node + elseif source.type == 'tablefield' then + return ('%q'):format(source.field and source.field[1] or ''), source.parent + elseif source.type == 'getmethod' + or source.type == 'setmethod' then + return ('%q'):format(source.method and source.method[1] or ''), source.node + elseif source.type == 'setindex' + or source.type == 'getindex' then + local index = source.index + if not index then + return ANY_FIELD_CHAR, source.node + end + if index.type == 'string' + or index.type == 'boolean' + or index.type == 'number' then + return ('%q'):format(index[1] or ''), source.node + elseif index.type ~= 'function' + and index.type ~= 'table' then + return ANY_FIELD_CHAR, source.node + end + elseif source.type == 'tableindex' then + local index = source.index + if not index then + return ANY_FIELD_CHAR, source.parent + end + if index.type == 'string' + or index.type == 'boolean' + or index.type == 'number' then + return ('%q'):format(index[1] or ''), source.parent + elseif index.type ~= 'function' + and index.type ~= 'table' then + return ANY_FIELD_CHAR, source.parent + end + elseif source.type == 'table' then + return source.start, nil + elseif source.type == 'label' then + return source.start, nil + elseif source.type == 'goto' then + if source.node then + return source.node.start, nil + end + return nil, nil + elseif source.type == 'function' then + return source.start, nil + elseif source.type == 'string' then + return '', nil + elseif source.type == 'integer' + or source.type == 'number' + or source.type == 'boolean' + or source.type == 'nil' then + return source.start, nil + elseif source.type == '...' then + return source.start, nil + elseif source.type == 'varargs' then + if source.node then + return source.node.start, nil + end + elseif source.type == 'select' then + return ('%d%s%d'):format(source.start, RETURN_INDEX, source.sindex) + elseif source.type == 'call' then + local node = source.node + if node.special == 'rawget' + or node.special == 'rawset' then + if not source.args then + return nil, nil + end + local tbl, key = source.args[1], source.args[2] + if not tbl or not key then + return nil, nil + end + if key.type == 'string' then + return ('%q'):format(key[1] or ''), tbl + else + return '', tbl + end + end + return source.finish, nil + elseif source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.see.name' then + local name = source[1] + return name, nil + elseif source.type == 'doc.type.name' then + local name = source[1] + if source.typeGeneric then + return source.typeGeneric[name][1].start, nil + else + return name, nil + end + elseif source.type == 'doc.class' + or source.type == 'doc.type' + or source.type == 'doc.param' + or source.type == 'doc.vararg' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.table' + or source.type == 'doc.type.array' + or source.type == 'doc.type.function' then + return source.start, nil + elseif source.type == 'doc.see.field' then + return ('%q'):format(source[1]), source.parent.name + elseif source.type == 'generic.closure' then + return source.call.start, nil + elseif source.type == 'generic.value' then + return ('%s|%s'):format( + source.closure.call.start, + getKey(source.proto) + ) + end + return nil, nil +end + +local function checkMode(source) + if source.type == 'table' then + return 't:' + end + if source.type == 'select' then + return 's:' + end + if source.type == 'function' then + return 'f:' + end + if source.type == 'string' then + return 'str:' + end + if source.type == 'number' + or source.type == 'integer' + or source.type == 'boolean' + or source.type == 'nil' then + return 'i:' + end + if source.type == 'call' then + return 'c:' + end + if source.type == '...' + or source.type == 'varargs' then + return 'va:' + end + if source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' then + if source.typeGeneric then + return 'dg:' + end + return 'dn:' + end + if source.type == 'doc.field.name' then + return 'dfn:' + end + if source.type == 'doc.see.name' then + return 'dsn:' + end + if source.type == 'doc.class' then + return 'dc:' + end + if source.type == 'doc.type' then + return 'dt:' + end + if source.type == 'doc.param' then + return 'dp:' + end + if source.type == 'doc.type.function' then + return 'dfun:' + end + if source.type == 'doc.type.table' then + return 'dtable:' + end + if source.type == 'doc.type.array' then + return 'darray:' + end + if source.type == 'doc.vararg' then + return 'dv:' + end + if source.type == 'doc.type.enum' + or source.type == 'doc.resume' then + return 'de:' + end + if source.type == 'generic.closure' then + return 'gc:' + end + if source.type == 'generic.value' then + local id = 'gv:' + if guide.getUri(source.closure.call) ~= guide.getUri(source.proto) then + id = id .. URI_CHAR .. guide.getUri(source.closure.call) + end + return id + end + if guide.isGlobal(source) then + return 'g:' + end + if source.type == 'getlocal' + or source.type == 'setlocal' then + source = source.node + end + if source.parent.type == 'funcargs' then + return 'p:' + end + return 'l:' +end + +local IDList = {} +---获取语法树单元的字符串ID +---@param source parser.guide.object +---@return string? id +local function getID(source) + if not source then + return nil + end + if source._id ~= nil then + return source._id or nil + end + if source.type == 'field' + or source.type == 'method' then + source._id = false + return nil + end + local current = source + local index = 0 + while true do + if current.type == 'paren' then + current = current.exp + goto CONTINUE + end + local id, node = getKey(current) + if not id then + break + end + index = index + 1 + IDList[index] = id + if not node then + break + end + current = node + if current.special == '_G' then + for i = index, 2, -1 do + if IDList[i] == '"_G"' then + IDList[i] = nil + end + end + break + end + ::CONTINUE:: + end + if index == 0 then + source._id = false + return nil + end + for i = index + 1, #IDList do + IDList[i] = nil + end + local mode = checkMode(current) + if not mode then + source._id = false + return nil + end + util.revertTable(IDList) + local id = mode .. table.concat(IDList, SPLIT_CHAR) + source._id = id + return id +end + +---添加关联的前进ID +---@param noders noders +---@param id string +---@param forwardID string +local function pushForward(noders, id, forwardID, tag) + if not id + or not forwardID + or forwardID == '' + or id == forwardID then + return + end + local node = getNode(noders, id) + if not node.forward then + node.forward = {} + end + if node.forward[forwardID] ~= nil then + return + end + node.forward[forwardID] = tag or false + node.forward[#node.forward+1] = forwardID +end + +---添加关联的后退ID +---@param noders noders +---@param id string +---@param backwardID string +local function pushBackward(noders, id, backwardID, tag) + if not id + or not backwardID + or backwardID == '' + or id == backwardID then + return + end + local node = getNode(noders, id) + if not node.backward then + node.backward = {} + end + if node.backward[backwardID] ~= nil then + return + end + node.backward[backwardID] = tag or false + node.backward[#node.backward+1] = backwardID +end + +local m = {} + +m.SPLIT_CHAR = SPLIT_CHAR +m.RETURN_INDEX = RETURN_INDEX +m.PARAM_INDEX = PARAM_INDEX +m.TABLE_KEY = TABLE_KEY +m.ANY_FIELD = ANY_FIELD +m.URI_CHAR = URI_CHAR + +--- 寻找doc的主体 +---@param obj parser.guide.object +---@return parser.guide.object +local function getDocStateWithoutCrossFunction(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return obj + end + if parent.type == 'doc' then + return obj + end + if parent.type == 'doc.type.function' then + return nil + end + obj = parent + end + error('guide.getDocState overstack') +end + +---添加关联单元 +---@param noders noders +---@param source parser.guide.object +function m.pushSource(noders, source) + local id = m.getID(source) + if not id then + return + end + local node = getNode(noders, id) + if not node.sources then + node.sources = {} + end + node.sources[#node.sources+1] = source +end + +---@param noders noders +---@param source parser.guide.object +---@return parser.guide.object[] +function m.compileNode(noders, source) + local id = getID(source) + local value = source.value + if value then + local valueID = getID(value) + if valueID then + -- x = y : x -> y + pushForward(noders, id, valueID, 'set') + -- 参数禁止反向查找赋值 + if valueID:sub(1, 2) ~= 'p:' then + pushBackward(noders, valueID, id, 'set') + end + end + end + -- self -> mt:xx + if source.type == 'local' and source[1] == 'self' then + local func = guide.getParentFunction(source) + if func.isGeneric then + return + end + local setmethod = func.parent + -- guess `self` + if setmethod and ( setmethod.type == 'setmethod' + or setmethod.type == 'setfield' + or setmethod.type == 'setindex') then + pushForward(noders, id, getID(setmethod.node), 'method') + pushBackward(noders, getID(setmethod.node), id, 'method') + end + end + -- 分解 @type + if source.type == 'doc.type' then + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) + end + end + for _, enumUnit in ipairs(source.enums) do + pushForward(noders, id, getID(enumUnit)) + end + for _, resumeUnit in ipairs(source.resumes) do + pushForward(noders, id, getID(resumeUnit)) + end + for _, typeUnit in ipairs(source.types) do + local unitID = getID(typeUnit) + pushForward(noders, id, unitID) + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushBackward(noders, unitID, getID(src)) + end + end + end + end + -- 分解 @alias + if source.type == 'doc.alias' then + pushForward(noders, getID(source.alias), getID(source.extends)) + end + -- 分解 @class + if source.type == 'doc.class' then + pushForward(noders, id, getID(source.class)) + pushForward(noders, getID(source.class), id) + if source.extends then + for _, ext in ipairs(source.extends) do + pushBackward(noders, id, getID(ext)) + end + end + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) + end + end + for _, field in ipairs(source.fields) do + local key = field.field[1] + if key then + local keyID = ('%s%s%q'):format( + id, + SPLIT_CHAR, + key + ) + pushForward(noders, keyID, getID(field.field)) + pushForward(noders, getID(field.field), keyID) + pushForward(noders, keyID, getID(field.extends)) + pushBackward(noders, getID(field.extends), keyID) + end + end + end + if source.type == 'doc.param' then + pushForward(noders, id, getID(source.extends)) + for _, src in ipairs(source.bindSources) do + if src.type == 'local' and src.parent.type == 'in' then + pushForward(noders, getID(src), id) + end + end + end + if source.type == 'doc.vararg' then + pushForward(noders, getID(source), getID(source.vararg)) + end + if source.type == 'doc.see' then + local nameID = getID(source.name) + local classID = nameID:gsub('^dsn:', 'dn:') + pushForward(noders, nameID, classID) + if source.field then + local fieldID = getID(source.field) + local fieldClassID = fieldID:gsub('^dsn:', 'dn:') + pushForward(noders, fieldID, fieldClassID) + end + end + if source.type == 'call' then + local node = source.node + local nodeID = getID(node) + if not nodeID then + return + end + getNode(noders, id).call = source + -- 将 call 映射到 node#1 上 + local callID = ('%s%s%s'):format( + nodeID, + RETURN_INDEX, + 1 + ) + pushForward(noders, id, callID) + -- 将setmetatable映射到 param1 以及 param2.__index 上 + if node.special == 'setmetatable' then + local tblID = getID(source.args and source.args[1]) + local metaID = getID(source.args and source.args[2]) + local indexID + if metaID then + indexID = ('%s%s%q'):format( + metaID, + SPLIT_CHAR, + '__index' + ) + end + pushForward(noders, id, callID) + pushBackward(noders, callID, id) + pushForward(noders, callID, tblID) + pushForward(noders, callID, indexID) + pushBackward(noders, tblID, callID) + --pushBackward(noders, indexID, callID) + end + if node.special == 'require' then + local arg1 = source.args and source.args[1] + if arg1 and arg1.type == 'string' then + getNode(noders, callID).require = arg1[1] + end + end + end + if source.type == 'select' then + if source.vararg.type == 'call' then + local call = source.vararg + local node = call.node + local nodeID = getID(node) + if not nodeID then + return + end + -- 将call的返回值接收映射到函数返回值上 + local callXID = ('%s%s%s'):format( + nodeID, + RETURN_INDEX, + source.sindex + ) + pushForward(noders, id, callXID) + pushBackward(noders, callXID, id) + getNode(noders, id).call = call + if node.special == 'pcall' + or node.special == 'xpcall' then + local index = source.sindex - 1 + if index <= 0 then + return + end + local funcID = call.args and getID(call.args[1]) + if not funcID then + return + end + local funcXID = ('%s%s%s'):format( + funcID, + RETURN_INDEX, + index + ) + pushForward(noders, id, funcXID) + pushBackward(noders, funcXID, id) + end + end + if source.vararg.type == 'varargs' then + pushForward(noders, id, getID(source.vararg)) + end + end + if source.type == 'doc.type.function' then + if source.returns then + for index, rtn in ipairs(source.returns) do + local returnID = ('%s%s%s'):format( + id, + RETURN_INDEX, + index + ) + pushForward(noders, returnID, getID(rtn)) + end + end + -- @type fun(x: T):T 的情况 + local docType = getDocStateWithoutCrossFunction(source) + if docType and docType.type == 'doc.type' then + guide.eachSourceType(source, 'doc.type.name', function (typeName) + if typeName.typeGeneric then + source.isGeneric = true + return false + end + end) + end + end + if source.type == 'doc.type.table' then + if source.tkey then + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, getID(source.tkey)) + end + if source.tvalue then + local valueID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, valueID, getID(source.tvalue)) + end + end + if source.type == 'doc.type.array' then + if source.node then + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source.node)) + end + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, 'dn:integer') + end + -- 将函数的返回值映射到具体的返回值上 + if source.type == 'function' then + local hasDocReturn = {} + -- 检查 luadoc + if source.bindDocs then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + local fullID = ('%s%s%s'):format( + id, + RETURN_INDEX, + rtn.returnIndex + ) + pushForward(noders, fullID, getID(rtn)) + hasDocReturn[rtn.returnIndex] = true + end + end + if doc.type == 'doc.param' then + local paramName = doc.param[1] + if source.docParamMap then + local paramIndex = source.docParamMap[paramName] + local param = source.args[paramIndex] + if param then + pushForward(noders, getID(param), getID(doc)) + param.docParam = doc + end + end + end + if doc.type == 'doc.vararg' then + for _, param in ipairs(source.args) do + if param.type == '...' then + pushForward(noders, getID(param), getID(doc)) + end + end + end + if doc.type == 'doc.generic' then + source.isGeneric = true + end + if doc.type == 'doc.overload' then + pushForward(noders, id, getID(doc.overload)) + end + end + end + -- 检查实体返回值 + if source.returns then + local returns = {} + for _, rtn in ipairs(source.returns) do + for index, rtnObj in ipairs(rtn) do + if not hasDocReturn[index] then + if not returns[index] then + returns[index] = {} + end + returns[index][#returns[index]+1] = rtnObj + end + end + end + for index, rtnObjs in ipairs(returns) do + local returnID = ('%s%s%s'):format( + id, + RETURN_INDEX, + index + ) + for _, rtnObj in ipairs(rtnObjs) do + pushForward(noders, returnID, getID(rtnObj)) + if rtnObj.type == 'function' + or rtnObj.type == 'call' then + pushBackward(noders, getID(rtnObj), returnID) + end + end + end + end + end + if source.type == 'table' then + if #source == 1 and source[1].type == 'varargs' then + source.array = source[1] + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source[1])) + end + end + if source.type == 'main' then + if source.returns then + for _, rtn in ipairs(source.returns) do + local rtnObj = rtn[1] + if rtnObj then + pushForward(noders, 'mainreturn', getID(rtnObj)) + pushBackward(noders, getID(rtnObj), 'mainreturn') + end + end + end + end + if source.type == 'generic.closure' then + for i, rtn in ipairs(source.returns) do + local closureID = ('%s%s%s'):format( + id, + RETURN_INDEX, + i + ) + local returnID = getID(rtn) + pushForward(noders, closureID, returnID) + end + end + if source.type == 'generic.value' then + local proto = source.proto + local closure = source.closure + local upvalues = closure.upvalues + if proto.type == 'doc.type.name' then + local key = proto[1] + if upvalues[key] then + for _, paramID in ipairs(upvalues[key]) do + pushForward(noders, id, paramID) + pushBackward(noders, paramID, id) + end + end + end + if proto.type == 'doc.type' then + for _, tp in ipairs(source.types) do + pushForward(noders, id, getID(tp)) + pushBackward(noders, getID(tp), id) + end + end + if proto.type == 'doc.type.array' then + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source.node)) + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, 'dn:integer') + end + if proto.type == 'doc.type.table' then + if source.tkey then + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, getID(source.tkey)) + end + if source.tvalue then + local valueID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, valueID, getID(source.tvalue)) + end + end + end +end + +---根据ID来获取所有的node +---@param root parser.guide.object +---@param id string +---@return node? +function m.getNodeByID(root, id) + root = guide.getRoot(root) + local noders = root._noders + if not noders then + return nil + end + return noders[id] +end + +---根据ID来获取第一个节点的ID +---@param id string +---@return string +function m.getFirstID(id) + if FirstIDCache[id] then + return FirstIDCache[id] or nil + end + local firstID, count = id:match(FIRST_REGEX) + if count == 0 then + FirstIDCache[id] = false + return nil + end + FirstIDCache[id] = firstID + return firstID +end + +---根据ID来获取上个节点的ID +---@param id string +---@return string +function m.getLastID(id) + if LastIDCache[id] then + return LastIDCache[id] or nil + end + local lastID, count = id:gsub(LAST_REGEX, '') + if count == 0 then + LastIDCache[id] = false + return nil + end + LastIDCache[id] = lastID + return lastID +end + +---把形如 `@file:\\\XXXXX@gv:1|1`拆分成uri与id +---@param id string +---@return uri? string +---@return string id +function m.getUriAndID(id) + local uri, newID = id:match(URI_REGEX) + return uri, newID +end + +---获取source的ID +---@param source parser.guide.object +---@return string +function m.getID(source) + return getID(source) +end + +---获取source的key +---@param source parser.guide.object +---@return string +function m.getKey(source) + return getKey(source) +end + +---清除临时id(用于泛型的临时对象) +---@param root parser.guide.object +---@param id string +function m.removeID(root, id) + root = guide.getRoot(root) + local noders = root._noders + noders[id] = nil +end + +---寻找doc的主体 +---@param doc parser.guide.object +function m.getDocState(doc) + return getDocStateWithoutCrossFunction(doc) +end + +---获取对象的noders +---@param source parser.guide.object +---@return noders +function m.getNoders(source) + local root = guide.getRoot(source) + if not root._noders then + root._noders = {} + end + return root._noders +end + +---编译整个文件的node +---@param source parser.guide.object +---@return table +function m.compileNodes(source) + local root = guide.getRoot(source) + local noders = m.getNoders(source) + if next(noders) then + return noders + end + guide.eachSource(root, function (src) + m.pushSource(noders, src) + m.compileNode(noders, src) + end) + -- Special rule: ('').XX -> stringlib.XX + pushBackward(noders, 'str:', 'dn:stringlib') + pushBackward(noders, 'dn:string', 'dn:stringlib') + return noders +end + +return m diff --git a/script/core/reference.lua b/script/core/reference.lua index 7620b09e..c3f3b349 100644 --- a/script/core/reference.lua +++ b/script/core/reference.lua @@ -1,4 +1,5 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' @@ -6,8 +7,8 @@ local findSource = require 'core.find-source' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -19,7 +20,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else @@ -64,11 +65,23 @@ return function (uri, offset) local metaSource = vm.isMetaFile(uri) + local refs = vm.getRefs(source) + local values = {} + for _, src in ipairs(refs) do + local value = searcher.getObjectValue(src) + if value and value ~= src and guide.isLiteral(value) then + values[value] = true + end + end + local results = {} - for _, src in ipairs(vm.getRefs(source, 5)) do + for _, src in ipairs(refs) do if src.dummy then goto CONTINUE end + if values[src] then + goto CONTINUE + end local root = guide.getRoot(src) if not root then goto CONTINUE diff --git a/script/core/rename.lua b/script/core/rename.lua index da82b0a6..6b67d4be 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -1,11 +1,11 @@ local files = require 'files' local vm = require 'vm' -local guide = require 'core.guide' local proto = require 'proto' local define = require 'proto.define' local util = require 'utility' local findSource = require 'core.find-source' -local ws = require 'workspace' +local guide = require 'parser.guide' +local noder = require 'core.noder' local Forcing @@ -185,7 +185,7 @@ local function renameField(source, newname, callback) end callback(source, source.start, source.finish, newname) elseif parent.type == 'setmethod' then - local uri = guide.getUri(source) + local uri = guide.getUri(source) local text = files.getText(uri) local func = parent.value -- function mt:name () end --> mt['newname'] = function (self) end @@ -292,14 +292,14 @@ local function ofField(source, newname, callback) else node = source.node end - for _, src in ipairs(vm.getFields(node, 5)) do + for _, src in ipairs(vm.getRefs(node, '*')) do ofFieldThen(key, src, newname, callback) end end local function ofGlobal(source, newname, callback) local key = guide.getKeyName(source) - for _, src in ipairs(vm.getRefs(source, 0)) do + for _, src in ipairs(vm.getRefs(source)) do ofFieldThen(key, src, newname, callback) end end @@ -308,24 +308,27 @@ local function ofLabel(source, newname, callback) if not isValidName(newname) and not askForcing(newname)then return false end - for _, src in ipairs(vm.getRefs(source, 0)) do + for _, src in ipairs(vm.getRefs(source)) do callback(src, src.start, src.finish, newname) end end local function ofDocTypeName(source, newname, callback) - for _, doc in ipairs(vm.getDocTypes(source[1])) do + local oldname = source[1] + for _, doc in ipairs(vm.getRefs(source)) do if doc.type == 'doc.class.name' or doc.type == 'doc.type.name' or doc.type == 'doc.alias.name' then - callback(doc, doc.start, doc.finish, newname) + if oldname == doc[1] then + callback(doc, doc.start, doc.finish, newname) + end end end end local function ofDocParamName(source, newname, callback) callback(source, source.start, source.finish, newname) - local doc = guide.getDocState(source) + local doc = noder.getDocState(source) if doc.bindSources then for _, src in ipairs(doc.bindSources) do if src.type == 'local' diff --git a/script/core/searcher.lua b/script/core/searcher.lua new file mode 100644 index 00000000..11e00378 --- /dev/null +++ b/script/core/searcher.lua @@ -0,0 +1,728 @@ +local noder = require 'core.noder' +local guide = require 'parser.guide' +local files = require 'files' +local generic = require 'core.generic' +local ws = require 'workspace' +local vm = require 'vm.vm' + +local NONE = {'NONE'} +local LAST = {'LAST'} + +local ignoredIDs = { + ['dn:unknown'] = true, + ['dn:nil'] = true, + ['dn:any'] = true, + ['dn:boolean'] = true, + ['dn:string'] = true, + ['dn:table'] = true, + ['dn:number'] = true, + ['dn:integer'] = true, + ['dn:userdata'] = true, + ['dn:lightuserdata'] = true, + ['dn:function'] = true, + ['dn:thread'] = true, +} + +local m = {} + +---@alias guide.searchmode '"ref"'|'"def"' + +---添加结果 +---@param status guide.status +---@param mode guide.searchmode +---@param source parser.guide.object +---@param force boolean +function m.pushResult(status, mode, source, force) + if not source then + return + end + local results = status.results + if results[source] then + return + end + results[source] = true + if force then + results[#results+1] = source + return + end + local parent = source.parent + if mode == 'def' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'label' + or source.type == 'setfield' + or source.type == 'setmethod' + or source.type == 'setindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if noder.getID(source) ~= status.id then + results[#results+1] = source + end + end + elseif mode == 'ref' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'label' + or source.type == 'goto' + or source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' + or source.type == 'setindex' + or source.type == 'getindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' + or source.type == 'nil' + or source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if noder.getID(source) ~= status.id then + results[#results+1] = source + end + end + end +end + +---获取uri +---@param obj parser.guide.object +---@return uri +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = guide.getRoot(obj) + if root then + return root.uri + end + return '' +end + +---@param obj parser.guide.object +---@return parser.guide.object? +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent and obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args and obj.args[3] + else + return obj + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +local function crossSearch(status, uri, expect, mode) + m.searchRefsByID(status, uri, expect, mode) +end + +local function getLock(status, uri, expect, mode) + local slock = status.lock + local ulock = slock[uri] + if not ulock then + ulock = {} + slock[uri] = ulock + end + local mlock = ulock[mode] + if not mlock then + mlock = {} + ulock[mode] = mlock + end + if mlock[expect] then + return false + end + mlock[expect] = true + return true +end + +function m.searchRefsByID(status, uri, expect, mode) + local ast = files.getAst(uri) + if not ast then + return + end + if not getLock(status, uri, expect, mode) then + return + end + local root = ast.ast + local searchStep + noder.compileNodes(root) + + status.id = expect + + local callStack = status.callStack + + local mark = {} + + local function search(id, field) + local firstID = noder.getFirstID(id) + if ignoredIDs[firstID] and (field or firstID ~= id) then + return + end + local cmark = mark[id] + if not cmark then + cmark = {} + mark[id] = cmark + end + log.debug('search:', id, field) + if field then + if cmark[field] then + return + end + cmark[field] = true + searchStep(id, field) + cmark[field] = nil + else + if cmark[NONE] then + return + end + cmark[NONE] = true + searchStep(id, nil) + cmark[NONE] = nil + end + log.debug('pop:', id, field) + end + + local function checkLastID(id, field) + local cmark = mark[id] + if not cmark then + cmark = {} + mark[id] = cmark + end + if cmark[LAST] then + return + end + local lastID = noder.getLastID(id) + if not lastID then + return + end + local newField = id:sub(#lastID + 1) + if field then + newField = newField .. field + end + cmark[LAST] = true + search(lastID, newField) + cmark[LAST] = nil + return lastID + end + + local function searchID(id, field) + if not id then + return + end + if field then + id = id .. field + end + search(id, nil) + end + + local function isCallID(field) + if not field then + return false + end + if field:sub(1, 2) == noder.RETURN_INDEX then + return true + end + return false + end + + local function findLastCall() + for i = #callStack, 1, -1 do + local call = callStack[i] + if call then + -- 标记此处的call失效,等待在堆栈平衡时弹出 + callStack[i] = false + return call + end + end + return nil + end + + local genericCallArgs = {} + local closureCache = {} + local function checkGeneric(source, field) + if not source.isGeneric then + return + end + if not isCallID(field) then + return + end + local call = findLastCall() + if not call then + return + end + + if call.args then + for _, arg in ipairs(call.args) do + genericCallArgs[arg] = true + end + end + + local cacheID = noder.getID(source) .. noder.getID(call) + local closure = closureCache[cacheID] + if closure == false then + return + end + if not closure then + closure = generic.createClosure(source, call) + closureCache[cacheID] = closure or false + if not closure then + return + end + end + local id = noder.getID(closure) + searchID(id, field) + end + + local function checkENV(source, field) + if not field then + return + end + if source.special ~= '_G' then + return + end + local newID = 'g:' .. field:sub(2) + searchID(newID) + end + + local forwardTag = {} + local backwardTag = {} + local function checkForward(id, node, field) + for _, forwardID in ipairs(node.forward) do + local tag = node.forward[forwardID] + if tag then + if backwardTag[tag] and backwardTag[tag] > 0 then + goto CONTINUE + end + forwardTag[tag] = (forwardTag[tag] or 0) + 1 + end + local targetUri, targetID = noder.getUriAndID(forwardID) + if targetUri and not files.eq(targetUri, uri) then + crossSearch(status, targetUri, targetID .. (field or ''), mode) + else + searchID(targetID or forwardID, field) + end + if tag then + forwardTag[tag] = forwardTag[tag] - 1 + end + ::CONTINUE:: + end + end + + local function checkBackward(id, node, field) + if mode ~= 'ref' and not field then + return + end + for _, backwardID in ipairs(node.backward) do + local tag = node.backward[backwardID] + if tag then + if forwardTag[tag] and forwardTag[tag] > 0 then + goto CONTINUE + end + backwardTag[tag] = (backwardTag[tag] or 0) + 1 + end + local targetUri, targetID = noder.getUriAndID(backwardID) + if targetUri and not files.eq(targetUri, uri) then + crossSearch(status, targetUri, targetID .. (field or ''), mode) + else + searchID(targetID or backwardID, field) + end + if tag then + backwardTag[tag] = backwardTag[tag] - 1 + end + ::CONTINUE:: + end + end + + local function checkRequire(requireName, field) + local tid = 'mainreturn' .. (field or '') + local uris = ws.findUrisByRequirePath(requireName) + for _, ruri in ipairs(uris) do + if not files.eq(uri, ruri) then + crossSearch(status, ruri, tid, mode) + end + end + end + + local function checkGlobal(id, node, field) + if id:sub(1, 2) ~= 'g:' then + return + end + local firstID = noder.getFirstID(id) + if status.crossed[firstID] then + return + end + status.crossed[firstID] = true + local tid = id .. (field or '') + for guri in files.eachFile() do + if not files.eq(uri, guri) then + crossSearch(status, guri, tid, mode) + end + end + end + + local function checkClass(id, node, field) + if id:sub(1, 3) ~= 'dn:' then + return + end + local firstID = noder.getFirstID(id) + if status.crossed[firstID] then + return + end + status.crossed[firstID] = true + local tid = id .. (field or '') + for guri in files.eachFile() do + if not files.eq(uri, guri) then + crossSearch(status, guri, tid, mode) + end + end + end + + local function checkMainReturn(id, node, field) + if id ~= 'mainreturn' then + return + end + if mode ~= 'ref' and not field then + return + end + local calls = vm.getLinksTo(uri) + for _, call in ipairs(calls) do + local turi = guide.getUri(call) + if not files.eq(turi, uri) then + local tid = noder.getID(call) .. (field or '') + crossSearch(status, turi, tid, mode) + end + end + end + + local function searchNode(id, node, field) + if node.call then + callStack[#callStack+1] = node.call + end + if field == nil and node.sources then + for _, source in ipairs(node.sources) do + local force = genericCallArgs[source] + m.pushResult(status, mode, source, force) + end + end + if node.forward then + checkForward(id, node, field) + end + if node.backward then + checkBackward(id, node, field) + end + + if node.sources then + checkGeneric(node.sources[1], field) + checkENV(node.sources[1], field) + end + + if node.require then + checkRequire(node.require, field) + end + + checkMainReturn(id, node, field) + + if node.call then + callStack[#callStack] = nil + end + end + + local function checkCrossUri(id, field) + local targetUri, newID = noder.getUriAndID(id) + if not targetUri then + return false + end + crossSearch(status, targetUri, newID .. (field or ''), mode) + return true + end + + local stepCount = 0 + function searchStep(id, field) + stepCount = stepCount + 1 + if stepCount > 1000 then + error('too large') + end + local node = noder.getNodeByID(root, id) + if node then + searchNode(id, node, field) + end + checkGlobal(id, node, field) + checkClass(id, node, field) + local lastID = checkLastID(id, field) + if not lastID then + return + end + local originField = id:sub(#lastID + 1) + if originField == noder.TABLE_KEY then + return + end + local anyFieldID = lastID .. noder.ANY_FIELD + local anyFieldNode = noder.getNodeByID(root, anyFieldID) + if anyFieldNode then + searchNode(anyFieldID, anyFieldNode, field) + end + end + + search(expect) + + --清除来自泛型的临时对象 + for _, closure in pairs(closureCache) do + noder.removeID(root, noder.getID(closure)) + if closure then + for _, value in ipairs(closure.values) do + noder.removeID(root, noder.getID(value)) + end + end + end +end + +local function prepareSearch(source) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local root = guide.getRoot(source) + noder.compileNodes(root) + local uri = guide.getUri(source) + local id = noder.getID(source) + return uri, id +end + +local function getField(status, source, mode) + if source.type == 'table' then + for _, field in ipairs(source) do + m.pushResult(status, mode, field) + end + return + end + if source.type == 'doc.class.name' then + local class = source.parent + for _, field in ipairs(class.fields) do + m.pushResult(status, mode, field.field) + end + return + end + local field = source.next + if field then + if field.type == 'getmethod' + or field.type == 'setmethod' + or field.type == 'getfield' + or field.type == 'setfield' + or field.type == 'getindex' + or field.type == 'setindex' then + m.pushResult(status, mode, field) + end + return + end +end + +local function searchAllGlobalByUri(status, mode, uri, fullID) + local ast = files.getAst(uri) + if not ast then + return + end + local root = ast.ast + noder.compileNodes(root) + local noders = noder.getNoders(root) + if fullID then + for id, node in pairs(noders) do + if node.sources + and id == fullID then + for _, source in ipairs(node.sources) do + m.pushResult(status, mode, source) + end + end + end + else + for id, node in pairs(noders) do + if node.sources + and id:sub(1, 2) == 'g:' + and not id:find(noder.SPLIT_CHAR) then + for _, source in ipairs(node.sources) do + m.pushResult(status, mode, source) + end + end + end + end +end + +local function searchAllGlobals(status, mode, fullID) + for uri in files.eachFile() do + searchAllGlobalByUri(status, mode, uri, fullID) + end +end + +---搜索对象的引用 +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +function m.searchRefs(status, source, mode) + local uri, id = prepareSearch(source) + if not id then + return + end + log.debug('searchRefs:', id) + m.searchRefsByID(status, uri, id, mode) +end + +function m.findGlobals(uri, mode, name) + local status = m.status() + + if name then + local fullID = ('g:%q'):format(name) + searchAllGlobalByUri(status, mode, uri, fullID) + else + searchAllGlobalByUri(status, mode, uri) + end + + return status.results +end + +---搜索对象的field +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +---@param field string +function m.searchFields(status, source, mode, field) + local uri, id = prepareSearch(source) + if not id then + return + end + log.debug('searchFields:', id, field) + if field == '*' then + if source.special == '_G' then + searchAllGlobals(status, mode) + else + local newStatus = m.status(status) + m.searchRefsByID(newStatus, uri, id, mode) + for _, def in ipairs(newStatus.results) do + getField(status, def, mode) + end + end + else + if source.special == '_G' then + local fullID = ('g:%q'):format(field) + m.searchRefsByID(status, uri, fullID, mode) + else + local fullID = ('%s%s%q'):format(id, noder.SPLIT_CHAR, field) + m.searchRefsByID(status, uri, fullID, mode) + end + end +end + +---@class guide.status +---搜索结果 +---@field results parser.guide.object[] + +---创建搜索状态 +---@param parentStatus guide.status +---@return guide.status +function m.status(parentStatus) + local status = { + --mark = parentStatus and parentStatus.mark or {}, + callStack = {}, + crossed = {}, + lock = {}, + results = {}, + } + return status +end + +--- 请求对象的引用 +---@param obj parser.guide.object +---@param field? string +---@return parser.guide.object[] +function m.requestReference(obj, field) + local status = m.status() + + if field then + m.searchFields(status, obj, 'ref', field) + else + m.searchRefs(status, obj, 'ref') + end + + return status.results +end + +--- 请求对象的定义 +---@param obj parser.guide.object +---@param field? string +---@return parser.guide.object[] +function m.requestDefinition(obj, field) + local status = m.status() + + if field then + m.searchFields(status, obj, 'def', field) + else + m.searchRefs(status, obj, 'def') + end + + return status.results +end + +return m diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index f8feaa09..5e9ee9b1 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local await = require 'await' local define = require 'proto.define' local vm = require 'vm' @@ -221,7 +221,7 @@ return function (uri, start, finish) local results = {} local count = 0 - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) local method = Care[source.type] if not method then return diff --git a/script/core/signature.lua b/script/core/signature.lua index eb740784..915310c0 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -1,8 +1,9 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local hoverLabel = require 'core.hover.label' local hoverDesc = require 'core.hover.description' +local guide = require 'parser.guide' local function findNearCall(uri, ast, pos) local text = files.getText(uri) @@ -96,10 +97,10 @@ local function makeSignatures(call, pos) index = 1 end local signs = {} - local defs = vm.getDefs(node, 0) + local defs = vm.getDefs(node) local mark = {} for _, src in ipairs(defs) do - src = guide.getObjectValue(src) or src + src = searcher.getObjectValue(src) or src if src.type == 'function' or src.type == 'doc.type.function' then if not mark[src] then diff --git a/script/core/type-formatting.lua b/script/core/type-formatting.lua index c2290ef3..49a721e5 100644 --- a/script/core/type-formatting.lua +++ b/script/core/type-formatting.lua @@ -1,6 +1,6 @@ local files = require 'files' local lookBackward = require 'core.look-backward' -local guide = require 'core.guide' +local guide = require "parser.guide" local function insertIndentation(uri, offset, edits) local lines = files.getLines(uri) diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua index ae420d32..2df23a4d 100644 --- a/script/core/workspace-symbol.lua +++ b/script/core/workspace-symbol.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local matchKey = require 'core.matchkey' local define = require 'proto.define' local await = require 'await' @@ -52,7 +52,7 @@ local function searchFile(uri, key, results) return end - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) buildSource(uri, source, key, results) end) end diff --git a/script/files.lua b/script/files.lua index 9cc6b549..bb143250 100644 --- a/script/files.lua +++ b/script/files.lua @@ -9,7 +9,7 @@ local await = require 'await' local timer = require 'timer' local plugin = require 'plugin' local util = require 'utility' -local guide = require 'core.guide' +local guide = require 'parser.guide' local smerger = require 'string-merger' local progress = require "progress" @@ -345,6 +345,7 @@ function m.getAllUris() i = i + 1 files[i] = uri end + table.sort(files) end return m._pairsCache end diff --git a/script/parser/ast.lua b/script/parser/ast.lua index 45d77631..40b5788e 100644 --- a/script/parser/ast.lua +++ b/script/parser/ast.lua @@ -110,7 +110,7 @@ local function getSelect(vararg, index) start = vararg.start, finish = vararg.finish, vararg = vararg, - index = index, + sindex = index, } end @@ -1460,8 +1460,14 @@ local Defs = { local values if func then local call = createCall(exp, func.finish + 1, exp.finish) + if #exp == 0 then + exp[1] = getSelect(func, 2) + exp[2] = getSelect(func, 3) + exp[3] = getSelect(func, 4) + end call.node = func - call.start = func.start + call.start = inA + call.finish = doB - 1 func.next = call func.iterator = true values = { call } diff --git a/script/parser/compile.lua b/script/parser/compile.lua index a7e0dc1f..21be406d 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -125,6 +125,7 @@ local vmMap = { vararg.ref = {} end vararg.ref[#vararg.ref+1] = obj + obj.node = vararg end end end, @@ -150,8 +151,8 @@ local vmMap = { local value = obj.value local localself = { type = 'local', - start = 0, - finish = 0, + start = value.start, + finish = value.finish, method = obj, effect = obj.finish, tag = 'self', diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 2369e84f..8d2708cf 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -1,34 +1,8 @@ -local util = require 'utility' local error = error local type = type -local next = next -local tostring = tostring -local print = print -local ipairs = ipairs -local tableInsert = table.insert -local tableUnpack = table.unpack -local tableRemove = table.remove -local tableMove = table.move -local tableSort = table.sort -local tableConcat = table.concat -local mathType = math.type -local pairs = pairs -local setmetatable = setmetatable -local assert = assert -local select = select -local osClock = os.clock -local tonumber = tonumber -local tointeger = math.tointeger -local DEVELOP = _G.DEVELOP -local log = log -local _G = _G ---@class parser.guide.object -local function logWarn(...) - log.warn(...) -end - ---@class guide ---@field debugMode boolean local m = {} @@ -91,7 +65,7 @@ m.childMap = { ['doc'] = {'#'}, ['doc.class'] = {'class', '#extends', 'comment'}, - ['doc.type'] = {'#types', '#enums', 'name', 'comment'}, + ['doc.type'] = {'#types', '#enums', '#resumes', 'name', 'comment'}, ['doc.alias'] = {'alias', 'extends', 'comment'}, ['doc.param'] = {'param', 'extends', 'comment'}, ['doc.return'] = {'#returns', 'comment'}, @@ -100,9 +74,9 @@ m.childMap = { ['doc.generic.object'] = {'generic', 'extends', 'comment'}, ['doc.vararg'] = {'vararg', 'comment'}, ['doc.type.array'] = {'node'}, - ['doc.type.table'] = {'node', 'key', 'value', 'comment'}, + ['doc.type.table'] = {'tkey', 'tvalue', 'comment'}, ['doc.type.function'] = {'#args', '#returns', 'comment'}, - ['doc.type.typeliteral'] = {'node'}, + ['doc.type.literal'] = {'node'}, ['doc.type.arg'] = {'extends'}, ['doc.overload'] = {'overload', 'comment'}, ['doc.see'] = {'name', 'field'}, @@ -123,19 +97,31 @@ m.actionMap = { ['funcargs'] = {'#'}, } -local TypeSort = { - ['boolean'] = 1, - ['string'] = 2, - ['integer'] = 3, - ['number'] = 4, - ['table'] = 5, - ['function'] = 6, - ['true'] = 101, - ['false'] = 102, - ['nil'] = 999, -} +local inf = 1 / 0 +local nan = 0 / 0 + +local function isInteger(n) + if math.type then + return math.type(n) == 'integer' + else + return type(n) == 'number' and n % 1 == 0 + end +end -local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) +local function formatNumber(n) + if n == inf + or n == -inf + or n == nan + or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则 + return ('%q'):format(n) + end + if isInteger(n) then + return tostring(n) + end + local str = ('%.10f'):format(n) + str = str:gsub('%.?0*$', '') + return str +end --- 是否是字面量 ---@param obj parser.guide.object @@ -182,23 +168,6 @@ function m.getParentFunction(obj) return nil end ---- 寻找父的table类型 doc.type.table ----@param obj parser.guide.object ----@return parser.guide.object -function m.getParentDocTypeTable(obj) - for _ = 1, 1000 do - local parent = obj.parent - if not parent then - return nil - end - if parent.type == 'doc.type.table' then - return obj - end - obj = parent - end - error('guide.getParentDocTypeTable overstack') -end - --- 寻找所在区块 ---@param obj parser.guide.object ---@return parser.guide.object @@ -293,10 +262,19 @@ end ---@param obj parser.guide.object ---@return parser.guide.object function m.getRoot(obj) + local source = obj + if source._root then + return source._root + end for _ = 1, 1000 do if obj.type == 'main' then + source._root = obj return obj end + if obj._root then + source._root = obj._root + return source._root + end local parent = obj.parent if not parent then return nil @@ -501,8 +479,8 @@ function m.addChilds(list, obj, map) for i = 1, #keys do local key = keys[i] if key == '#' then - for i = 1, #obj do - list[#list+1] = obj[i] + for j = 1, #obj do + list[#list+1] = obj[j] end elseif obj[key] then list[#list+1] = obj[key] @@ -510,8 +488,8 @@ function m.addChilds(list, obj, map) and key:sub(1, 1) == '#' then key = key:sub(2) if obj[key] then - for i = 1, #obj[key] do - list[#list+1] = obj[key][i] + for j = 1, #obj[key] do + list[#list+1] = obj[key][j] end end end @@ -613,9 +591,16 @@ function m.eachSource(ast, callback) index = index + 1 if not mark[obj] then mark[obj] = true - callback(obj) + local res = callback(obj) + if res == true then + goto CONTINUE + end + if res == false then + return + end m.addChilds(list, obj, m.childMap) end + ::CONTINUE:: end end @@ -718,4 +703,288 @@ function m.lineData(lines, row) return lines[row] end +function m.isSet(source) + local tp = source.type + if tp == 'setglobal' + or tp == 'local' + or tp == 'setlocal' + or tp == 'setfield' + or tp == 'setmethod' + or tp == 'setindex' + or tp == 'tablefield' + or tp == 'tableindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawset' then + return true + end + end + return false +end + +function m.isGet(source) + local tp = source.type + if tp == 'getglobal' + or tp == 'getlocal' + or tp == 'getfield' + or tp == 'getmethod' + or tp == 'getindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawget' then + return true + end + end + return false +end + +function m.getSpecial(source) + if not source then + return nil + end + return source.special +end + +function m.getKeyNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return obj[1] + elseif tp == 'string' then + local s = obj[1] + if s then + return s + end + elseif tp == 'number' then + local n = obj[1] + if n then + return ('%s'):format(formatNumber(obj[1])) + end + elseif tp == 'boolean' then + local b = obj[1] + if b then + return tostring(b) + end + end +end + +function m.getKeyName(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + if obj.field then + return obj.field[1] + end + elseif tp == 'getmethod' + or tp == 'setmethod' then + if obj.method then + return obj.method[1] + end + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' + or tp == 'doc.see.field' then + return obj[1] + elseif tp == 'doc.class' then + return obj.class[1] + elseif tp == 'doc.alias' then + return obj.alias[1] + elseif tp == 'doc.field' then + return obj.field[1] + elseif tp == 'doc.field.name' then + return obj[1] + elseif tp == 'dummy' then + return obj[1] + end + return m.getKeyNameOfLiteral(obj) +end + +function m.getKeyTypeOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return 'string' + elseif tp == 'string' then + return 'string' + elseif tp == 'number' then + return 'number' + elseif tp == 'boolean' then + return 'boolean' + end +end + +function m.getKeyType(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return 'string' + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return 'local' + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + return 'string' + elseif tp == 'getmethod' + or tp == 'setmethod' then + return 'string' + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyTypeOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' + or tp == 'doc.see.field' then + return 'string' + elseif tp == 'doc.class' then + return 'string' + elseif tp == 'doc.alias' then + return 'string' + elseif tp == 'doc.field' then + return 'string' + elseif tp == 'dummy' then + return 'string' + end + if tp == 'doc.field.name' then + return 'string' + end + return m.getKeyTypeOfLiteral(obj) +end + +--- 测试 a 到 b 的路径(不经过函数,不考虑 goto), +--- 每个路径是一个 block 。 +--- +--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>` +--- +--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>` +--- +--- 否则返回 `false` +--- +--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。 +---@param a table +---@param b table +---@return string|boolean mode +---@return table pathA? +---@return table pathB? +function m.getPath(a, b, sameFunction) + --- 首先测试双方在同一个函数内 + if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then + return false + end + local mode + local objA + local objB + if a.finish < b.start then + mode = 'before' + objA = a + objB = b + elseif a.start > b.finish then + mode = 'after' + objA = b + objB = a + else + return 'equal', {}, {} + end + local pathA = {} + local pathB = {} + for _ = 1, 1000 do + objA = m.getParentBlock(objA) + pathA[#pathA+1] = objA + if (not sameFunction and objA.type == 'function') or objA.type == 'main' then + break + end + end + for _ = 1, 1000 do + objB = m.getParentBlock(objB) + pathB[#pathB+1] = objB + if (not sameFunction and objA.type == 'function') or objB.type == 'main' then + break + end + end + -- pathA: {1, 2, 3, 4, 5} + -- pathB: {5, 6, 2, 3} + local top = #pathB + local start + for i = #pathA, 1, -1 do + local currentBlock = pathA[i] + if currentBlock == pathB[top] then + start = i + break + end + end + if not start then + return nil + end + -- pathA: { 1, 2, 3} + -- pathB: {5, 6, 2, 3} + local extra = 0 + local align = top - start + for i = start, 1, -1 do + local currentA = pathA[i] + local currentB = pathB[i+align] + if currentA ~= currentB then + extra = i + break + end + end + -- pathA: {1} + local resultA = {} + for i = extra, 1, -1 do + resultA[#resultA+1] = pathA[i] + end + -- pathB: {5, 6} + local resultB = {} + for i = extra + align, 1, -1 do + resultB[#resultB+1] = pathB[i] + end + return mode, resultA, resultB +end + +---是否是全局变量(包括 _G.XXX 形式) +---@param source parser.guide.object +---@return boolean +function m.isGlobal(source) + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.special == '_G' then + return true + end + return false +end + return m diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index ae8e3f34..335c8f24 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -1,7 +1,7 @@ local m = require 'lpeglabel' local re = require 'parser.relabel' local lines = require 'parser.lines' -local guide = require 'core.guide' +local guide = require 'parser.guide' local grammar = require 'parser.grammar' local TokenTypes, TokenStarts, TokenFinishs, TokenContents @@ -194,6 +194,7 @@ local function parseClass(parent) local result = { type = 'doc.class', parent = parent, + fields = {}, } result.class = parseName('doc.class.name', result) if not result.class then @@ -300,8 +301,8 @@ local function parseTypeUnitTable(parent, node) node.parent = result; result.finish = getFinish() - result.key = key - result.value = value + result.tkey = key + result.tvalue = value return result end @@ -425,9 +426,10 @@ local function parseTypeUnit(parent, content) return result end -local function parseResume() +local function parseResume(parent) local result = { - type = 'doc.resume' + type = 'doc.resume', + parent = parent, } if checkToken('symbol', '>', 1) then @@ -456,7 +458,6 @@ local function parseResume() return result end -local LastType function parseType(parent) local result = { type = 'doc.type', @@ -484,13 +485,7 @@ function parseType(parent) break end -- TypeLiteral,指代类型的字面值。比如,对于类 Cat 来说,它的 TypeLiteral 是 "Cat" - typeLiteral = { - type = 'doc.type.typeliteral', - parent = result, - start = getStart(), - finish = nil, - node = nil, - } + typeLiteral = true end if tp == 'name' then @@ -501,10 +496,7 @@ function parseType(parent) end if typeLiteral then nextToken() - typeLiteral.finish = getFinish() - typeLiteral.node = typeUnit - typeUnit.parent = typeLiteral - typeUnit = typeLiteral + typeUnit.literal = true end result.types[#result.types+1] = typeUnit if not result.start then @@ -566,7 +558,7 @@ function parseType(parent) row = row + i + 1 local finishPos = nextComm.text:find('#', 3) or #nextComm.text parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1) - local resume = parseResume() + local resume = parseResume(result) if resume then if comments then resume.comment = table.concat(comments, '\n') @@ -1122,17 +1114,25 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish) end local src = sources[index] if src.start < start then - left = index + left = index + 1 else right = index end end - for i = index - 1, max do + + -- 从前往后进行绑定 + for i = index, max do local src = sources[i] if src then if src.start > finish then break end + -- 遇到table后中断,处理以下情况: + -- ---@type AAA + -- local t = {x = 1, y = 2} + if src.type == 'table' then + break + end if src.start >= start then src.bindDocs = binded bindSources[#bindSources+1] = src @@ -1152,21 +1152,22 @@ local function bindParamAndReturnIndex(binded) if not func then return end - if not func.args then - return - end - local paramIndex = 0 - local paramMap = {} - for _, param in ipairs(func.args) do - paramIndex = paramIndex + 1 - if param[1] then - paramMap[param[1]] = paramIndex + local paramMap + if func.args then + local paramIndex = 0 + paramMap = {} + for _, param in ipairs(func.args) do + paramIndex = paramIndex + 1 + if param[1] then + paramMap[param[1]] = paramIndex + end end + func.docParamMap = paramMap end local returnIndex = 0 for _, doc in ipairs(binded) do if doc.type == 'doc.param' then - if doc.extends then + if paramMap and doc.extends then doc.extends.paramIndex = paramMap[doc.param[1]] end elseif doc.type == 'doc.return' then @@ -1178,6 +1179,24 @@ local function bindParamAndReturnIndex(binded) end end +local function bindClassAndFields(binded) + local class + for _, doc in ipairs(binded) do + if doc.type == 'doc.class' then + -- 多个class连续写在一起,只有最后一个class可以绑定source + if class then + class.bindSources = nil + end + class = doc + elseif doc.type == 'doc.field' then + if class then + class.fields[#class.fields+1] = doc + doc.class = class + end + end + end +end + local function bindDoc(sources, lns, binded) if not binded then return @@ -1200,6 +1219,7 @@ local function bindDoc(sources, lns, binded) bindDocsBetween(sources, binded, bindSources, nstart, nfinish) end bindParamAndReturnIndex(binded) + bindClassAndFields(binded) end local function bindDocs(state) @@ -1214,6 +1234,7 @@ local function bindDocs(state) or src.type == 'tablefield' or src.type == 'tableindex' or src.type == 'function' + or src.type == 'table' or src.type == '...' then sources[#sources+1] = src end diff --git a/script/vm/eachDef.lua b/script/vm/eachDef.lua index d72c8f01..6f7af295 100644 --- a/script/vm/eachDef.lua +++ b/script/vm/eachDef.lua @@ -1,49 +1,7 @@ ---@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local files = require 'files' -local util = require 'utility' -local await = require 'await' -local config = require 'config' +local vm = require 'vm.vm' +local searcher = require 'core.searcher' -local function getDefs(source, deep) - local results = {} - local lock = vm.lock('eachDef', source) - if not lock then - return results - end - - await.delay() - - deep = config.config.intelliSense.searchDepth + (deep or 0) - - local clock = os.clock() - local myResults, count = guide.requestDefinition(source, vm.interface, deep) - if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestDefinition', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) - end - vm.mergeResults(results, myResults) - - lock() - - return results -end - -function vm.getDefs(source, deep) - deep = deep or -999 - if guide.isGlobal(source) then - local key = guide.getKeyName(source) - if not key then - return {} - end - return vm.getGlobalSets(key) - else - local cache = vm.getCache('eachDef')[source] - if not cache or cache.deep < deep then - cache = getDefs(source, deep) - cache.deep = deep - vm.getCache('eachDef')[source] = cache - end - return cache - end +function vm.getDefs(source, field) + return searcher.requestDefinition(source, field) end diff --git a/script/vm/eachField.lua b/script/vm/eachField.lua deleted file mode 100644 index 59f35f0c..00000000 --- a/script/vm/eachField.lua +++ /dev/null @@ -1,109 +0,0 @@ ----@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local await = require 'await' -local config = require 'config' - -local function getFields(source, deep, filterKey) - local unlock = vm.lock('eachField', source) - if not unlock then - return {} - end - - while source.type == 'paren' do - source = source.exp - if not source then - return {} - end - end - deep = config.config.intelliSense.searchDepth + (deep or 0) - - await.delay() - local results = guide.requestFields(source, vm.interface, deep, filterKey) - - unlock() - return results -end - -local function getDefFields(source, deep, filterKey) - local unlock = vm.lock('eachDefField', source) - if not unlock then - return {} - end - - while source.type == 'paren' do - source = source.exp - if not source then - return {} - end - end - deep = config.config.intelliSense.searchDepth + (deep or 0) - - await.delay() - local results = guide.requestDefFields(source, vm.interface, deep, filterKey) - - unlock() - return results -end - -local function getFieldsBySource(source, deep, filterKey) - deep = deep or -999 - local cache = vm.getCache('eachField')[source] - if not cache or cache.deep < deep then - cache = getFields(source, deep, filterKey) - cache.deep = deep - if not filterKey then - vm.getCache('eachField')[source] = cache - end - end - return cache -end - -local function getDefFieldsBySource(source, deep, filterKey) - deep = deep or -999 - local cache = vm.getCache('eachDefField')[source] - if not cache or cache.deep < deep then - cache = getDefFields(source, deep, filterKey) - cache.deep = deep - if not filterKey then - vm.getCache('eachDefField')[source] = cache - end - end - return cache -end - -function vm.getFields(source, deep) - if source.special == '_G' then - return vm.getGlobals '*' - end - if guide.isGlobal(source) then - local name = guide.getKeyName(source) - if not name then - return {} - end - local cache = vm.getCache('eachFieldOfGlobal')[name] - or getFieldsBySource(source, deep) - vm.getCache('eachFieldOfGlobal')[name] = cache - return cache - else - return getFieldsBySource(source, deep) - end -end - -function vm.getDefFields(source, deep) - if source.special == '_G' then - return vm.getGlobalSets '*' - end - if guide.isGlobal(source) then - local name = guide.getKeyName(source) - if not name then - return {} - end - local cache = vm.getCache('eachDefFieldOfGlobal')[name] - or getDefFieldsBySource(source, deep) - vm.getCache('eachDefFieldOfGlobal')[name] = cache - return cache - else - return getDefFieldsBySource(source, deep) - end -end diff --git a/script/vm/eachRef.lua b/script/vm/eachRef.lua index 9d0f061c..5aca198e 100644 --- a/script/vm/eachRef.lua +++ b/script/vm/eachRef.lua @@ -1,48 +1,7 @@ ---@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local util = require 'utility' -local await = require 'await' -local config = require 'config' +local vm = require 'vm.vm' +local searcher = require 'core.searcher' -local function getRefs(source, deep) - local results = {} - local lock = vm.lock('eachRef', source) - if not lock then - return results - end - - await.delay() - - deep = config.config.intelliSense.searchDepth + (deep or 0) - - local clock = os.clock() - local myResults, count = guide.requestReference(source, vm.interface, deep) - if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestReference', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) - end - vm.mergeResults(results, myResults) - - lock() - - return results -end - -function vm.getRefs(source, deep) - deep = deep or -999 - if guide.isGlobal(source) then - local key = guide.getKeyName(source) - if not key then - return {} - end - return vm.getGlobals(key) - else - local cache = vm.getCache('eachRef')[source] - if not cache or cache.deep < deep then - cache = getRefs(source, deep) - cache.deep = deep - vm.getCache('eachRef')[source] = cache - end - return cache - end +function vm.getRefs(source, field) + return searcher.requestReference(source, field) end diff --git a/script/vm/getClass.lua b/script/vm/getClass.lua deleted file mode 100644 index 5c68e0bb..00000000 --- a/script/vm/getClass.lua +++ /dev/null @@ -1,64 +0,0 @@ ----@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' - -local function lookUpDocClass(source) - local infers = vm.getInfers(source, 0) - for _, infer in ipairs(infers) do - if infer.source.type == 'doc.class' - or infer.source.type == 'doc.type' then - return guide.viewInferType(infers) - end - end - return nil -end - -local function getClass(source, classes, depth, deep) - local docClass = lookUpDocClass(source) - if docClass then - classes[docClass] = true - return - end - if depth > 3 then - return - end - local value = guide.getObjectValue(source) or source - if not deep then - if value and value.type == 'string' then - classes[value[1]] = true - end - else - for _, src in ipairs(vm.getDefFields(value)) do - local key = vm.getKeyName(src) - if not key then - goto CONTINUE - end - local lkey = key:lower() - if lkey == 'type' - or lkey == '__name' - or lkey == 'name' - or lkey == 'class' then - local value = guide.getObjectValue(src) - if value and value.type == 'string' then - classes[value[1]] = true - end - end - ::CONTINUE:: - end - end - if next(classes) then - return - end - vm.eachMeta(source, function (mt) - getClass(mt, classes, depth + 1, deep) - end) -end - -function vm.getClass(source, deep) - local classes = {} - getClass(source, classes, 1, deep) - if not next(classes) then - return nil - end - return guide.mergeTypes(classes) -end diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua index cfa9326f..dbb8b4fd 100644 --- a/script/vm/getDocs.lua +++ b/script/vm/getDocs.lua @@ -1,148 +1,51 @@ -local files = require 'files' -local util = require 'utility' -local guide = require 'core.guide' +local files = require 'files' +local guide = require 'parser.guide' ---@type vm -local vm = require 'vm.vm' -local config = require 'config' +local vm = require 'vm.vm' +local config = require 'config' +local searcher = require 'core.searcher' -local function getTypesOfFile(uri) - local types = {} - local ast = files.getAst(uri) - if not ast or not ast.ast.docs then - return types - end - guide.eachSource(ast.ast.docs, function (src) - if src.type == 'doc.type.name' - or src.type == 'doc.class.name' - or src.type == 'doc.extends.name' - or src.type == 'doc.alias.name' then - if src.type == 'doc.type.name' then - if guide.getParentDocTypeTable(src) then - return - end +local function getDocDefinesInAst(results, root, name) + for _, doc in ipairs(root.docs) do + if doc.type == 'doc.class' then + if not name or name == doc.class[1] then + results[#results+1] = doc.class end - local name = src[1] - if name then - if not types[name] then - types[name] = {} - end - types[name][#types[name]+1] = src + elseif doc.type == 'doc.alias' then + if not name or name == doc.alias[1] then + results[#results+1] = doc.alias end end - end) - return types + end end -local function getDocTypes(name) +---获取class与alias +---@param name? string +---@return parser.guide.object[] +function vm.getDocDefines(name) local results = {} - if name == 'any' - or name == 'nil' then - return results - end for uri in files.eachFile() do - local cache = files.getCache(uri) - cache.classes = cache.classes or getTypesOfFile(uri) - if name == '*' then - for _, sources in util.sortPairs(cache.classes) do - for _, source in ipairs(sources) do - results[#results+1] = source - end - end - else - if cache.classes[name] then - for _, source in ipairs(cache.classes[name]) do - results[#results+1] = source - end - end - end + local ast = files.getAst(uri) + getDocDefinesInAst(results, ast.ast, name) end return results end -function vm.getDocEnums(doc, mark, results) +function vm.getDocEnums(doc) if not doc then return nil end - mark = mark or {} - if mark[doc] then - return nil - end - mark[doc] = true - results = results or {} - for _, enum in ipairs(doc.enums) do - results[#results+1] = enum - end - for _, resume in ipairs(doc.resumes) do - results[#results+1] = resume - end - for _, unit in ipairs(doc.types) do - if unit.type == 'doc.type.name' then - for _, other in ipairs(vm.getDocTypes(unit[1])) do - if other.type == 'doc.alias.name' then - vm.getDocEnums(other.parent.extends, mark, results) - end - end - end - end - return results -end + local defs = searcher.requestDefinition(doc) + local results = {} -function vm.getDocTypeUnits(doc, mark, results) - if not doc then - return nil - end - mark = mark or {} - if mark[doc] then - return nil - end - mark[doc] = true - results = results or {} - for _, enum in ipairs(doc.enums) do - results[#results+1] = enum - end - for _, resume in ipairs(doc.resumes) do - results[#results+1] = resume - end - for _, unit in ipairs(doc.types) do - if unit.type == 'doc.type.name' then - for _, other in ipairs(vm.getDocTypes(unit[1])) do - if other.type == 'doc.alias.name' then - vm.getDocTypeUnits(other.parent.extends, mark, results) - elseif other.type == 'doc.class.name' then - results[#results+1] = other - end - end - else - results[#results+1] = unit + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then + results[#results+1] = def end end - return results -end - -function vm.getDocTypes(name) - local cache = vm.getCache('getDocTypes')[name] - if cache ~= nil then - return cache - end - cache = getDocTypes(name) - vm.getCache('getDocTypes')[name] = cache - return cache -end -function vm.getDocClass(name) - local cache = vm.getCache('getDocClass')[name] - if cache ~= nil then - return cache - end - cache = {} - local results = getDocTypes(name) - for _, doc in ipairs(results) do - if doc.type == 'doc.class.name' then - cache[#cache+1] = doc - end - end - vm.getCache('getDocClass')[name] = cache - return cache + return results end function vm.isMetaFile(uri) @@ -224,7 +127,7 @@ end function vm.isDeprecated(value, deep) if deep then - local defs = vm.getDefs(value, 0) + local defs = vm.getDefs(value) if #defs == 0 then return false end diff --git a/script/vm/getGlobals.lua b/script/vm/getGlobals.lua index 2752ce09..bea192ef 100644 --- a/script/vm/getGlobals.lua +++ b/script/vm/getGlobals.lua @@ -1,5 +1,6 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local await = require "await" +local searcher = require "core.searcher" ---@type vm local vm = require 'vm.vm' local files = require 'files' @@ -17,12 +18,8 @@ local function getGlobalsOfFile(uri) end local globals = {} cache.globals = globals - local ast = files.getAst(uri) - if not ast then - return globals - end tracy.ZoneBeginN 'getGlobalsOfFile' - local results = guide.findGlobals(ast.ast) + local results = searcher.findGlobals(uri) local subscribe = ws.getCache 'globalSubscribe' subscribe[uri] = {} local mark = {} @@ -34,7 +31,7 @@ local function getGlobalsOfFile(uri) goto CONTINUE end mark[res] = true - local name = guide.getSimpleName(res) + local name = guide.getKeyName(res) if name then if not globals[name] then globals[name] = {} @@ -59,12 +56,8 @@ local function getGlobalSetsOfFile(uri) end local globals = {} cache.globalSets = globals - local ast = files.getAst(uri) - if not ast then - return globals - end tracy.ZoneBeginN 'getGlobalSetsOfFile' - local results = guide.findGlobals(ast.ast) + local results = searcher.findGlobals(uri, 'def') local subscribe = ws.getCache 'globalSetsSubscribe' subscribe[uri] = {} local mark = {} @@ -76,16 +69,14 @@ local function getGlobalSetsOfFile(uri) goto CONTINUE end mark[res] = true - if vm.isSet(res) then - local name = guide.getSimpleName(res) - if name then - if not globals[name] then - globals[name] = {} - subscribe[uri][#subscribe[uri]+1] = name - end - globals[name][#globals[name]+1] = res - globals['*'][#globals['*']+1] = res + local name = guide.getKeyName(res) + if name then + if not globals[name] then + globals[name] = {} + subscribe[uri][#subscribe[uri]+1] = name end + globals[name][#globals[name]+1] = res + globals['*'][#globals['*']+1] = res end ::CONTINUE:: end @@ -265,7 +256,7 @@ files.watch(function (ev, uri) end needUpdateGlobals[uri] = true elseif ev == 'create' then - getGlobalsOfFile(uri) - getGlobalSetsOfFile(uri) + --getGlobalsOfFile(uri) + --getGlobalSetsOfFile(uri) end end) diff --git a/script/vm/getInfer.lua b/script/vm/getInfer.lua deleted file mode 100644 index 5447ca23..00000000 --- a/script/vm/getInfer.lua +++ /dev/null @@ -1,104 +0,0 @@ ----@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local util = require 'utility' -local await = require 'await' -local config = require 'config' - -NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) - ---- 是否包含某种类型 -function vm.hasType(source, type, deep) - local defs = vm.getDefs(source, deep) - for i = 1, #defs do - local def = defs[i] - local value = guide.getObjectValue(def) or def - if value.type == type then - return true - end - end - return false -end - ---- 是否包含某种类型 -function vm.hasInferType(source, type, deep) - local infers = vm.getInfers(source, deep) - for i = 1, #infers do - local infer = infers[i] - if infer.type == type then - return true - end - end - return false -end - -function vm.getInferType(source, deep) - local infers = vm.getInfers(source, deep) - return guide.viewInferType(infers) -end - -function vm.getInferLiteral(source, deep) - local infers = vm.getInfers(source, deep) - local literals = {} - local mark = {} - for _, infer in ipairs(infers) do - local value = infer.value - if value and not mark[value] then - mark[value] = true - literals[#literals+1] = util.viewLiteral(value) - end - end - if #literals == 0 then - return nil - end - table.sort(literals) - return table.concat(literals, '|') -end - -local function getInfers(source, deep) - local results = {} - local lock = vm.lock('getInfers', source) - if not lock then - return results - end - - deep = config.config.intelliSense.searchDepth + (deep or 0) - - await.delay() - - local clock = os.clock() - local myResults, count = guide.requestInfer(source, vm.interface, deep) - if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestInfer', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) - end - vm.mergeResults(results, myResults) - - lock() - - return results -end - -local function getInfersBySource(source, deep) - deep = deep or -999 - local cache = vm.getCache('getInfers')[source] - if not cache or cache.deep < deep then - cache = getInfers(source, deep) - cache.deep = deep - vm.getCache('getInfers')[source] = cache - end - return cache -end - ---- 获取对象的值 ---- 会尝试穿透函数调用 -function vm.getInfers(source, deep) - if guide.isGlobal(source) then - local name = guide.getKeyName(source) - local cache = vm.getCache('getInfersOfGlobal')[name] - or getInfersBySource(source, deep) - vm.getCache('getInfersOfGlobal')[name] = cache - return cache - else - return getInfersBySource(source, deep) - end -end diff --git a/script/vm/getLibrary.lua b/script/vm/getLibrary.lua index b52f7240..a3c8feb0 100644 --- a/script/vm/getLibrary.lua +++ b/script/vm/getLibrary.lua @@ -1,8 +1,11 @@ ---@type vm local vm = require 'vm.vm' -function vm.getLibraryName(source, deep) - local defs = vm.getDefs(source, deep) +function vm.getLibraryName(source) + if source.special then + return source.special + end + local defs = vm.getDefs(source) for _, def in ipairs(defs) do if def.special then return def.special diff --git a/script/vm/getLinks.lua b/script/vm/getLinks.lua index 91a5f1a0..51a18d58 100644 --- a/script/vm/getLinks.lua +++ b/script/vm/getLinks.lua @@ -1,5 +1,4 @@ -local guide = require 'core.guide' ----@type vm +local guide = require 'parser.guide' local vm = require 'vm.vm' local files = require 'files' @@ -33,11 +32,17 @@ local function getFileLinks(uri) return links end +local function getFileLinksOrCache(uri) + local cache = files.getCache(uri) + cache.links = cache.links or getFileLinks(uri) + return cache.links +end + local function getLinksTo(uri) uri = files.asKey(uri) local links = {} for u in files.eachFile() do - local ls = vm.getFileLinks(u) + local ls = getFileLinksOrCache(u) if ls[uri] then for _, l in ipairs(ls[uri]) do links[#links+1] = l @@ -56,9 +61,3 @@ function vm.getLinksTo(uri) vm.getCache('getLinksTo')[uri] = cache return cache end - -function vm.getFileLinks(uri) - local cache = files.getCache(uri) - cache.links = cache.links or getFileLinks(uri) - return cache.links -end diff --git a/script/vm/getMeta.lua b/script/vm/getMeta.lua deleted file mode 100644 index 44d1874a..00000000 --- a/script/vm/getMeta.lua +++ /dev/null @@ -1,53 +0,0 @@ ----@type vm -local vm = require 'vm.vm' - -local function eachMetaOfArg1(source, callback) - local node, index = vm.getArgInfo(source) - local special = vm.getSpecial(node) - if special == 'setmetatable' and index == 1 then - local mt = node.next.args[2] - if mt then - callback(mt) - end - end -end - -local function eachMetaOfRecv(source, callback) - if not source or source.type ~= 'select' then - return - end - if source.index ~= 1 then - return - end - local call = source.vararg - if not call or call.type ~= 'call' then - return - end - local special = vm.getSpecial(call.node) - if special ~= 'setmetatable' then - return - end - local mt = call.args[2] - if mt then - callback(mt) - end -end - -function vm.eachMetaValue(source, callback) - vm.eachMeta(source, function (mt) - for _, src in ipairs(vm.getDefFields(mt)) do - if vm.getKeyName(src) == '__index' then - if src.value then - for _, valueSrc in ipairs(vm.getDefFields(src.value)) do - callback(valueSrc) - end - end - end - end - end) -end - -function vm.eachMeta(source, callback) - eachMetaOfArg1(source, callback) - eachMetaOfRecv(source.value, callback) -end diff --git a/script/vm/guideInterface.lua b/script/vm/guideInterface.lua index ae060481..e59fc6e3 100644 --- a/script/vm/guideInterface.lua +++ b/script/vm/guideInterface.lua @@ -2,7 +2,7 @@ local vm = require 'vm.vm' local files = require 'files' local ws = require 'workspace' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local await = require 'await' local config = require 'config' @@ -27,7 +27,7 @@ function m.require(args, index) return nil end local results = {} - local myUri = guide.getUri(args[1]) + local myUri = searcher.getUri(args[1]) local uris = ws.findUrisByRequirePath(reqName) for _, uri in ipairs(uris) do if not files.eq(myUri, uri) then @@ -47,7 +47,7 @@ function m.dofile(args, index) return end local results = {} - local myUri = guide.getUri(args[1]) + local myUri = searcher.getUri(args[1]) local uris = ws.findUrisByFilePath(reqName) for _, uri in ipairs(uris) do if not files.eq(myUri, uri) then @@ -87,9 +87,9 @@ function vm.interface.global(name, onlyDef) end end -function vm.interface.docType(name) +function vm.interface.doc(name, type) await.delay() - return vm.getDocTypes(name) + return vm.getDocNames(name, type) end function vm.interface.link(uri) diff --git a/script/vm/init.lua b/script/vm/init.lua index b9e8e147..c38f01d5 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -2,10 +2,6 @@ local vm = require 'vm.vm' require 'vm.getGlobals' require 'vm.getDocs' require 'vm.getLibrary' -require 'vm.getInfer' -require 'vm.getClass' -require 'vm.getMeta' -require 'vm.eachField' require 'vm.eachDef' require 'vm.eachRef' require 'vm.getLinks' diff --git a/script/vm/vm.lua b/script/vm/vm.lua index 0248ad8c..ebd0102b 100644 --- a/script/vm/vm.lua +++ b/script/vm/vm.lua @@ -1,18 +1,14 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local util = require 'utility' local files = require 'files' local timer = require 'timer' local setmetatable = setmetatable -local assert = assert -local require = require -local type = type local running = coroutine.running local ipairs = ipairs local log = log local xpcall = xpcall local mathHuge = math.huge -local collectgarbage = collectgarbage _ENV = nil @@ -63,10 +59,6 @@ function m.getArgInfo(source) return nil end -function m.getSpecial(source) - return guide.getSpecial(source) -end - function m.getKeyName(source) if not source then return nil @@ -65,22 +65,22 @@ local function testAll() test 'references' test 'definition' test 'type_inference' + test 'hover' + test 'completion' + test 'crossfile' test 'diagnostics' test 'highlight' test 'rename' - test 'hover' - test 'completion' test 'signature' test 'document_symbol' test 'code_action' test 'type_formatting' - test 'crossfile' --test 'other' end local function main() debug.setcstacklimit(1000) - require 'core.guide'.debugMode = true + require 'core.searcher'.debugMode = true require 'language' 'zh-cn' require 'utility'.enableCloseFunction() diff --git a/test/basic/init.lua b/test/basic/init.lua index a3a11f62..1b698493 100644 --- a/test/basic/init.lua +++ b/test/basic/init.lua @@ -1,219 +1,2 @@ -local files = require 'files' -local tm = require 'text-merger' - -local function TEST(source) - return function (expect) - return function (changes) - files.removeAll() - files.setText('', source) - local text = tm('', changes) - assert(text == expect) - end - end -end - -TEST [[ - - -function Test(self) - -end -]][[ - - -function Test(self) - -end - -asser]]{ - [1] = { - range = { - ["end"] = { - character = 0, - line = 5, - }, - start = { - character = 0, - line = 5, - }, - }, - rangeLength = 0, - text = "\ -", - }, - [2] = { - range = { - ["end"] = { - character = 0, - line = 6, - }, - start = { - character = 0, - line = 6, - }, - }, - rangeLength = 0, - text = "a", - }, - [3] = { - range = { - ["end"] = { - character = 1, - line = 6, - }, - start = { - character = 1, - line = 6, - }, - }, - rangeLength = 0, - text = "s", - }, - [4] = { - range = { - ["end"] = { - character = 2, - line = 6, - }, - start = { - character = 2, - line = 6, - }, - }, - rangeLength = 0, - text = "s", - }, - [5] = { - range = { - ["end"] = { - character = 3, - line = 6, - }, - start = { - character = 3, - line = 6, - }, - }, - rangeLength = 0, - text = "e", - }, - [6] = { - range = { - ["end"] = { - character = 4, - line = 6, - }, - start = { - character = 4, - line = 6, - }, - }, - rangeLength = 0, - text = "r", - }, -} - -TEST [[ -local mt = {} - -function mt['xxx']() - - - -end -]] [[ -local mt = {} - -function mt['xxx']() - -end -]] { - [1] = { - range = { - ["end"] = { - character = 4, - line = 5, - }, - start = { - character = 4, - line = 3, - }, - }, - rangeLength = 8, - text = "", - }, -} - -TEST [[ -local mt = {} - -function mt['xxx']() - -end -]] [[ -local mt = {} - -function mt['xxx']() - p -end -]] { - [1] = { - range = { - ["end"] = { - character = 4, - line = 3, - }, - start = { - character = 4, - line = 3, - }, - }, - rangeLength = 0, - text = "p", - }, -} - -TEST [[ -print(12345) -]] [[ -print(123 -45) -]] { - [1] = { - range = { - ["end"] = { - character = 9, - line = 0, - }, - start = { - character = 9, - line = 0, - }, - }, - rangeLength = 0, - text = "\ -", - }, -} - -TEST [[ -print(123 -45) -]] [[ -print(12345) -]] { - [1] = { - range = { - ["end"] = { - character = 0, - line = 1, - }, - start = { - character = 9, - line = 0, - }, - }, - rangeLength = 2, - text = "", - }, -} +require 'basic.textmerger' +require 'basic.noder' diff --git a/test/basic/linker.txt b/test/basic/linker.txt new file mode 100644 index 00000000..ea3ba180 --- /dev/null +++ b/test/basic/linker.txt @@ -0,0 +1,141 @@ +ast -> linkers = { + ['g|"X"|"Y"|"Z"'] = {src1, src2, src3}, + ['g|"X"|"Y"'] = {src4, src5, src6}, + ['g|"X"'] = {src7, src8, src9}, + ['l|7'] = {src10}, + ['l|7|"x"'] = {src11}, + ['l|11|"k"'] = {src12}, +} + +```lua +x.y.<?z?> = <!f!> + +<?g?> = x.y.z + +t.<!z!> = 1 +x.y = t + +x = { + y = { + <!z!> = 1 + } +} +``` + +expect: 'l|x|y|z' +forward: 'l|x|y|z' -> f +backward: 'l|x|y|z' -> g +last: 'l|x|y' + 'z' + +expect: 'l|x|y' + '|z' +forward: 'l|t' + '|z' -> 'l|t|z' -> t.z +backward: nil +last: 'l|x' + '|y|z' + +expect: 'l|x' + '|y|z' +forward: 'l|0' + '|y|z' -> 'l|0|y|z' +backward: nil +last: nil + +expect: 'l|0|y|z' +forward: nil +backward: nil +last: 'l|0|y' + '|z' + +expect: 'l|0|y' + '|z' +forward: 'l|1'+ '|z' -> 'l|1|z' -> field z +backward: nil +last: 'l|0' + '|y|z' + + +```lua +a = { + b = { + <?c?> = 1, + } +} + +print(a.b.<!c!>) +``` + +expect: 't|3|c' +forward: nil +backward: nil +last: 't|3' + '|c' + +expect: 't|3' + '|c' +forward: nil +backward: 't|2|b' + '|c' +last: nil + +expect: 't|2|b|c' +forward: nil +backward: 't|2|b' + '|c' +last: nil + +```lua +---@return <?A?> +local function f() +end + +local <!x!> = f() +``` + +'d|A' +'f|1|#1' +'f|1' + '|#1' +'l|1' + '|#1' +'s|1' + '|#1' + +```lua +---@generic T +---@param a T +---@return T +local function f(a) end + +local <?c?> + +local <!v!> = f(c) +``` + +'l1' +'l2|@1' +'f|1|@1' +'f|1|#1' + +``` +---@generic T +---@param p T +---@return T +local function f(p) end + +local <?r?> = f(<!k!>) +``` + +l:r +s:1#1 call +l:f#1 call +f:1#1 call -> f:1&T = l:k +l:f@1 --> 从保存的call信息里找到 f:1&T = l:k +l:k + + + +``` +---@generic T, V +---@param p T +---@return fun(V):T, V +local function f(p) end + +local f2 = f(<!k!>) +local <?r?> = f2() +``` + +l:r +s:2|#1 call1 +l:f2|#1 call1 +f:2|#1 call1 +s:1#1|#1 call2 +f:1#1|#1 call2 -> f:1&T = l:k +dfun:1|#1 +dn:V -> f:1&T = l:k diff --git a/test/basic/noder.lua b/test/basic/noder.lua new file mode 100644 index 00000000..3e5e9f25 --- /dev/null +++ b/test/basic/noder.lua @@ -0,0 +1,146 @@ +local noder = require 'core.noder' +local files = require 'files' +local util = require 'utility' +local guide = require 'parser.guide' + +local function getSource(pos) + local ast = files.getAst('') + return guide.eachSourceContain(ast.ast, pos, function (source) + if source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' + or source.type == 'tablefield' + or source.type == 'setindex' + or source.type == 'getindex' + or source.type == 'tableindex' + or source.type == 'label' + or source.type == 'goto' then + return source + end + end) +end + +local CARE = {} +local function TEST(script) + return function (expect) + files.removeAll() + local start = script:find('<?', 1, true) + local finish = script:find('?>', 1, true) + local pos = (start + finish) // 2 + 1 + local newScript = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ') + files.setText('', newScript) + local source = getSource(pos) + assert(source) + noder.compileNodes(source) + local result = { + id = noder.getID(source), + } + + expect['id'] = expect['id']:gsub('|', '\x1F') + + for key in pairs(CARE) do + assert(result[key] == expect[key]) + end + end +end + +CARE['id'] = true +TEST [[ +local <?x?> +]] { + id = 'l:9', +} + +TEST [[ +local x +print(<?x?>) +]] { + id = 'l:7', +} + +TEST [[ +local x +<?x?> = 1 +]] { + id = 'l:7', +} + +TEST [[ +print(<?X?>) +]] { + id = 'g:"X"', +} + +TEST [[ +print(<?X?>) +]] { + id = 'g:"X"', +} + +TEST [[ +local x +print(x.y.<?z?>) +]] { + id = 'l:7|"y"|"z"', +} + +TEST [[ +local x +function x:<?f?>() end +]] { + id = 'l:7|"f"', +} + +TEST [[ +print(X.Y.<?Z?>) +]] { + id = 'g:"X"|"Y"|"Z"', +} + +TEST [[ +function x:<?f?>() end +]] { + id = 'g:"x"|"f"', +} + +TEST [[ +{ + <?x?> = 1, +} +]] { + id = 't:1|"x"', +} + +TEST [[ +return <?X?> +]] { + id = 'g:"X"', +} + +TEST [[ +function f() + return <?X?> +end +]] { + id = 'g:"X"', +} + +TEST [[ +::<?label?>:: +goto label +]] { + id = 'l:5', +} + +TEST [[ +::label:: +goto <?label?> +]] { + id = 'l:3', +} diff --git a/test/basic/textmerger.lua b/test/basic/textmerger.lua new file mode 100644 index 00000000..a3a11f62 --- /dev/null +++ b/test/basic/textmerger.lua @@ -0,0 +1,219 @@ +local files = require 'files' +local tm = require 'text-merger' + +local function TEST(source) + return function (expect) + return function (changes) + files.removeAll() + files.setText('', source) + local text = tm('', changes) + assert(text == expect) + end + end +end + +TEST [[ + + +function Test(self) + +end +]][[ + + +function Test(self) + +end + +asser]]{ + [1] = { + range = { + ["end"] = { + character = 0, + line = 5, + }, + start = { + character = 0, + line = 5, + }, + }, + rangeLength = 0, + text = "\ +", + }, + [2] = { + range = { + ["end"] = { + character = 0, + line = 6, + }, + start = { + character = 0, + line = 6, + }, + }, + rangeLength = 0, + text = "a", + }, + [3] = { + range = { + ["end"] = { + character = 1, + line = 6, + }, + start = { + character = 1, + line = 6, + }, + }, + rangeLength = 0, + text = "s", + }, + [4] = { + range = { + ["end"] = { + character = 2, + line = 6, + }, + start = { + character = 2, + line = 6, + }, + }, + rangeLength = 0, + text = "s", + }, + [5] = { + range = { + ["end"] = { + character = 3, + line = 6, + }, + start = { + character = 3, + line = 6, + }, + }, + rangeLength = 0, + text = "e", + }, + [6] = { + range = { + ["end"] = { + character = 4, + line = 6, + }, + start = { + character = 4, + line = 6, + }, + }, + rangeLength = 0, + text = "r", + }, +} + +TEST [[ +local mt = {} + +function mt['xxx']() + + + +end +]] [[ +local mt = {} + +function mt['xxx']() + +end +]] { + [1] = { + range = { + ["end"] = { + character = 4, + line = 5, + }, + start = { + character = 4, + line = 3, + }, + }, + rangeLength = 8, + text = "", + }, +} + +TEST [[ +local mt = {} + +function mt['xxx']() + +end +]] [[ +local mt = {} + +function mt['xxx']() + p +end +]] { + [1] = { + range = { + ["end"] = { + character = 4, + line = 3, + }, + start = { + character = 4, + line = 3, + }, + }, + rangeLength = 0, + text = "p", + }, +} + +TEST [[ +print(12345) +]] [[ +print(123 +45) +]] { + [1] = { + range = { + ["end"] = { + character = 9, + line = 0, + }, + start = { + character = 9, + line = 0, + }, + }, + rangeLength = 0, + text = "\ +", + }, +} + +TEST [[ +print(123 +45) +]] [[ +print(12345) +]] { + [1] = { + range = { + ["end"] = { + character = 0, + line = 1, + }, + start = { + character = 9, + line = 0, + }, + }, + rangeLength = 2, + text = "", + }, +} diff --git a/test/crossfile/hover.lua b/test/crossfile/hover.lua index c27cd3dd..e81494ff 100644 --- a/test/crossfile/hover.lua +++ b/test/crossfile/hover.lua @@ -202,20 +202,20 @@ TEST { path = 'a.lua', content = [[ t = { - [{}] = 1, + [1] = 1, } ]], }, { path = 'b.lua', content = [[ - <?t?>[{}] = 2 + <?t?>[1] = 2 ]] }, hover = { label = [[ global t: { - [table]: integer = 1|2, + [1]: integer = 1|2, }]], name = 't', }, @@ -226,20 +226,20 @@ TEST { path = 'a.lua', content = [[ t = { - [{}] = 1, + [1] = 1, } ]], }, { path = 'a.lua', content = [[ - <?t?>[{}] = 2 + <?t?>[1] = 2 ]] }, hover = { label = [[ global t: { - [table]: integer = 2, + [1]: integer = 2, }]], name = 't', }, @@ -729,7 +729,7 @@ food.secondField = 2 ]] }, hover = { - label = 'field Food.firstField: integer = 0', + label = 'field Food.firstField: number = 0', name = 'food.firstField', }} diff --git a/test/definition/init.lua b/test/definition/init.lua index 6e6d0a9a..85bcd5d5 100644 --- a/test/definition/init.lua +++ b/test/definition/init.lua @@ -36,6 +36,7 @@ end function TEST(script) files.removeAll() + script = script:gsub('\n', '\r\n') local target = catch_target(script) local start = script:find('<?', 1, true) local finish = script:find('?>', 1, true) @@ -51,8 +52,14 @@ function TEST(script) positions[i] = { result.target.start, result.target.finish } end end + if not founded(target, positions) then + core('', pos) + end assert(founded(target, positions)) else + if #target ~= 0 then + core('', pos) + end assert(#target == 0) end end @@ -65,6 +72,6 @@ require 'definition.table' require 'definition.method' require 'definition.label' require 'definition.call' -require 'definition.bug' require 'definition.special' +require 'definition.bug' require 'definition.luadoc' diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua index ff54546b..5531e2e3 100644 --- a/test/definition/luadoc.lua +++ b/test/definition/luadoc.lua @@ -87,6 +87,11 @@ TEST [[ ]] TEST [[ +---@type <!fun():void!> +local <?<!f!>?> +]] + +TEST [[ ---@param f <!fun():void!> function t(<?<!f!>?>) end ]] @@ -97,7 +102,7 @@ function f(<?...?>) end ]] TEST [[ ----@overload fun(y: boolean) +---@overload <!fun(y: boolean)!> ---@param x number ---@param y boolean ---@param z string @@ -108,7 +113,7 @@ print(<?f?>) TEST [[ local function f() - return 1 + return <!1!> end ---@class Class @@ -204,6 +209,23 @@ TEST [[ ]] TEST [[ +---@return <!fun()!> +local function f() end + +local <?<!r!>?> = f() +]] + +TEST [[ +---@generic T +---@param p T +---@return T +local function f(p) end + +local <!k!> +local <?<!r!>?> = f(<!k!>) +]] + +TEST [[ ---@class Foo local Foo = {} function Foo:<!bar1!>() end @@ -260,6 +282,26 @@ print(v1.<?bar1?>) ]] TEST [[ +---@class A +local <!t!> + +---@type A[] +local b + +local <?<!c!>?> = b[1] +]] + +TEST [[ +---@class A +local <!t!> + +---@type table<number, A> +local b + +local <?<!c!>?> = b[1] +]] + +TEST [[ ---@class Foo local Foo = {} function Foo:<!bar1!>() end @@ -299,13 +341,223 @@ print(v1[1].<?bar1?>) --]] TEST [[ +---@type fun():<!fun()!> +local f + +local <?<!f2!>?> = f() +]] + +TEST [[ +---@generic T +---@type fun(x: T):T +local f + +local <?<!v2!>?> = f(<!{}!>) +]] + +TEST [[ +---@generic T +---@param x T +---@return fun():T +local function f(x) end + +local v1 = f(<!{}!>) +local <?<!v2!>?> = v1() +]] + +TEST [[ +---@generic T +---@type fun(x: T):fun():T +local f + +local v1 = f(<!{}!>) +local <?<!v2!>?> = v1() +]] + +TEST [[ +---@generic V +---@return fun(x: V):V +local function f(x) end + +local v1 = f() +local <?<!v2!>?> = v1(<!{}!>) +]] + +TEST [[ +---@generic V +---@param x V[] +---@return V +local function f(x) end + +---@class A +local <!a!> + +---@type A[] +local b + +local <?<!c!>?> = f(b) +]] + +TEST [[ +---@generic V +---@param x table<number, V> +---@return V +local function f(x) end + +---@class A +local <!a!> + +---@type table<number, A> +local b + +local <?<!c!>?> = f(b) +]] + +TEST [[ +---@generic V +---@param x V[] +---@return V +local function f(x) end + +---@class A +local <!a!> + +---@type table<number, A> +local b + +local <?<!c!>?> = f(b) +]] + +TEST [[ +---@generic V +---@param x table<number, V> +---@return V +local function f(x) end + +---@class A +local <!a!> + +---@type A[] +local b + +local <?<!c!>?> = f(b) +]] + +TEST [[ +---@generic K +---@param x table<K, number> +---@return K +local function f(x) end + +---@class A +local <!a!> + +---@type table<A, number> +local b + +local <?<!c!>?> = f(b) +]] + +TEST [[ +---@generic V +---@return fun(t: V[]):V +local function f() end + +---@class A +local <!a!> + +---@type A[] +local b + +local f2 = f() + +local <?<!c!>?> = f2(b) +]] + +TEST [[ +---@generic T, V +---@param t T +---@return fun(t: V[]):V +---@return T +local function f(t) end + +---@class A +local <!a!> + +---@type A[] +local b + +local f2, c = f(b) + +local <?<!d!>?> = f2(c) +]] + +TEST [[ +---@class C +local <!v1!> + +---@generic V, T +---@param t T +---@return fun(t: V): V +---@return T +local function iterator(t) end + +for <!v!> in iterator(<!v1!>) do + print(<?v?>) +end +]] + +TEST [[ +---@class C +local <!v!> + +---@type C +local <!v1!> + +---@generic V, T +---@param t T +---@return fun(t: V): V +---@return T +local function iterator(t) end + +for <!v!> in iterator(<!v1!>) do + print(<?v?>) +end +]] + +TEST [[ +---@class C +local <!v!> + +---@type C[] +local v1 + +---@generic V, T +---@param t T +---@return fun(t: V[]): V +---@return T +local function iterator(t) end + +for <!v!> in iterator(v1) do + print(<?v?>) +end +]] + +TEST [[ ---@class Foo local Foo = {} function Foo:<!bar1!>() end ---@type table<number, Foo> local v1 -local ipairs = ipairs + +---@generic T: table, V +---@param t T +---@return fun(table: V[], i?: integer):integer, V +---@return T +---@return integer i +local function ipairs(t) end + for i, v in ipairs(v1) do print(v.<?bar1?>) end @@ -318,6 +570,35 @@ function Foo:<!bar1!>() end ---@type table<Foo, Foo> local v1 + +---@generic T: table, K, V +---@param t T +---@return fun(table: table<K, V>, index: K):K, V +---@return T +---@return nil +local function pairs(t) end + +for k, v in pairs(v1) do + print(k.bar1) + print(v.<?bar1?>) +end +]] + +TEST [[ +---@class Foo +local Foo = {} +function Foo:<!bar1!>() end + +---@type table<Foo, Foo> +local v1 + +---@generic T: table, K, V +---@param t T +---@return fun(table: table<K, V>, index: K):K, V +---@return T +---@return nil +local function pairs(t) end + for k, v in pairs(v1) do print(k.<?bar1?>) print(v.bar1) @@ -329,6 +610,13 @@ TEST [[ local Foo = {} function Foo:<!bar1!>() end +---@generic T: table, V +---@param t T +---@return fun(table: V[], i?: integer):integer, V +---@return T +---@return integer i +local function ipairs(t) end + ---@type table<number, table<number, Foo>> local v1 for i, v in ipairs(v1) do diff --git a/test/diagnostics/init.lua b/test/diagnostics/init.lua index d4bffdb5..0f5880ae 100644 --- a/test/diagnostics/init.lua +++ b/test/diagnostics/init.lua @@ -79,7 +79,8 @@ local <!x!> ]] TEST [[ -local x <close> = print +local y +local x <close> = y ]] TEST [[ @@ -135,11 +136,11 @@ end ) TEST [[ +local print, _G print(<!x!>) print(<!log!>) print(<!X!>) print(<!Log!>) -print(_VERSION) print(<!y!>) print(Z) print(_G) diff --git a/test/full/example.lua b/test/full/example.lua index 1eb66060..8633318a 100644 --- a/test/full/example.lua +++ b/test/full/example.lua @@ -5,6 +5,7 @@ local diag = require 'core.diagnostics' local config = require 'config' local fs = require 'bee.filesystem' local luadoc = require "parser.luadoc" +local noder = require 'core.noder' -- 临时 local function testIfExit(path) @@ -19,6 +20,7 @@ local function testIfExit(path) local parseClock = 0 local compileClock = 0 local luadocClock = 0 + local noderClock = 0 local total for i = 1, max do vm = TEST(buf) @@ -26,21 +28,26 @@ local function testIfExit(path) luadoc(nil, vm) local luadocPassed = os.clock() - luadocStart local passed = os.clock() - clock - parseClock = parseClock + vm.parseClock + local noderStart = os.clock() + noder.compileNodes(vm.ast) + local noderPassed = os.clock() - noderStart + parseClock = parseClock + vm.parseClock compileClock = compileClock + vm.compileClock luadocClock = luadocClock + luadocPassed + noderClock = noderClock + noderPassed if passed >= 1.0 or i == max then need = passed / i total = i break end end - print(('基准编译测试[%s]单次耗时:%.10f(解析:%.10f, 编译:%.10f, LuaDoc: %.10f)'):format( + print(('基准编译测试[%s]单次耗时:%.10f(解析:%.10f, 编译:%.10f, LuaDoc: %.10f, Noder: %.10f)'):format( path:filename():string(), need, parseClock / total, compileClock / total, - luadocClock / total + luadocClock / total, + noderClock / total )) local clock = os.clock() diff --git a/test/hover/init.lua b/test/hover/init.lua index d0e50036..2c68fef5 100644 --- a/test/hover/init.lua +++ b/test/hover/init.lua @@ -54,39 +54,39 @@ obj:<?init?>(1, '测试') function mt:init(a: any, b: any, c: any) ]] -TEST [[ -local mt = {} -mt.__index = mt -mt.type = 'Class' - -function mt:init(a, b, c) - return -end - -local obj = setmetatable({}, mt) - -obj:<?init?>(1, '测试') -]] -[[ -function Class:init(a: any, b: any, c: any) -]] - -TEST [[ -local mt = {} -mt.__index = mt -mt.__name = 'Class' - -function mt:init(a, b, c) - return -end - -local obj = setmetatable({}, mt) +--TEST [[ +--local mt = {} +--mt.__index = mt +--mt.type = 'Class' +-- +--function mt:init(a, b, c) +-- return +--end +-- +--local obj = setmetatable({}, mt) +-- +--obj:<?init?>(1, '测试') +--]] +--[[ +--function Class:init(a: any, b: any, c: any) +--]] -obj:<?init?>(1, '测试') -]] -[[ -function Class:init(a: any, b: any, c: any) -]] +--TEST [[ +--local mt = {} +--mt.__index = mt +--mt.__name = 'Class' +-- +--function mt:init(a, b, c) +-- return +--end +-- +--local obj = setmetatable({}, mt) +-- +--obj:<?init?>(1, '测试') +--]] +--[[ +--function Class:init(a: any, b: any, c: any) +--]] TEST [[ local mt = {} @@ -170,55 +170,55 @@ local <?obj?> = {} ]] "local obj: {}" -TEST [[ -local mt = {} -mt.__name = 'class' - -local <?obj?> = setmetatable({}, mt) -]] -"local obj: class {}" - -TEST [[ -local mt = {} -mt.name = 'class' -mt.__index = mt - -local <?obj?> = setmetatable({}, mt) -]] -[[ -local obj: class { - __index: table, - name: string = "class", -} -]] - -TEST [[ -local mt = {} -mt.TYPE = 'class' -mt.__index = mt +--TEST [[ +--local mt = {} +--mt.__name = 'class' +-- +--local <?obj?> = setmetatable({}, mt) +--]] +--"local obj: class {}" -local <?obj?> = setmetatable({}, mt) -]] -[[ -local obj: class { - TYPE: string = "class", - __index: table, -} -]] +--TEST [[ +--local mt = {} +--mt.name = 'class' +--mt.__index = mt +-- +--local <?obj?> = setmetatable({}, mt) +--]] +--[[ +--local obj: class { +-- __index: table, +-- name: string = "class", +--} +--]] -TEST [[ -local mt = {} -mt.Class = 'class' -mt.__index = mt +--TEST [[ +--local mt = {} +--mt.TYPE = 'class' +--mt.__index = mt +-- +--local <?obj?> = setmetatable({}, mt) +--]] +--[[ +--local obj: class { +-- TYPE: string = "class", +-- __index: table, +--} +--]] -local <?obj?> = setmetatable({}, mt) -]] -[[ -local obj: class { - Class: string = "class", - __index: table, -} -]] +--TEST [[ +--local mt = {} +--mt.Class = 'class' +--mt.__index = mt +-- +--local <?obj?> = setmetatable({}, mt) +--]] +--[[ +--local obj: class { +-- Class: string = "class", +-- __index: table, +--} +--]] -- TODO 支持自定义的函数库 --TEST[[ @@ -422,8 +422,6 @@ local t: { [1]: integer = 2, [true]: integer = 3, [5.5]: integer = 4, - [table]: integer = 5, - [function]: integer = 6, b: integer = 7, ["012"]: integer = 8, } @@ -438,9 +436,7 @@ local any = collectgarbage() t[any] = any ]] [[ -local t: { - [number]: integer = 1, -} +local t: {} ]] TEST[[ @@ -492,7 +488,7 @@ local <?self?> = setmetatable({ }, mt) ]] [[ -local self: obj { +local self: { __index: table, __name: string = "obj", id: integer = 1, @@ -860,15 +856,15 @@ print(<?x?>) local x <close>: integer = 1 ]] -TEST [[ -local function <?a?>(b) - return (b.c and a(b.c) or b) -end -]] -[[ -function a(b: table) - -> table -]] +--TEST [[ +--local function <?a?>(b) +-- return (b.c and a(b.c) or b) +--end +--]] +--[[ +--function a(b: table) +-- -> table +--]] TEST [[ local <?t?> = { @@ -927,7 +923,7 @@ field x: Class ]] TEST[[ ----@type Class +---@class Class local <?x?> = class() ]] [[ @@ -935,7 +931,7 @@ local x: Class ]] TEST[[ ----@type Class +---@class Class <?x?> = class() ]] [[ @@ -943,16 +939,10 @@ global x: Class ]] TEST[[ -local t = { - ---@type Class - <?x?> = class() -} -]] -[[ -field x: Class -]] +---@class A +---@class B +---@class C -TEST[[ ---@type A|B|C local <?x?> = class() ]] @@ -994,7 +984,7 @@ function f(t) end ]] [[ -local t: Class {} +local t: Class ]] TEST [[ @@ -1020,6 +1010,10 @@ local v: Class ]] TEST [[ +---@class A +---@class B +---@class C + ---@return A|B ---@return C local function <?f?>() @@ -1078,6 +1072,8 @@ function f(x: number, y: boolean) ]] TEST [[ +---@class Class + ---@vararg Class local function f(...) local _, <?x?> = ... @@ -1089,6 +1085,21 @@ local x: Class ]] TEST [[ +---@class Class + +---@vararg Class +local function f(...) + local t = {...} + local <?v?> = t[1] +end +]] +[[ +local v: Class +]] + +TEST [[ +---@class Class + ---@vararg Class local function f(...) local <?t?> = {...} @@ -1164,23 +1175,29 @@ local x: table<ClassA, ClassB> ]] --TEST [[ +-----@class ClassA +-----@class ClassB +-- -----@type table<ClassA, ClassB> --local t --for _, <?x?> in pairs(t) do --end --]] --[[ ---local x: *ClassB +--local x: ClassB --]] --TEST [[ +-----@class ClassA +-----@class ClassB +-- -----@type table<ClassA, ClassB> --local t --for <?k?>, v in pairs(t) do --end --]] --[[ ---local k: *ClassA +--local k: ClassA --]] TEST [[ @@ -1202,6 +1219,8 @@ local r: boolean ]] TEST [[ +---@class void + ---@param f fun():void function t(<?f?>) end ]] @@ -1492,6 +1511,12 @@ TEST [[ ---@field x string local t +---@generic T +---@param v T +---@param message any +---@return T +local function assert(v, message) end + local <?v?> = assert(t) ]] [[ diff --git a/test/references/all.lua b/test/references/all.lua new file mode 100644 index 00000000..a9442ae1 --- /dev/null +++ b/test/references/all.lua @@ -0,0 +1,213 @@ +local core = require 'core.reference' +local files = require 'files' + +local function catch_target(script) + local list = {} + local cur = 1 + while true do + local start, finish = script:find('<[!?].-[!?]>', cur) + if not start then + break + end + list[#list+1] = { start + 2, finish - 2 } + cur = finish + 1 + end + return list +end + +local function founded(targets, results) + if #targets ~= #results then + return false + end + for _, target in ipairs(targets) do + for _, result in ipairs(results) do + if target[1] == result[1] and target[2] == result[2] then + goto NEXT + end + end + do return false end + ::NEXT:: + end + return true +end + +function TEST(script) + files.removeAll() + local expect = catch_target(script) + local start = script:find('<[?~]') + local finish = script:find('[?~]>') + local pos = (start + finish) // 2 + 1 + local new_script = script:gsub('<[!?~]', ' '):gsub('[!?~]>', ' ') + files.setText('', new_script) + + local results = core('', pos) + if results then + local positions = {} + for i, result in ipairs(results) do + positions[i] = { result.target.start, result.target.finish } + end + assert(founded(expect, positions)) + else + assert(#expect == 0) + end +end + +TEST [[ +---@class A +local a = {} +a.<?x?> = 1 + +---@return A +local function f() end + +local b = f() +return b.<!x!> +]] + +TEST [[ +---@class A +local a = {} +a.<?x?> = 1 + +---@return table +---@return A +local function f() end + +local a, b = f() +return a.x, b.<!x!> +]] + +TEST [[ +local <?mt?> = {} +function <!mt!>:x() + <!self!>:x() +end +]] + +TEST [[ +local mt = {} +function mt:<?x?>() + self:<!x!>() +end +]] + +TEST [[ +---@class Dog +local mt = {} +function mt:<?eat?>() +end + +---@class Master +local mt2 = {} +function mt2:init() + ---@type Dog + local foo = self:doSomething() + ---@type Dog + self.dog = getDog() +end +function mt2:feed() + self.dog:<!eat!>() +end +function mt2:doSomething() +end +]] + +-- 泛型的反向搜索 +TEST [[ +---@class Dog +local <?Dog?> = {} + +---@generic T +---@param type1 T +---@return T +function foobar(type1) +end + +local <!v1!> = foobar(<!Dog!>) +]] + +TEST [[ +---@class Dog +local Dog = {} +function Dog:<?eat?>() +end + +---@generic T +---@param type1 T +---@return T +function foobar(type1) + return {} +end + +local v1 = foobar(Dog) +v1:<!eat!>() +]] + +TEST [[ +---@class Dog +local Dog = {} +function Dog:<?eat?>() +end + +---@class Master +local Master = {} + +---@generic T +---@param type1 string +---@param type2 T +---@return T +function Master:foobar(type1, type2) + return {} +end + +local v1 = Master:foobar("", Dog) +v1.<!eat!>() +]] + +TEST [[ +---@class A +local <?A?> + +---@generic T +---@param self T +---@return T +function m.f(self) end + +local <!b!> = m.f(<!A!>) +]] + +TEST [[ +---@class A +local <?A?> + +---@generic T +---@param self T +---@return T +function m:f() end + +local <!b!> = m.f(<!A!>) +]] + +TEST [[ +---@class A +local <?A?> + +---@generic T +---@param self T +---@return T +function <!A!>.f(self) end + +local <!b!> = <!A!>:f() +]] + +TEST [[ +---@class A +local <?A?> + +---@generic T +---@param self T +---@return T +function <!A!>:f() end + +local <!b!> = <!A!>:f() +]] diff --git a/test/references/init.lua b/test/references/init.lua index c4e5018a..e90cb2a8 100644 --- a/test/references/init.lua +++ b/test/references/init.lua @@ -1,4 +1,4 @@ -local core = require 'core.reference' +local core = require 'core.reference' local files = require 'files' local function catch_target(script) @@ -33,7 +33,7 @@ end function TEST(script) files.removeAll() - local target = catch_target(script) + local expect = catch_target(script) local start = script:find('<[?~]') local finish = script:find('[?~]>') local pos = (start + finish) // 2 + 1 @@ -46,9 +46,9 @@ function TEST(script) for i, result in ipairs(results) do positions[i] = { result.target.start, result.target.finish } end - assert(founded(target, positions)) + assert(founded(expect, positions)) else - assert(#target == 0) + assert(#expect == 0) end end @@ -96,6 +96,16 @@ local <?a?> = 1 ]] TEST [[ +local <!a!> +local <?b?> = <!a!> +]] + +TEST [[ +local <?a?> +local <!b!> = <!a!> +]] + +TEST [[ local t = { <!a!> = 1 } @@ -166,7 +176,7 @@ local y = f()() TEST [[ local t = {} t.<?x?> = 1 -t[a.b.x] = 1 +t[<!a.b.c!>] = 1 ]] TEST [[ @@ -208,13 +218,6 @@ end ]] TEST [[ -local <?mt?> = {} -function <!mt!>:x() - <!self!>:x() -end -]] - -TEST [[ local mt = {} function mt:<!x!>() self:<?x?>() @@ -222,13 +225,6 @@ end ]] TEST [[ -local mt = {} -function mt:<?x?>() - self:<!x!>() -end -]] - -TEST [[ a.<!b!>.c = 1 print(a.<?b?>.c) ]] @@ -252,7 +248,7 @@ a.<!t!> = <?f?> ]] TEST [[ -<!t!>.f = <?t?> +<!t!>.<!f!> = <?t?> ]] TEST [[ @@ -302,135 +298,3 @@ TEST [[ ---@return <?xxx?> function f() end ]] - -TEST [[ ----@class Dog -local mt = {} -function mt:<?eat?>() -end - ----@class Master -local mt2 = {} -function mt2:init() - ---@type Dog - local foo = self:doSomething() - ---@type Dog - self.dog = getDog() -end -function mt2:feed() - self.dog:<!eat!>() -end -function mt2:doSomething() -end -]] - -TEST [[ ----@class A -local a = {} -a.<?x?> = 1 - ----@return A -local function f() end - -local b = f() -return b.<!x!> -]] - -TEST [[ ----@class A -local a = {} -a.<?x?> = 1 - ----@return table ----@return A -local function f() end - -local a, b = f() -return a.x, b.<!x!> -]] - -TEST [[ ----@class Dog -local Dog = {} -function Dog:<?eat?>() -end - ----@generic T ----@param type1 T ----@return T -function foobar(type1) - return {} -end - -local v1 = foobar(Dog) -v1:<!eat!>() -]] - -TEST [[ ----@class Dog -local Dog = {} -function Dog:<?eat?>() -end - ----@class Master -local Master = {} - ----@generic T ----@param type1 string ----@param type2 T ----@return T -function Master:foobar(type1, type2) - return {} -end - -local v1 = Master:foobar("", Dog) -v1.<!eat!>() -]] - -TEST [[ ----@class A -local <?A?> - ----@generic T ----@param self T ----@return T -function m.f(self) end - -local <!b!> = m.f(<!A!>) -]] - -TEST [[ ----@class A -local <?A?> - ----@generic T ----@param self T ----@return T -function m:f() end - -local <!b!> = m.f(<!A!>) -]] - -TEST [[ ----@class A -local <?A?> - ----@generic T ----@param self T ----@return T -function <!A!>.f(self) end - -local <!b!> = <!A!>:f() -]] - -TEST [[ ----@class A -local <?A?> - ----@generic T ----@param self T ----@return T -function <!A!>:f() end - -local <!b!> = <!A!>:f() -]] diff --git a/test/rename/init.lua b/test/rename/init.lua index 88f83269..4b10756e 100644 --- a/test/rename/init.lua +++ b/test/rename/init.lua @@ -18,7 +18,7 @@ end function TEST(oldName, newName) return function (oldScript) - return function (newScript) + return function (expectScript) files.removeAll() files.setText('', oldScript) local pos = oldScript:find('[^%w_]'..oldName..'[^%w_]') @@ -29,7 +29,7 @@ function TEST(oldName, newName) if positions then script = replace(script, positions) end - assert(script == newScript) + assert(script == expectScript) end end end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 02355a94..4c8817f2 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1,7 +1,7 @@ local files = require 'files' -local config = require 'config' local vm = require 'vm' -local guide = require 'core.guide' +local guide = require 'parser.guide' +local infer = require 'core.infer' rawset(_G, 'TEST', true) @@ -30,7 +30,10 @@ function TEST(wanted) files.setText('', newScript) local source = getSource(pos) assert(source) - local result = vm.getInferType(source, 0) + local result = infer.searchAndViewInfers(source) + if wanted ~= result then + infer.searchAndViewInfers(source) + end assert(wanted == result) end end @@ -140,22 +143,44 @@ TEST 'number' [[ ]] TEST 'tablelib' [[ +---@class tablelib +table = {} + <?table?>() ]] TEST 'string' [[ +_VERSION = 'Lua 5.4' + <?x?> = _VERSION ]] TEST 'function' [[ +---@class stringlib +local string + +string.sub = function () end + return ('x').<?sub?> ]] TEST 'function' [[ +---@class stringlib +local string + +string.sub = function () end + <?x?> = ('x').sub ]] TEST 'function' [[ +---@class stringlib +local string + +string.sub = function () end + +_VERSION = 'Lua 5.4' + <?x?> = _VERSION.sub ]] @@ -200,8 +225,11 @@ end _, <?y?> = pcall(x) ]] -TEST 'oslib' [[ -local <?os?> = require 'os' +TEST 'integer' [[ +local function x() + return 1 +end +_, <?y?> = xpcall(x) ]] TEST 'string|table' [[ @@ -218,9 +246,14 @@ local function f(<?a?>, b) end ]] -TEST 'string' [[ +TEST 'A' [[ +---@class A + +---@return A +local function f2() end + local function f() - return string.sub() + return f2() end local <?x?> = f() @@ -238,14 +271,6 @@ local <?x?> = f() --setmetatable(<?b?>) --]] -TEST 'function' [[ -string.<?sub?>() -]] - -TEST 'function' [[ -(''):<?sub?>() -]] - -- 不根据对方函数内的使用情况来推测 TEST 'any' [[ local function x(a) @@ -270,12 +295,6 @@ end local _, _, _, <?b?>, _ = x(nil, true, 1, 'yy') ]] --- TODO 暂不支持这些特殊情况,之后用其他语法定义 ---TEST 'integer' [[ ---for <?i?> in ipairs(t) do ---end ---]] - TEST 'any' [[ local <?x?> = next() ]] @@ -297,16 +316,23 @@ local <?x?> ]] TEST 'string' [[ +---@class string + ---@type string local <?x?> ]] TEST 'string[]' [[ +---@class string + ---@type string[] local <?x?> ]] TEST 'string|table' [[ +---@class string +---@class table + ---@type string | table local <?x?> ]] @@ -322,6 +348,9 @@ local <?x?> ]] TEST 'table<string, number>' [[ +---@class string +---@class number + ---@type table<string, number> local <?x?> ]] @@ -331,12 +360,16 @@ self.<?t?>[#self.t+1] = {} ]] TEST 'string' [[ +---@class string + ---@type string[] local x local <?y?> = x[1] ]] TEST 'string' [[ +---@class string + ---@return string[] local function f() end local x = f() @@ -387,6 +420,15 @@ print(t.<?a?>) ]] TEST 'integer' [[ +---@class integer + +---@generic T: table, V +---@param t T +---@return fun(table: V[], i?: integer):integer, V +---@return T +---@return integer i +local function ipairs() end + for <?i?> in ipairs() do end ]] @@ -404,6 +446,8 @@ local k, v = next(<?t?>) ]] TEST 'string' [[ +---@class string + ---@generic K, V ---@param t table<K, V> ---@return K @@ -416,6 +460,8 @@ local <?k?>, v = next(t) ]] TEST 'boolean' [[ +---@class boolean + ---@generic K, V ---@param t table<K, V> ---@return K @@ -436,6 +482,8 @@ local <?r?> = f(true) ]] TEST 'string' [[ +---@class string + ---@generic K, V ---@type fun(arg: table<K, V>):K, V local f @@ -447,6 +495,8 @@ local <?k?>, v = f(t) ]] TEST 'boolean' [[ +---@class boolean + ---@generic K, V ---@type fun(arg: table<K, V>):K, V local f @@ -472,6 +522,8 @@ local <?r?> = f() ]] TEST 'string' [[ +---@class string + ---@generic K, V ---@return fun(arg: table<K, V>):K, V local function f() end @@ -485,6 +537,8 @@ local <?k?>, v = f2(t) ]] TEST 'string' [[ +---@class string + ---@generic T: table, K, V ---@param t T ---@return fun(table: table<K, V>, index: K):K, V @@ -502,11 +556,12 @@ end ]] TEST 'boolean' [[ +---@class boolean + ---@generic T: table, K, V ---@param t T ----@return fun(table: table<K, V>, index: K):K, V +---@return fun(table: table<K, V>, index?: K):K, V ---@return T ----@return nil local function pairs(t) end local f = pairs(t) @@ -519,11 +574,12 @@ end ]] TEST 'string' [[ +---@class string + ---@generic T: table, K, V ---@param t T ----@return fun(table: table<K, V>, index: K):K, V +---@return fun(table: table<K, V>, index?: K):K, V ---@return T ----@return nil local function pairs(t) end ---@type table<string, boolean> @@ -534,6 +590,8 @@ end ]] TEST 'boolean' [[ +---@class boolean + ---@generic T: table, K, V ---@param t T ---@return fun(table: table<K, V>, index: K):K, V @@ -549,6 +607,8 @@ end ]] TEST 'boolean' [[ +---@class boolean + ---@generic T: table, V ---@param t T ---@return fun(table: V[], i?: integer):integer, V @@ -564,6 +624,8 @@ end ]] TEST 'boolean' [[ +---@class boolean + ---@generic T: table, K, V ---@param t T ---@return fun(table: table<K, V>, index: K):K, V @@ -579,11 +641,12 @@ end ]] TEST 'integer' [[ +---@class integer + ---@generic T: table, K, V ---@param t T ----@return fun(table: table<K, V>, index: K):K, V +---@return fun(table: table<K, V>, index?: K):K, V ---@return T ----@return nil local function pairs(t) end ---@type boolean[] |