diff options
Diffstat (limited to 'script')
87 files changed, 3825 insertions, 1360 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua index bae3df81..3fd58c81 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -1,14 +1,13 @@ -local files = require 'files' -local lang = require 'language' -local define = require 'proto.define' -local guide = require 'core.guide' -local util = require 'utility' -local sp = require 'bee.subprocess' -local vm = require 'vm' +local files = require 'files' +local lang = require 'language' +local util = require 'utility' +local sp = require 'bee.subprocess' +local vm = require 'vm' +local guide = require "parser.guide" local function checkDisableByLuaDocExits(uri, row, mode, code) local lines = files.getLines(uri) - local ast = files.getAst(uri) + local ast = files.getState(uri) local text = files.getOriginText(uri) local line = lines[row] if ast.ast.docs and line then @@ -44,7 +43,7 @@ end local function checkDisableByLuaDocInsert(uri, row, mode, code) local lines = files.getLines(uri) - local ast = files.getAst(uri) + local ast = files.getState(uri) local text = files.getOriginText(uri) -- 先看看上一行是不是已经有了 -- 没有的话就插入一行 @@ -135,7 +134,7 @@ local function changeVersion(uri, version, results) end local function solveUndefinedGlobal(uri, diag, results) - local ast = files.getAst(uri) + local ast = files.getState(uri) local offset = files.offsetOfWord(uri, diag.range.start) guide.eachSourceContain(ast.ast, offset, function (source) if source.type ~= 'getglobal' then @@ -154,7 +153,7 @@ local function solveUndefinedGlobal(uri, diag, results) end local function solveLowercaseGlobal(uri, diag, results) - local ast = files.getAst(uri) + local ast = files.getState(uri) local offset = files.offsetOfWord(uri, diag.range.start) guide.eachSourceContain(ast.ast, offset, function (source) if source.type ~= 'setglobal' then @@ -167,7 +166,7 @@ local function solveLowercaseGlobal(uri, diag, results) end local function findSyntax(uri, diag) - local ast = files.getAst(uri) + local ast = files.getState(uri) for _, err in ipairs(ast.errs) do if err.type:lower():gsub('_', '-') == diag.code then local range = files.range(uri, err.start, err.finish) @@ -351,7 +350,7 @@ local function checkQuickFix(results, uri, start, diagnostics) end local function checkSwapParams(results, uri, start, finish) - local ast = files.getAst(uri) + local ast = files.getState(uri) local text = files.getText(uri) if not ast then return @@ -540,7 +539,7 @@ local function checkJsonToLua(results, uri, start, finish) end return function (uri, start, finish, diagnostics) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end diff --git a/script/core/collector.lua b/script/core/collector.lua new file mode 100644 index 00000000..763d145b --- /dev/null +++ b/script/core/collector.lua @@ -0,0 +1,71 @@ +local collect = {} +local subscribed = {} + +local m = {} + +--- 订阅一个名字 +---@param uri uri +---@param name string +---@param value any +function m.subscribe(uri, name, value) + -- 订阅部分 + local uriSubscribed = subscribed[uri] + if not uriSubscribed then + uriSubscribed = {} + subscribed[uri] = uriSubscribed + end + uriSubscribed[name] = true + -- 收集部分 + local nameCollect = collect[name] + if not nameCollect then + nameCollect = {} + collect[name] = nameCollect + end + if value == nil then + value = true + end + nameCollect[uri] = value +end + +--- 丢弃掉某个 uri 中收集的所有信息 +---@param uri uri +function m.dropUri(uri) + local uriSubscribed = subscribed[uri] + if not uriSubscribed then + return + end + subscribed[uri] = nil + for name in pairs(uriSubscribed) do + collect[name][uri] = nil + end +end + +--- 是否包含某个名字的订阅 +---@param name string +---@return boolean +function m.has(name) + local nameCollect = collect[name] + if not nameCollect then + return false + end + if next(nameCollect) == nil then + return false + end + return true +end + +--- 迭代某个名字的订阅 +---@param name string +function m.each(name) + local nameCollect = collect[name] + if not nameCollect then + return function () end + end + local uri, value + return function () + uri, value = next(nameCollect, uri) + return value + end +end + +return m diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua index 527af8d5..6fb9669f 100644 --- a/script/core/command/removeSpace.lua +++ b/script/core/command/removeSpace.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' local function isInString(ast, offset) return guide.eachSourceContain(ast.ast, offset, function (source) @@ -16,17 +16,17 @@ return function (data) local uri = data.uri local lines = files.getLines(uri) local text = files.getText(uri) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not lines then return end local textEdit = {} for i = 1, #lines do - local line = guide.lineContent(lines, text, i, true) + local line = searcher.lineContent(lines, text, i, true) local pos = line:find '[ \t]+$' if pos then - local start, finish = guide.lineRange(lines, i, true) + local start, finish = searcher.lineRange(lines, i, true) start = start + pos - 1 if isInString(ast, start) then goto NEXT_LINE diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua index 995a2109..348c2646 100644 --- a/script/core/command/solve.lua +++ b/script/core/command/solve.lua @@ -1,8 +1,7 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' local opMap = { ['+'] = true, @@ -29,7 +28,7 @@ local literalMap = { return function (data) local uri = data.uri local text = files.getText(uri) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/completion.lua b/script/core/completion.lua index e3980eca..d261b302 100644 --- a/script/core/completion.lua +++ b/script/core/completion.lua @@ -1,19 +1,14 @@ local define = require 'proto.define' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local matchKey = require 'core.matchkey' local vm = require 'vm' -local getLabel = require 'core.hover.label' local getName = require 'core.hover.name' local getArg = require 'core.hover.arg' -local getReturn = require 'core.hover.return' -local getDesc = require 'core.hover.description' local getHover = require 'core.hover' local config = require 'config' local util = require 'utility' local markdown = require 'provider.markdown' -local findSource = require 'core.find-source' -local await = require 'await' local parser = require 'parser' local keyWordMap = require 'core.keyword' local workspace = require 'workspace' @@ -21,6 +16,8 @@ local furi = require 'file-uri' local rpath = require 'workspace.require-path' local lang = require 'language' local lookBackward = require 'core.look-backward' +local guide = require 'parser.guide' +local infer = require 'core.infer' local DiagnosticModes = { 'disable-next-line', @@ -135,8 +132,11 @@ local function buildFunctionSnip(source, value, oop) end local function buildDetail(source) - local types = vm.getInferType(source, 0) - local literals = vm.getInferLiteral(source, 0) + if source.type == 'dummy' then + return + end + local types = infer.searchAndViewInfers(source) + local literals = infer.searchAndViewLiterals(source) if literals then return types .. ' = ' .. literals else @@ -149,9 +149,9 @@ local function getSnip(source) if context <= 0 then return nil end - local defs = vm.getRefs(source, 0) + local defs = vm.getRefs(source) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def ~= source and def.type == 'function' then local uri = guide.getUri(def) local text = files.getText(uri) @@ -173,6 +173,9 @@ local function getSnip(source) end local function buildDesc(source) + if source.type == 'dummy' then + return + end local hover = getHover.get(source) local md = markdown() md:add('lua', hover.label) @@ -273,8 +276,8 @@ local function checkLocal(ast, word, offset, results) if not matchKey(word, name) then goto CONTINUE end - if vm.hasType(source, 'function') then - for _, def in ipairs(vm.getDefs(source, 0)) do + if infer.hasType(source, 'function') then + for _, def in ipairs(vm.getDefs(source)) do if def.type == 'function' or def.type == 'doc.type.function' then local funcLabel = name .. getParams(def, false) @@ -325,7 +328,7 @@ local function checkModule(ast, word, offset, results) and not config.config.diagnostics.globals[stemName] and stemName:match '^[%a_][%w_]*$' and matchKey(word, stemName) then - local targetAst = files.getAst(uri) + local targetAst = files.getState(uri) if not targetAst then goto CONTINUE end @@ -417,7 +420,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset) end local function checkFieldThen(name, src, word, start, offset, parent, oop, results) - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src local kind = define.CompletionItemKind.Field if value.type == 'function' or value.type == 'doc.type.function' then @@ -492,7 +495,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res end local funcLabel if config.config.completion.showParams then - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src if value.type == 'function' or value.type == 'doc.type.function' then funcLabel = name .. getParams(value, oop) @@ -539,16 +542,16 @@ end local function checkGlobal(ast, word, start, offset, parent, oop, results) local locals = guide.getVisibleLocals(ast.ast, offset) - local refs = vm.getGlobalSets '*' - checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global') + local globals = vm.getGlobalSets '*' + checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results, locals, 'global') end local function checkField(ast, word, start, offset, parent, oop, results) if parent.tag == '_ENV' or parent.special == '_G' then - local refs = vm.getGlobalSets '*' - checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) + local globals = vm.getGlobalSets '*' + checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results) else - local refs = vm.getFields(parent, 0) + local refs = vm.getRefs(parent, '*') checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) end end @@ -1043,14 +1046,14 @@ local function mergeEnums(a, b, source) end end -local function checkTypingEnum(ast, text, offset, infers, str, results) +local function checkTypingEnum(ast, text, offset, defs, str, results) local enums = {} - for _, infer in ipairs(infers) do - if infer.source.type == 'doc.type.enum' - or infer.source.type == 'doc.resume' then + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then enums[#enums+1] = { - label = infer.source[1], - description = infer.source.comment and infer.source.comment.text, + label = def[1], + description = def.comment and def.comment.text, kind = define.CompletionItemKind.EnumMember, } end @@ -1074,8 +1077,8 @@ local function checkEqualEnumLeft(ast, text, offset, source, results) return src end end) - local infers = vm.getInfers(source, 0) - checkTypingEnum(ast, text, offset, infers, str, results) + local defs = vm.getDefs(source) + checkTypingEnum(ast, text, offset, defs, str, results) end local function checkEqualEnum(ast, text, offset, results) @@ -1247,7 +1250,30 @@ function (%s)\ end"):format(table.concat(args, ', ')) end -local function getCallEnums(source, index) +local function pushCallEnumsAndFuncs(defs) + local results = {} + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then + results[#results+1] = { + label = def[1], + description = def.comment, + kind = define.CompletionItemKind.EnumMember, + } + end + if def.type == 'doc.type.function' then + results[#results+1] = { + label = infer.viewDocFunction(def), + description = def.comment, + kind = define.CompletionItemKind.Function, + insertText = buildInsertDocFunction(def), + } + end + end + return results +end + +local function getCallEnumsAndFuncs(source, index) if source.type == 'function' and source.bindDocs then if not source.args then return @@ -1266,37 +1292,10 @@ local function getCallEnums(source, index) for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.param' and doc.param[1] == arg[1] then - local enums = {} - for _, enum in ipairs(vm.getDocEnums(doc.extends) or {}) do - enums[#enums+1] = { - label = enum[1], - description = enum.comment, - kind = define.CompletionItemKind.EnumMember, - } - end - for _, unit in ipairs(vm.getDocTypeUnits(doc.extends) or {}) do - if unit.type == 'doc.type.function' then - local text = files.getText(guide.getUri(unit)) - enums[#enums+1] = { - label = text:sub(unit.start, unit.finish), - description = doc.comment, - kind = define.CompletionItemKind.Function, - insertText = buildInsertDocFunction(unit), - } - end - end - return enums + return pushCallEnumsAndFuncs(vm.getDefs(doc.extends)) elseif doc.type == 'doc.vararg' and arg.type == '...' then - local enums = {} - for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do - enums[#enums+1] = { - label = enum[1], - description = enum.comment, - kind = define.CompletionItemKind.EnumMember, - } - end - return enums + return pushCallEnumsAndFuncs(vm.getDefs(doc.vararg)) end end end @@ -1403,12 +1402,12 @@ local function checkTableLiteralFieldByCall(ast, text, offset, call, defs, index return end for _, def in ipairs(defs) do - local func = guide.getObjectValue(def) or def + local func = searcher.getObjectValue(def) or def local param = getFuncParamByCallIndex(func, index) if not param then goto CONTINUE end - local defs = vm.getDefFields(param, 0) + local defs = vm.getDefs(param, '*') for _, field in ipairs(defs) do local name = guide.getKeyName(field) if name and not mark[name] then @@ -1431,10 +1430,10 @@ local function tryCallArg(ast, text, offset, results) if arg and arg.type == 'function' then return end - local defs = vm.getDefs(call.node, 0) + local defs = vm.getDefs(call.node) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def - local enums = getCallEnums(def, argIndex) + def = searcher.getObjectValue(def) or def + local enums = getCallEnumsAndFuncs(def, argIndex) if enums then mergeEnums(myResults, enums, arg) end @@ -1461,7 +1460,8 @@ local function tryTable(ast, text, offset, results) if source.type ~= 'table' then tbl = source.parent end - local defs = vm.getDefFields(tbl, 0) + local parent = tbl.parent + local defs = vm.getDefs(parent, '*') for _, field in ipairs(defs) do local name = guide.getKeyName(field) if name and not mark[name] then @@ -1560,7 +1560,7 @@ end local function tryLuaDocBySource(ast, offset, source, results) if source.type == 'doc.extends.name' then if source.parent.type == 'doc.class' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if doc.type == 'doc.class.name' and doc.parent ~= source.parent and matchKey(source[1], doc[1]) then @@ -1578,7 +1578,7 @@ local function tryLuaDocBySource(ast, offset, source, results) end return true elseif source.type == 'doc.type.name' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') and doc.parent ~= source.parent and matchKey(source[1], doc[1]) then @@ -1652,7 +1652,7 @@ end local function tryLuaDocByErr(ast, offset, err, docState, results) if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if doc.type == 'doc.class.name' and doc.parent ~= docState then results[#results+1] = { @@ -1662,7 +1662,7 @@ local function tryLuaDocByErr(ast, offset, err, docState, results) end end elseif err.type == 'LUADOC_MISS_TYPE_NAME' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then results[#results+1] = { label = doc[1], @@ -1735,14 +1735,14 @@ local function buildLuaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = vm.getInferType(arg) + args[#args+1] = infer.searchAndViewInfers(arg) end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = vm.getInferType(rtns[n]) + returns[n] = infer.searchAndViewInfers(rtns[n]) end end end @@ -1931,7 +1931,7 @@ local function completion(uri, offset, triggerCharacter) return results end tracy.ZoneBeginN 'completion #1' - local ast = files.getAst(uri) + local ast = files.getState(uri) local text = files.getText(uri) results = {} clearStack() diff --git a/script/core/definition.lua b/script/core/definition.lua index b26bb922..27a9e553 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -1,14 +1,15 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local workspace = require 'workspace' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' +local guide = require 'parser.guide' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -20,7 +21,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else @@ -101,7 +102,7 @@ local function convertIndex(source) end return function (uri, offset) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end @@ -127,11 +128,11 @@ return function (uri, offset) end end - local defs = vm.getDefs(source, 0) + local defs = vm.getDefs(source) local values = {} for _, src in ipairs(defs) do - local value = guide.getObjectValue(src) - if value and value ~= src then + local value = searcher.getObjectValue(src) + if value and value ~= src and guide.isLiteral(value) then values[value] = true end end @@ -148,9 +149,6 @@ return function (uri, offset) goto CONTINUE end src = src.field or src.method or src.index or src - if src.type == 'table' and src.parent.type ~= 'return' then - goto CONTINUE - end if src.type == 'doc.class.name' and source.type ~= 'doc.type.name' and source.type ~= 'doc.extends.name' diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua index 19bb4f97..bae39a03 100644 --- a/script/core/diagnostics/ambiguity-1.lua +++ b/script/core/diagnostics/ambiguity-1.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local opMap = { @@ -25,7 +25,7 @@ local literalMap = { } return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua index 702cd904..ae6d4d3b 100644 --- a/script/core/diagnostics/circle-doc-class.lua +++ b/script/core/diagnostics/circle-doc-class.lua @@ -1,11 +1,11 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' +local guide = require 'parser.guide' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end @@ -40,7 +40,7 @@ return function (uri, callback) end if not mark[newName] then mark[newName] = true - local docs = vm.getDocTypes(newName) + local docs = vm.getDocDefines(newName) for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' then list[#list+1] = otherDoc.parent diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index d1983c42..afd259d0 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -1,10 +1,9 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua index f23755ea..21f7e83a 100644 --- a/script/core/diagnostics/code-after-break.lua +++ b/script/core/diagnostics/code-after-break.lua @@ -1,10 +1,10 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua index 65099af8..a16811ab 100644 --- a/script/core/diagnostics/count-down-loop.lua +++ b/script/core/diagnostics/count-down-loop.lua @@ -1,9 +1,9 @@ -local files = require "files" -local guide = require "core.guide" -local lang = require 'language' +local files = require "files" +local guide = require "parser.guide" +local lang = require 'language' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) local text = files.getText(uri) if not state or not text then return diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua index 60d60946..c60084fb 100644 --- a/script/core/diagnostics/deprecated.lua +++ b/script/core/diagnostics/deprecated.lua @@ -1,13 +1,13 @@ -local files = require 'files' -local vm = require 'vm' -local lang = require 'language' -local guide = require 'core.guide' -local config = require 'config' -local define = require 'proto.define' -local await = require 'await' +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local guide = require 'parser.guide' +local config = require 'config' +local define = require 'proto.define' +local await = require 'await' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -20,7 +20,7 @@ return function (uri, callback) return end if src.type == 'getglobal' then - local key = guide.getKeyName(src) + local key = src[1] if not key then return end @@ -34,11 +34,11 @@ return function (uri, callback) await.delay() - if not vm.isDeprecated(src, 0) then + if not vm.isDeprecated(src, true) then return end - local defs = vm.getDefs(src, 0) + local defs = vm.getDefs(src) local validVersions for _, def in ipairs(defs) do if def.bindDocs then diff --git a/script/core/diagnostics/doc-field-no-class.lua b/script/core/diagnostics/doc-field-no-class.lua index f27bbb32..97603c0b 100644 --- a/script/core/diagnostics/doc-field-no-class.lua +++ b/script/core/diagnostics/doc-field-no-class.lua @@ -2,7 +2,7 @@ local files = require 'files' local lang = require 'language' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua index 8c6696a9..20eedb5e 100644 --- a/script/core/diagnostics/duplicate-doc-class.lua +++ b/script/core/diagnostics/duplicate-doc-class.lua @@ -1,11 +1,11 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' +local guide = require 'parser.guide' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end @@ -20,7 +20,7 @@ return function (uri, callback) or doc.type == 'doc.alias' then local name = guide.getKeyName(doc) if not cache[name] then - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) cache[name] = {} for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' @@ -28,7 +28,7 @@ return function (uri, callback) cache[name][#cache[name]+1] = { start = otherDoc.start, finish = otherDoc.finish, - uri = guide.getUri(otherDoc), + uri = searcher.getUri(otherDoc), } end end diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua index b621fd9e..1ee27ff2 100644 --- a/script/core/diagnostics/duplicate-doc-field.lua +++ b/script/core/diagnostics/duplicate-doc-field.lua @@ -2,7 +2,7 @@ local files = require 'files' local lang = require 'language' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end diff --git a/script/core/diagnostics/duplicate-doc-param.lua b/script/core/diagnostics/duplicate-doc-param.lua index 676a6fb4..b54c1978 100644 --- a/script/core/diagnostics/duplicate-doc-param.lua +++ b/script/core/diagnostics/duplicate-doc-param.lua @@ -2,7 +2,7 @@ local files = require 'files' local lang = require 'language' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua index 5e63d39e..91a35212 100644 --- a/script/core/diagnostics/duplicate-index.lua +++ b/script/core/diagnostics/duplicate-index.lua @@ -1,11 +1,11 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua index c1e2285a..492793b1 100644 --- a/script/core/diagnostics/duplicate-set-field.lua +++ b/script/core/diagnostics/duplicate-set-field.lua @@ -1,11 +1,11 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local define = require 'proto.define' +local guide = require "parser.guide" return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -30,7 +30,7 @@ return function (uri, callback) if not name then goto CONTINUE end - local value = guide.getObjectValue(nxt) + local value = searcher.getObjectValue(nxt) if not value or value.type ~= 'function' then goto CONTINUE end diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua index 690a4ca2..fc205d7e 100644 --- a/script/core/diagnostics/empty-block.lua +++ b/script/core/diagnostics/empty-block.lua @@ -1,12 +1,12 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' -- 检查空代码块 -- 但是排除忙等待(repeat/while) return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index de23bc76..d95963e4 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' -- TODO: 检查路径是否可达 @@ -8,7 +8,7 @@ local function mayRun(path) end return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua index a2b831f7..5446a7c3 100644 --- a/script/core/diagnostics/init.lua +++ b/script/core/diagnostics/init.lua @@ -59,16 +59,18 @@ local function check(uri, name, results) if passed >= 0.5 then log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed)) end + if DIAGTIMES then + DIAGTIMES[name] = (DIAGTIMES[name] or 0) + passed + end end return function (uri, response) - local vm = require 'vm' - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end - local isOpen = files.isOpen(uri) + log.debug('do diagnostic @', uri) for _, name in ipairs(diagList) do await.delay() diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua index 9c094701..cba33459 100644 --- a/script/core/diagnostics/lowercase-global.lua +++ b/script/core/diagnostics/lowercase-global.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local config = require 'config' local vm = require 'vm' @@ -18,7 +18,7 @@ end -- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删) return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua index 0727c2fd..2cbc13ee 100644 --- a/script/core/diagnostics/newfield-call.lua +++ b/script/core/diagnostics/newfield-call.lua @@ -1,9 +1,9 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua index 807f76a2..71dc33e2 100644 --- a/script/core/diagnostics/newline-call.lua +++ b/script/core/diagnostics/newline-call.lua @@ -1,9 +1,9 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) local lines = files.getLines(uri) local text = files.getText(uri) if not ast or not lines then diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua index ffaab821..6ff17c81 100644 --- a/script/core/diagnostics/no-implicit-any.lua +++ b/script/core/diagnostics/no-implicit-any.lua @@ -1,11 +1,10 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local infer = require 'core.infer' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -21,7 +20,7 @@ return function (uri, callback) and source.type ~= 'tableindex' then return end - if vm.getInferType(source, 0) == 'any' then + if infer.searchAndViewInfers(source) == 'any' then callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua index 857d80d2..503347d0 100644 --- a/script/core/diagnostics/redefined-local.lua +++ b/script/core/diagnostics/redefined-local.lua @@ -1,9 +1,9 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -13,6 +13,9 @@ return function (uri, callback) or name == ast.ENVMode then return end + if source.tag == 'self' then + return + end local exist = guide.getLocal(source, name, source.start-1) if exist then callback { diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index c5bcd5a5..b25ec77a 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -1,9 +1,8 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' local define = require 'proto.define' -local await = require 'await' local function countCallArgs(source) local result = 0 @@ -67,7 +66,7 @@ local function getFuncArgs(func) end return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -81,14 +80,7 @@ return function (uri, callback) end local func = source.node - local funcArgs = cache[func] - if funcArgs == nil then - funcArgs = getFuncArgs(func) or false - local refs = vm.getRefs(func, 0) - for _, ref in ipairs(refs) do - cache[ref] = funcArgs - end - end + local funcArgs = getFuncArgs(func) if not funcArgs then return diff --git a/script/core/diagnostics/redundant-value.lua b/script/core/diagnostics/redundant-value.lua index be483448..d6cd97a7 100644 --- a/script/core/diagnostics/redundant-value.lua +++ b/script/core/diagnostics/redundant-value.lua @@ -3,7 +3,7 @@ local define = require 'proto.define' local lang = require 'language' return function (uri, callback, code) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua index 0a4b1d57..824eb83f 100644 --- a/script/core/diagnostics/trailing-space.lua +++ b/script/core/diagnostics/trailing-space.lua @@ -1,6 +1,6 @@ local files = require 'files' local lang = require 'language' -local guide = require 'core.guide' +local guide = require 'parser.guide' local function isInString(ast, offset) local result = false @@ -13,7 +13,7 @@ local function isInString(ast, offset) end return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua index b2b2800c..df71f0c9 100644 --- a/script/core/diagnostics/unbalanced-assignments.lua +++ b/script/core/diagnostics/unbalanced-assignments.lua @@ -1,10 +1,10 @@ local files = require 'files' local define = require 'proto.define' local lang = require 'language' -local guide = require 'core.guide' +local guide = require 'parser.guide' return function (uri, callback, code) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua index a91cfa7f..e7133ab9 100644 --- a/script/core/diagnostics/undefined-doc-class.lua +++ b/script/core/diagnostics/undefined-doc-class.lua @@ -1,11 +1,11 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end @@ -25,7 +25,7 @@ return function (uri, callback) end for _, ext in ipairs(doc.extends) do local name = ext[1] - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) if cache[name] == nil then cache[name] = false for _, otherDoc in ipairs(docs) do diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index d8a4363b..91d4b90e 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -1,11 +1,10 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' -local define = require 'proto.define' local vm = require 'vm' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end @@ -14,26 +13,6 @@ return function (uri, callback) return end - local classCache = { - ['any'] = true, - ['nil'] = true, - } - local function hasNameOfClassOrAlias(name) - if classCache[name] ~= nil then - return classCache[name] - end - local docs = vm.getDocTypes(name) - for _, otherDoc in ipairs(docs) do - if otherDoc.type == 'doc.class.name' - or otherDoc.type == 'doc.alias.name' then - classCache[name] = true - return true - end - end - classCache[name] = false - return false - end - local function hasNameOfGeneric(name, source) if not source.typeGeneric then return false @@ -56,7 +35,7 @@ return function (uri, callback) if name == '...' then return end - if hasNameOfClassOrAlias(name) + if vm.isDocDefined(name) or hasNameOfGeneric(name, source) then return end diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua index 0bf371e5..6140b4f0 100644 --- a/script/core/diagnostics/undefined-doc-param.lua +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -17,7 +17,7 @@ local function hasParamName(func, name) end return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua index 89efb8c7..39c8de27 100644 --- a/script/core/diagnostics/undefined-env-child.lua +++ b/script/core/diagnostics/undefined-env-child.lua @@ -1,10 +1,11 @@ -local files = require 'files' -local guide = require 'core.guide' -local vm = require 'vm' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' +local lang = require 'language' +local vm = require "vm.vm" return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -13,7 +14,7 @@ return function (uri, callback) if source.node.tag == '_ENV' then return end - local defs = guide.requestDefinition(source) + local defs = vm.getDefs(source) if #defs > 0 then return end diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index b10c9ab0..9d1f696c 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -2,11 +2,18 @@ local files = require 'files' local vm = require 'vm' local lang = require 'language' local config = require 'config' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' +local SkipCheckClass = { + ['unknown'] = true, + ['any'] = true, + ['table'] = true, + ['nil'] = true, +} + return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -18,7 +25,7 @@ return function (uri, callback) if cache[src] == nil then tracy.ZoneBeginN('undefined-field getInfers') infers = vm.getInfers(src, 0) or false - local refs = vm.getRefs(src, 0) + local refs = vm.getRefs(src) for _, ref in ipairs(refs) do cache[ref] = infers end @@ -47,7 +54,7 @@ return function (uri, callback) elseif inferSource.type == 'doc.class.name' then addTo(allDocClass, inferSource.parent) elseif inferSource.type == 'doc.type.name' then - local docTypes = vm.getDocTypes(inferSource[1]) + local docTypes = vm.getDocDefines(inferSource[1]) for _, docType in ipairs(docTypes) do if docType.type == 'doc.class.name' then addTo(allDocClass, docType.parent) @@ -65,7 +72,7 @@ return function (uri, callback) local empty = true for _, docClass in ipairs(allDocClass) do tracy.ZoneBeginN('undefined-field getDefFields') - local refs = vm.getDefFields(docClass) + local refs = vm.getDefs(docClass, '*') tracy.ZoneEnd() for _, ref in ipairs(refs) do @@ -87,35 +94,37 @@ return function (uri, callback) end local function checkUndefinedField(src) - local fieldName = guide.getKeyName(src) - - local allDocClass = getAllDocClassFromInfer(src.node) - if (not allDocClass) or (#allDocClass == 0) then - return - end - - local fields = getAllFieldsFromAllDocClass(allDocClass) - - -- 没找到任何 field,跳过检查 - if not fields then + if #vm.getDefs(src) > 0 then return end - - if not fields[fieldName] then - local message = lang.script('DIAG_UNDEF_FIELD', fieldName) - if src.type == 'getfield' and src.field then - callback { - start = src.field.start, - finish = src.field.finish, - message = message, - } - elseif src.type == 'getmethod' and src.method then - callback { - start = src.method.start, - finish = src.method.finish, - message = message, - } + local node = src.node + if node then + local defs = vm.getDefs(node) + local ok + for _, def in ipairs(defs) do + if def.type == 'doc.class.name' + and not SkipCheckClass[def[1]] then + ok = true + break + end end + if not ok then + return + end + end + local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src)) + if src.type == 'getfield' and src.field then + callback { + start = src.field.start, + finish = src.field.finish, + message = message, + } + elseif src.type == 'getmethod' and src.method then + callback { + start = src.method.start, + finish = src.method.finish, + message = message, + } end end guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField); diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index 161d8856..549a1922 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -1,9 +1,8 @@ -local files = require 'files' -local vm = require 'vm' -local lang = require 'language' -local config = require 'config' -local guide = require 'core.guide' -local define = require 'proto.define' +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local config = require 'config' +local guide = require 'parser.guide' local requireLike = { ['include'] = true, @@ -13,14 +12,14 @@ local requireLike = { } return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end -- 遍历全局变量,检查所有没有 set 模式的全局变量 guide.eachSourceType(ast.ast, 'getglobal', function (src) - local key = guide.getKeyName(src) + local key = src[1] if not key then return end @@ -30,7 +29,11 @@ return function (uri, callback) if config.config.runtime.special[key] then return end - if #vm.getGlobalSets(key) == 0 then + local node = src.node + if node.tag ~= '_ENV' then + return + end + if #vm.getDefs(src) == 0 then local message = lang.script('DIAG_UNDEF_GLOBAL', key) if requireLike[key:lower()] then message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key)) diff --git a/script/core/diagnostics/unknown-diag-code.lua b/script/core/diagnostics/unknown-diag-code.lua index 45d3b6db..013a702b 100644 --- a/script/core/diagnostics/unknown-diag-code.lua +++ b/script/core/diagnostics/unknown-diag-code.lua @@ -3,7 +3,7 @@ local lang = require 'language' local define = require 'proto.define' return function (uri, callback) - local state = files.getAst(uri) + local state = files.getState(uri) if not state then return end diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua index b6f92e60..59f27e59 100644 --- a/script/core/diagnostics/unused-function.lua +++ b/script/core/diagnostics/unused-function.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local vm = require 'vm' local define = require 'proto.define' local lang = require 'language' @@ -19,7 +19,7 @@ local function isToBeClosed(source) end return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua index e2d5e49a..8ee0bba3 100644 --- a/script/core/diagnostics/unused-label.lua +++ b/script/core/diagnostics/unused-label.lua @@ -1,10 +1,10 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua index fde90cb8..072cbd31 100644 --- a/script/core/diagnostics/unused-local.lua +++ b/script/core/diagnostics/unused-local.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' @@ -77,7 +77,7 @@ local function isDocParam(source) end return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -87,6 +87,9 @@ return function (uri, callback) or name == ast.ENVMode then return end + if source.tag == 'self' then + return + end if isToBeClosed(source) then return end diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua index ec0a05fb..2e07e1ee 100644 --- a/script/core/diagnostics/unused-vararg.lua +++ b/script/core/diagnostics/unused-vararg.lua @@ -1,10 +1,10 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' return function (uri, callback) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua index cc87e3ca..03169cfd 100644 --- a/script/core/document-symbol.lua +++ b/script/core/document-symbol.lua @@ -1,8 +1,8 @@ -local await = require 'await' -local files = require 'files' -local guide = require 'core.guide' -local define = require 'proto.define' -local util = require 'utility' +local await = require 'await' +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local util = require 'utility' local function buildName(source, text) if source.type == 'setmethod' @@ -228,7 +228,7 @@ local function buildSource(source, text, used, symbols) end local function makeSymbol(uri) - local ast = files.getAst(uri) + local ast = files.getState(uri) local text = files.getText(uri) if not ast or not text then return nil diff --git a/script/core/find-source.lua b/script/core/find-source.lua index b36306b6..edbb1e2c 100644 --- a/script/core/find-source.lua +++ b/script/core/find-source.lua @@ -1,4 +1,4 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local function isValidFunctionPos(source, offset) for i = 1, #source.keyword // 2 do diff --git a/script/core/folding.lua b/script/core/folding.lua index 15678995..dad98422 100644 --- a/script/core/folding.lua +++ b/script/core/folding.lua @@ -1,5 +1,5 @@ local files = require "files" -local guide = require "core.guide" +local guide = require "parser.guide" local util = require 'utility' local Care = { @@ -145,7 +145,7 @@ local Care = { } return function (uri) - local ast = files.getAst(uri) + local ast = files.getState(uri) local text = files.getText(uri) if not ast or not text then return nil diff --git a/script/core/generic.lua b/script/core/generic.lua new file mode 100644 index 00000000..15950974 --- /dev/null +++ b/script/core/generic.lua @@ -0,0 +1,234 @@ +local guide = require 'parser.guide' +local noder = require "core.noder" + +---@class generic.value +---@field type string +---@field closure generic.closure +---@field proto parser.guide.object +---@field parent parser.guide.object + +---@class generic.closure +---@field type string +---@field proto parser.guide.object +---@field upvalues table<string, generic.value[]> +---@field params generic.value[] +---@field returns generic.value[] + +local m = {} + +---@param closure generic.closure +---@param proto parser.guide.object +local function instantValue(closure, proto) + ---@type generic.value + local value = { + type = 'generic.value', + closure = closure, + proto = proto, + parent = proto.parent, + } + closure.values[#closure.values+1] = value + return value +end + +---递归实例化对象 +---@param proto parser.guide.object +---@return generic.value +local function createValue(closure, proto, callback, road) + if callback then + road = road or {} + end + if proto.type == 'doc.type' then + local types = {} + local hasGeneric + for i, tp in ipairs(proto.types) do + local genericValue = createValue(closure, tp, callback, road) + if genericValue then + hasGeneric = true + types[i] = genericValue + else + types[i] = tp + end + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.types = types + noder.compileNode(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.name' then + if not proto.typeGeneric then + return nil + end + local key = proto[1] + local value = instantValue(closure, proto) + if callback then + callback(road, key, proto) + end + noder.compileNode(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.function' then + local hasGeneric + local args = {} + local returns = {} + for i, arg in ipairs(proto.args) do + local value = createValue(closure, arg, callback, road) + if value then + hasGeneric = true + end + args[i] = value or arg + end + for i, rtn in ipairs(proto.returns) do + local value = createValue(closure, rtn, callback, road) + if value then + hasGeneric = true + end + returns[i] = value or rtn + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.args = args + value.returns = returns + value.isGeneric = true + noder.pushSource(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.array' then + if road then + road[#road+1] = noder.ANY_FIELD + end + local node = createValue(closure, proto.node, callback, road) + if road then + road[#road] = nil + end + if not node then + return nil + end + local value = instantValue(closure, proto) + value.node = node + return value + end + if proto.type == 'doc.type.table' then + road[#road+1] = noder.TABLE_KEY + local tkey = createValue(closure, proto.tkey, callback, road) + road[#road] = nil + + road[#road+1] = noder.ANY_FIELD + local tvalue = createValue(closure, proto.tvalue, callback, road) + road[#road] = nil + + if not tkey and not tvalue then + return nil + end + local value = instantValue(closure, proto) + value.tkey = tkey or proto.tkey + value.tvalue = tvalue or proto.tvalue + return value + end +end + +local function buildValue(road, key, proto, param, upvalues) + local paramID + if proto.literal then + local str = param.type == 'string' and param[1] + if not str then + return + end + paramID = 'dn:' .. str + else + paramID = noder.getID(param) + end + if not paramID then + return + end + local myUri = guide.getUri(param) + local myHead = noder.URI_CHAR .. myUri .. noder.URI_CHAR + paramID = myHead .. paramID + if not upvalues[key] then + upvalues[key] = {} + end + upvalues[key][#upvalues[key]+1] = paramID .. table.concat(road) +end + +-- 为所有的 param 与 return 创建副本 +---@param closure generic.closure +local function buildValues(closure) + local protoFunction = closure.proto + local upvalues = closure.upvalues + local params = closure.call.args + + if protoFunction.type == 'function' then + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.param' then + local extends = doc.extends + local index = extends.paramIndex + if index then + local param = params and params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + end + end + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + closure.returns[rtn.returnIndex] = createValue(closure, rtn) or rtn + end + end + end + end + if protoFunction.type == 'doc.type.function' then + for index, arg in ipairs(protoFunction.args) do + local extends = arg.extends + local param = params and params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + for index, rtn in ipairs(protoFunction.returns) do + closure.returns[index] = createValue(closure, rtn) or rtn + end + end +end + +---创建一个闭包 +---@param proto parser.guide.object|generic.value # 原型函数|泛型值 +---@return generic.closure +function m.createClosure(proto, call) + local protoFunction, parentClosure + if proto.type == 'function' then + protoFunction = proto + elseif proto.type == 'doc.type.function' then + protoFunction = proto + elseif proto.type == 'generic.value' then + protoFunction = proto.proto + parentClosure = proto.closure + end + ---@type generic.closure + local closure = { + type = 'generic.closure', + parent = protoFunction.parent, + proto = protoFunction, + call = call, + upvalues = parentClosure and parentClosure.upvalues or {}, + params = {}, + returns = {}, + values = {}, + } + buildValues(closure) + + if #closure.returns == 0 then + return nil + end + + noder.compileNode(noder.getNoders(proto), closure) + + return closure +end + +return m diff --git a/script/core/guide.lua b/script/core/guide2.lua index c7a784b7..183555b3 100644 --- a/script/core/guide.lua +++ b/script/core/guide2.lua @@ -292,8 +292,13 @@ end ---@param obj parser.guide.object ---@return parser.guide.object function m.getRoot(obj) + local source = obj + if source._root then + return source._root + end for _ = 1, 1000 do if obj.type == 'main' then + source._root = obj return obj end local parent = obj.parent diff --git a/script/core/highlight.lua b/script/core/highlight.lua index 12ec114f..d1f11906 100644 --- a/script/core/highlight.lua +++ b/script/core/highlight.lua @@ -1,31 +1,18 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local files = require 'files' local vm = require 'vm' local define = require 'proto.define' local findSource = require 'core.find-source' local util = require 'utility' +local guide = require 'parser.guide' local function eachRef(source, callback) - local results = guide.requestReference(source) + local results = vm.getRefs(source) for i = 1, #results do callback(results[i]) end end -local function eachField(source, callback) - if not source then - return - end - local isGlobal = guide.isGlobal(source) - local results = guide.requestReference(source) - for i = 1, #results do - local res = results[i] - if isGlobal == guide.isGlobal(res) then - callback(res) - end - end -end - local function eachLocal(source, callback) callback(source) if source.ref then @@ -43,21 +30,21 @@ local function find(source, uri, callback) eachLocal(source.node, callback) elseif source.type == 'field' or source.type == 'method' then - eachField(source.parent, callback) + eachRef(source.parent, callback) elseif source.type == 'getindex' or source.type == 'setindex' or source.type == 'tableindex' then - eachField(source, callback) + eachRef(source, callback) elseif source.type == 'setglobal' or source.type == 'getglobal' then - eachField(source, callback) + eachRef(source, callback) elseif source.type == 'goto' or source.type == 'label' then eachRef(source, callback) elseif source.type == 'string' and source.parent and source.parent.index == source then - eachField(source.parent, callback) + eachRef(source.parent, callback) elseif source.type == 'string' or source.type == 'boolean' or source.type == 'number' @@ -238,8 +225,18 @@ local accept = { ['nil'] = true, } +local function isLiteralValue(source) + if not guide.isLiteral(source) then + return false + end + if source.parent.index == source then + return false + end + return true +end + return function (uri, offset) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end @@ -249,10 +246,28 @@ return function (uri, offset) local source = findSource(ast, offset, accept) if source then + local isGlobal = guide.isGlobal(source) + local isLiteral = isLiteralValue(source) find(source, uri, function (target) + if not target then + return + end if target.dummy then return end + if mark[target] then + return + end + mark[target] = true + if isGlobal ~= guide.isGlobal(target) then + return + end + if isLiteral ~= isLiteralValue(target) then + return + end + if not files.eq(uri, guide.getUri(target)) then + return + end local kind if target.type == 'getfield' then target = target.field @@ -315,13 +330,6 @@ return function (uri, offset) else return end - if not target then - return - end - if mark[target] then - return - end - mark[target] = true results[#results+1] = { start = target.start, finish = target.finish, diff --git a/script/core/hint.lua b/script/core/hint.lua index 13d01dc7..67c725f7 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -1,10 +1,11 @@ -local files = require 'files' -local guide = require 'core.guide' -local vm = require 'vm' -local config = require 'config' +local files = require 'files' +local infer = require 'core.infer' +local vm = require 'vm' +local config = require 'config' +local guide = require 'parser.guide' local function typeHint(uri, edits, start, finish) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -18,6 +19,9 @@ local function typeHint(uri, edits, start, finish) and source.type ~= 'setindex' then return end + if source.dummy then + return + end if source[1] == '_' then return end @@ -33,9 +37,9 @@ local function typeHint(uri, edits, start, finish) return end end - local infer = vm.getInferType(source, 0) - if infer == 'any' - or infer == 'nil' then + local view = infer.searchAndViewInfers(source) + if view == 'any' + or view == 'nil' then return end local src = source @@ -52,7 +56,7 @@ local function typeHint(uri, edits, start, finish) end mark[src] = true edits[#edits+1] = { - newText = (':%s'):format(infer), + newText = (':%s'):format(view), start = src.finish, finish = src.finish, } @@ -95,7 +99,7 @@ local function paramName(uri, edits, start, finish) if not config.config.hint.paramName then return end - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end @@ -107,7 +111,7 @@ local function paramName(uri, edits, start, finish) if not hasLiteralArgInCall(source) then return end - local defs = vm.getDefs(source.node, 0) + local defs = vm.getDefs(source.node) if not defs then return end diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua index 324d28af..822be2b6 100644 --- a/script/core/hover/arg.lua +++ b/script/core/hover/arg.lua @@ -1,4 +1,5 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' +local infer = require 'core.infer' local vm = require 'vm' local function optionalArg(arg) @@ -21,7 +22,7 @@ local function asFunction(source, oop) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(vm.getInferType(parent.node)) + args[#args+1] = ('self: %s'):format(infer.searchAndViewInfers(parent.node)) end if source.args then for i = 1, #source.args do @@ -34,10 +35,12 @@ local function asFunction(source, oop) args[#args+1] = ('%s%s: %s'):format( name, optionalArg(arg) and '?' or '', - vm.getInferType(arg) + infer.searchAndViewInfers(arg) ) + elseif arg.type == '...' then + args[#args+1] = '...' else - args[#args+1] = ('%s'):format(vm.getInferType(arg)) + args[#args+1] = ('%s'):format(infer.searchAndViewInfers(arg)) end ::CONTINUE:: end @@ -61,7 +64,7 @@ local function asDocFunction(source) args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - vm.getInferType(arg.extends) + infer.searchAndViewInfers(arg.extends) ) else args[i] = ('%s%s'):format( diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index 401ca5a7..bcc3065a 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -2,11 +2,13 @@ local vm = require 'vm' local ws = require 'workspace' local furi = require 'file-uri' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local markdown = require 'provider.markdown' local config = require 'config' local lang = require 'language' local util = require 'utility' +local guide = require 'parser.guide' +local noder = require 'core.noder' local function asStringInRequire(source, literal) local rootPath = ws.path or '' @@ -124,10 +126,10 @@ local function getBindComment(source, docGroup, base) end local function tryDocClassComment(source) - for _, def in ipairs(vm.getDefs(source, 0)) do + for _, def in ipairs(vm.getDefs(source)) do if def.type == 'doc.class.name' or def.type == 'doc.alias.name' then - local class = guide.getDocState(def) + local class = noder.getDocState(def) local comment = getBindComment(class, class.bindGroup, class) if comment then return comment @@ -180,7 +182,7 @@ local function isFunction(source) if source.type == 'function' then return true end - local value = guide.getObjectValue(source) + local value = searcher.getObjectValue(source) if not value then return false end @@ -223,13 +225,14 @@ local function getBindEnums(source, docGroup) end local function tryDocFieldUpComment(source) - if source.type ~= 'doc.field' then + if source.type ~= 'doc.field.name' then return end - if not source.bindGroup then + local docField = source.parent + if not docField.bindGroup then return end - local comment = getBindComment(source, source.bindGroup, source) + local comment = getBindComment(docField, docField.bindGroup, docField) return comment end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index 0c8644ed..41616bc9 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local getLabel = require 'core.hover.label' local getDesc = require 'core.hover.description' @@ -7,6 +7,7 @@ local util = require 'utility' local findSource = require 'core.find-source' local lang = require 'language' local markdown = require 'provider.markdown' +local infer = require 'core.infer' local function eachFunctionAndOverload(value, callback) callback(value) @@ -24,7 +25,7 @@ local function getHoverAsValue(source) local label = getLabel(source) local desc = getDesc(source) if not desc then - local values = vm.getDefs(source, 0) + local values = vm.getDefs(source) for _, def in ipairs(values) do desc = getDesc(def) if desc then @@ -40,7 +41,7 @@ local function getHoverAsValue(source) end local function getHoverAsFunction(source) - local values = vm.getDefs(source, 0) + local values = vm.getDefs(source) local desc = getDesc(source) local labels = {} local defs = 0 @@ -48,7 +49,7 @@ local function getHoverAsFunction(source) local other = 0 local mark = {} for _, def in ipairs(values) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def.type == 'function' or def.type == 'doc.type.function' then eachFunctionAndOverload(def, function (value) @@ -123,7 +124,7 @@ local function getHover(source) if source.type == 'doc.type.name' then return getHoverAsDocName(source) end - local isFunction = vm.hasInferType(source, 'function', 0) + local isFunction = infer.hasType(source, 'function') if isFunction then return getHoverAsFunction(source) else @@ -146,7 +147,7 @@ local accept = { } local function getHoverByUri(uri, offset) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index d93b14e3..d96b149c 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -2,9 +2,10 @@ local buildName = require 'core.hover.name' local buildArg = require 'core.hover.arg' local buildReturn = require 'core.hover.return' local buildTable = require 'core.hover.table' +local infer = require 'core.infer' local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local config = require 'config' local files = require 'files' @@ -31,29 +32,28 @@ local function asDocFunction(source) end local function asDocTypeName(source) - for _, doc in ipairs(vm.getDocTypes(source[1])) do + local defs = vm.getDefs(source) + for _, doc in ipairs(defs) do if doc.type == 'doc.class.name' then - return 'class ' .. source[1] + return 'class ' .. doc[1] end if doc.type == 'doc.alias.name' then local extends = doc.parent.extends - return lang.script('HOVER_EXTENDS', vm.getInferType(extends)) + return lang.script('HOVER_EXTENDS', infer.searchAndViewInfers(extends)) end end end local function asValue(source, title) local name = buildName(source) - local infers = vm.getInfers(source, 0) - local type = vm.getInferType(source, 0) - local class = vm.getClass(source, 0) - local literal = vm.getInferLiteral(source, 0) + local type = infer.searchAndViewInfers(source) + local literal = infer.searchAndViewLiterals(source) local cont - if not vm.hasInferType(source, 'string', 0) + if not infer.hasType(source, 'string') and not type:find('%[%]$') and not type:find('%w%<') then - if #vm.getFields(source, 0) > 0 - or vm.hasInferType(source, 'table', 0) then + if #vm.getRefs(source, '*') > 0 + or infer.hasType(source, 'table') then cont = buildTable(source) end end @@ -66,11 +66,7 @@ local function asValue(source, title) or type == 'nil') then type = nil end - if class then - pack[#pack+1] = class - else - pack[#pack+1] = type - end + pack[#pack+1] = type if literal then pack[#pack+1] = '=' pack[#pack+1] = literal @@ -123,30 +119,21 @@ local function asField(source) return asValue(source, 'field') end -local function asDocField(source) - local name = source.field[1] +local function asDocFieldName(source) + local name = source[1] + local docField = source.parent local class - for _, doc in ipairs(source.bindGroup) do + for _, doc in ipairs(docField.bindGroup) do if doc.type == 'doc.class' then class = doc break end end - local infers = {} - for _, infer in ipairs(vm.getInfers(source.extends) or {}) do - infers[#infers+1] = infer - end + local view = infer.searchAndViewInfers(docField.extends) if not class then - return ('field ?.%s: %s'):format( - name, - guide.viewInferType(infers) - ) - end - return ('field %s.%s: %s'):format( - class.class[1], - name, - guide.viewInferType(infers) - ) + return ('field ?.%s: %s'):format(name, view) + end + return ('field %s.%s: %s'):format(class.class[1], name, view) end local function asString(source) @@ -177,7 +164,7 @@ local function asNumber(source) if type(num) ~= 'number' then return nil end - local uri = guide.getUri(source) + local uri = searcher.getUri(source) local text = files.getText(uri) if not text then return nil @@ -215,7 +202,7 @@ return function (source, oop) return asDocFunction(source) elseif source.type == 'doc.type.name' then return asDocTypeName(source) - elseif source.type == 'doc.field' then - return asDocField(source) + elseif source.type == 'doc.field.name' then + return asDocFieldName(source) end end diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua index d583f1e1..d2b9d30b 100644 --- a/script/core/hover/name.lua +++ b/script/core/hover/name.lua @@ -1,4 +1,6 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' +local infer = require 'core.infer' +local guide = require 'parser.guide' local vm = require 'vm' local buildName @@ -19,7 +21,7 @@ end local function asField(source, oop) local class if source.node.type ~= 'getglobal' then - class = vm.getClass(source.node, 0) + class = infer.getClass(source.node) end local node = class or guide.getKeyName(source.node) or '?' local method = guide.getKeyName(source) diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index c3e9656d..0f0d85e0 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,12 +1,4 @@ -local guide = require 'core.guide' -local vm = require 'vm' - -local function mergeTypes(returns) - if type(returns) == 'string' then - return returns - end - return guide.mergeTypes(returns) -end +local infer = require 'core.infer' local function getReturnDualByDoc(source) local docs = source.bindDocs @@ -55,24 +47,20 @@ local function asFunction(source) local returns = {} for i, rtn in ipairs(dual) do local line = {} - local types = {} + local infers = {} if i == 1 then line[#line+1] = ' -> ' else line[#line+1] = ('% 3d. '):format(i) end for n = 1, #rtn do - local values = vm.getInfers(rtn[n]) - for _, value in ipairs(values) do - if value.type then - for tp in value.type:gmatch '[^|]+' do - types[tp] = true - end - end + local values = infer.searchInfers(rtn[n]) + for tp in pairs(values) do + infers[tp] = true end end - if next(types) or rtn[1] then - local tp = mergeTypes(types) or 'any' + if next(infers) or rtn[1] then + local tp = infer.viewInfers(infers) if rtn[1].name then line[#line+1] = ('%s%s: %s'):format( rtn[1].name[1], @@ -103,7 +91,7 @@ local function asDocFunction(source) local returns = {} for i, rtn in ipairs(source.returns) do local rtnText = ('%s%s'):format( - vm.getInferType(rtn), + infer.searchAndViewInfers(rtn), rtn.optional and '?' or '' ) if i == 1 then diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index edb7751b..159453e6 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -1,26 +1,12 @@ local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local config = require 'config' local lang = require 'language' +local infer = require 'core.infer' -local function getKey(src) - local key = vm.getKeyName(src) - if not key or #key <= 0 then - if not src.index then - return '[any]' - end - local class = vm.getClass(src.index) - if class then - return ('[%s]'):format(class) - end - local tp = vm.getInferType(src.index) - if tp then - return ('[%s]'):format(tp) - end - return '[any]' - end - if guide.getKeyType(src) == 'string' then +local function formatKey(key) + if type(key) == 'string' then if key:match '^[%a_][%w_]*$' then return key else @@ -30,104 +16,16 @@ local function getKey(src) return ('[%s]'):format(key) end -local function getFieldFull(src) - local value = guide.getObjectValue(src) or src - local tp = vm.getInferType(value, 0) - --local class = vm.getClass(src) - local literal = vm.getInferLiteral(value) - if type(literal) == 'string' and #literal >= 50 then - literal = literal:sub(1, 47) .. '...' - end - return tp, literal -end - -local function getFieldFast(src) - if src.bindDocs then - return getFieldFull(src) - end - local value = guide.getObjectValue(src) or src - if not value then - return 'any' - end - if value.type == 'boolean' then - return value.type, util.viewLiteral(value[1]) - end - if value.type == 'number' - or value.type == 'integer' then - if math.tointeger(value[1]) then - if config.config.runtime.version == 'Lua 5.3' - or config.config.runtime.version == 'Lua 5.4' then - return 'integer', util.viewLiteral(value[1]) - end - end - return value.type, util.viewLiteral(value[1]) - end - if value.type == 'table' - or value.type == 'function' then - return value.type - end - if value.type == 'string' then - local literal = value[1] - if type(literal) == 'string' and #literal >= 50 then - literal = literal:sub(1, 47) .. '...' - end - return value.type, util.viewLiteral(literal) - end - if value.type == 'doc.field' then - return vm.getInferType(value) - end -end - -local function getField(src, timeUp, mark, key) - if src.type == 'table' - or src.type == 'function' then - return nil - end - if src.parent then - if src.type == 'string' - or src.type == 'boolean' - or src.type == 'number' - or src.type == 'integer' then - if src.parent.type == 'tableindex' - or src.parent.type == 'setindex' - or src.parent.type == 'getindex' then - if src.parent.index == src then - src = src.parent - end - end - end - end - local tp, literal - tp, literal = getFieldFast(src) - if tp then - return tp, literal - end - if timeUp or mark[key] then - return nil - end - mark[key] = true - tp, literal = getFieldFull(src) - if tp then - return tp, literal - end - return nil -end - -local function buildAsHash(classes, literals, reachMax) - local keys = {} - for k in pairs(classes) do - keys[#keys+1] = k - end - table.sort(keys) +local function buildAsHash(keys, inferMap, literalMap, reachMax) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local class = classes[key] - local literal = literals[key] - if literal then - lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + local inferView = inferMap[key] + local literalView = literalMap[key] + if literalView then + lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView) else - lines[#lines+1] = (' %s: %s,'):format(key, class) + lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView) end end if reachMax then @@ -137,23 +35,19 @@ local function buildAsHash(classes, literals, reachMax) return table.concat(lines, '\n') end -local function buildAsConst(classes, literals, reachMax) - local keys = {} - for k in pairs(classes) do - keys[#keys+1] = k - end +local function buildAsConst(keys, inferMap, literalMap, reachMax) table.sort(keys, function (a, b) - return tonumber(literals[a]) < tonumber(literals[b]) + return tonumber(literalMap[a]) < tonumber(literalMap[b]) end) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local class = classes[key] - local literal = literals[key] - if literal then - lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + local inferView = inferMap[key] + local literalView = literalMap[key] + if literalView then + lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView) else - lines[#lines+1] = (' %s: %s,'):format(key, class) + lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView) end end if reachMax then @@ -163,111 +57,79 @@ local function buildAsConst(classes, literals, reachMax) return table.concat(lines, '\n') end -local function mergeLiteral(literals) - local results = {} +local typeSorter = { + ['string'] = 1, + ['number'] = 2, + ['boolean'] = 3, +} + +local function getKeyMap(fields) + local keys = {} local mark = {} - for _, value in ipairs(literals) do - if not mark[value] then - mark[value] = true - results[#results+1] = value + for _, field in ipairs(fields) do + local key = vm.getKeyName(field) + local tp = vm.getKeyType(field) + if tp == 'number' then + key = tonumber(key) + elseif tp == 'boolean' then + key = key == 'true' end - end - if #results == 0 then - return nil - end - table.sort(results) - return table.concat(results, '|') -end - -local function mergeTypes(types) - local results = {} - local mark = { - -- 讲道理table的keyvalue不会是nil - ['nil'] = true, - } - for _, tv in ipairs(types) do - for tp in tv:gmatch '[^|]+' do - if not mark[tp] then - mark[tp] = true - results[tp] = true - end + if key and not mark[key] then + mark[key] = true + keys[#keys+1] = key end end - return guide.mergeTypes(results) -end - -local function clearClasses(classes) - classes['[nil]'] = nil - classes['[any]'] = nil - classes['[string]'] = nil + table.sort(keys, function (a, b) + local ta = typeSorter[type(a)] + local tb = typeSorter[type(b)] + if ta == tb then + return tostring(a) < tostring(b) + else + return ta < tb + end + end) + return keys end return function (source) - if config.config.hover.previewFields <= 0 then + local maxFields = config.config.hover.previewFields + if maxFields <= 0 then return 'table' end - local literals = {} - local classes = {} - local clock = os.clock() - local timeUp - local mark = {} - local fields = vm.getFields(source, 0) - local keyCount = 0 - local reachMax - for _, src in ipairs(fields) do - local key = getKey(src) - if not key then - goto CONTINUE - end - if not classes[key] then - classes[key] = {} - keyCount = keyCount + 1 - end - if not literals[key] then - literals[key] = {} - end - if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then - timeUp = true - end - local class, literal = getField(src, timeUp, mark, key) - if literal == 'nil' then - literal = nil - end - classes[key][#classes[key]+1] = class - literals[key][#literals[key]+1] = literal - if keyCount >= config.config.hover.previewFields then - reachMax = true - break - end - ::CONTINUE:: - end - - clearClasses(classes) - for key, class in pairs(classes) do - literals[key] = mergeLiteral(literals[key]) - classes[key] = mergeTypes(class) - end + local fields = vm.getRefs(source, '*') + local keys = getKeyMap(fields) - if not next(classes) then + if #keys == 0 then return '{}' end - local intValue = true - for key, class in pairs(classes) do - if class ~= 'integer' or not tonumber(literals[key]) then - intValue = false - break + local inferMap = {} + local literalMap = {} + + local reachMax = maxFields < #keys + + local isConsts = true + for i = 1, math.min(maxFields, #keys) do + local key = keys[i] + inferMap[key] = infer.searchAndViewInfers(source, key) + literalMap[key] = infer.searchAndViewLiterals(source, key) + if not tonumber(literalMap[key]) then + isConsts = false end end + local result - if intValue then - result = buildAsConst(classes, literals, reachMax) + + if isConsts then + result = buildAsConst(keys, inferMap, literalMap, reachMax) else - result = buildAsHash(classes, literals, reachMax) - end - if timeUp then - result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + result = buildAsHash(keys, inferMap, literalMap, reachMax) end + + --if timeUp then + -- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + --end + return result end diff --git a/script/core/infer.lua b/script/core/infer.lua new file mode 100644 index 00000000..a2c12fba --- /dev/null +++ b/script/core/infer.lua @@ -0,0 +1,639 @@ +local searcher = require 'core.searcher' +local config = require 'config' +local noder = require 'core.noder' +local util = require 'utility' +local vm = require "vm.vm" + +local STRING_OR_TABLE = {'STRING_OR_TABLE'} +local BE_RETURN = {'BE_RETURN'} +local BE_CONNACT = {'BE_CONNACT'} +local CLASS = {'CLASS'} +local TABLE = {'TABLE'} + +local TypeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['true'] = 101, + ['false'] = 102, + ['nil'] = 999, +} + +local m = {} + +local function mergeTable(a, b) + if not b then + return + end + for v in pairs(b) do + a[v] = true + end +end + +local function searchInferOfUnary(value, infers, mark) + local op = value.op.type + if op == 'not' then + infers['boolean'] = true + return + end + if op == '#' then + infers['integer'] = true + return + end + if op == '-' then + if m.hasType(value[1], 'integer', mark) then + infers['integer'] = true + else + infers['number'] = true + end + return + end + if op == '~' then + infers['integer'] = true + return + end +end + +local function searchInferOfBinary(value, infers, mark) + local op = value.op.type + if op == 'and' then + if m.isTrue(value[1], mark) then + mergeTable(infers, m.searchInfers(value[2], nil, mark)) + else + mergeTable(infers, m.searchInfers(value[1], nil, mark)) + end + return + end + if op == 'or' then + if m.isTrue(value[1], mark) then + mergeTable(infers, m.searchInfers(value[1], nil, mark)) + else + mergeTable(infers, m.searchInfers(value[2], nil, mark)) + end + return + end + if op == '==' + or op == '~=' + or op == '<' + or op == '>' + or op == '<=' + or op == '>=' then + infers['boolean'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '&' + or op == '|' then + infers['integer'] = true + return + end + if op == '..' then + infers['string'] = true + return + end + if op == '^' + or op == '/' then + infers['number'] = true + return + end + if op == '+' + or op == '-' + or op == '*' + or op == '%' + or op == '//' then + if m.hasType(value[1], 'integer', mark) + and m.hasType(value[2], 'integer', mark) then + infers['integer'] = true + else + infers['number'] = true + end + return + end +end + +local function searchInferOfValue(value, infers, mark) + if value.type == 'string' then + infers['string'] = true + return true + end + if value.type == 'boolean' then + infers['boolean'] = true + return true + end + if value.type == 'table' then + if value.array then + local node = m.searchAndViewInfers(value.array, nil, mark) + local infer = node .. '[]' + infers[infer] = true + else + infers['table'] = true + end + return true + end + if value.type == 'number' then + if math.type(value[1]) == 'integer' then + infers['integer'] = true + else + infers['number'] = true + end + return true + end + if value.type == 'nil' then + infers['nil'] = true + return true + end + if value.type == 'function' then + infers['function'] = true + return true + end + if value.type == 'unary' then + searchInferOfUnary(value, infers, mark) + return true + end + if value.type == 'binary' then + searchInferOfBinary(value, infers, mark) + return true + end + return false +end + +local function searchLiteralOfValue(value, literals, mark) + if value.type == 'string' + or value.type == 'boolean' + or value.type == 'number' + or value.type == 'integer' then + local v = value[1] + if v ~= nil then + literals[v] = true + end + return + end + if value.type == 'unary' then + local op = value.op.type + if op == '-' then + local subLiterals = m.searchLiterals(value[1], nil, mark) + if subLiterals then + for subLiteral in pairs(subLiterals) do + local num = tonumber(subLiteral) + if num then + literals[-num] = true + end + end + end + end + if op == '~' then + local subLiterals = m.searchLiterals(value[1], nil, mark) + if subLiterals then + for subLiteral in pairs(subLiterals) do + local num = math.tointeger(subLiteral) + if num then + literals[~num] = true + end + end + end + end + end + return +end + +local function bindClassOrType(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' + or doc.type == 'doc.type' then + return true + end + end + return false +end + +local function cleanInfers(infers) + local version = config.config.runtime.version + local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4' + infers['unknown'] = nil + if infers['any'] and infers['nil'] then + infers['nil'] = nil + end + if infers['number'] then + enableInteger = false + end + if not enableInteger and infers['integer'] then + infers['integer'] = nil + infers['number'] = true + end + -- stringlib 就是 string + if infers['stringlib'] and infers['string'] then + infers['stringlib'] = nil + end + -- 如果是通过 .. 来推测的,且结果里没有 number 与 integer,则推测为string + if infers[BE_CONNACT] then + infers[BE_CONNACT] = nil + if not infers['number'] and not infers['integer'] then + infers['string'] = true + end + end + -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型 + if infers[STRING_OR_TABLE] then + infers[STRING_OR_TABLE] = nil + if not infers['table'] and not infers['string'] then + infers['table'] = true + infers['string'] = true + end + end + -- 如果有doc标记,则先移除table类型 + if infers[CLASS] then + infers[CLASS] = nil + infers['table'] = nil + end + -- 用doc标记的table,加入table类型 + if infers[TABLE] then + infers[TABLE] = nil + infers['table'] = true + end + if infers[BE_RETURN] then + infers[BE_RETURN] = nil + infers['nil'] = nil + end + infers['any'] = nil +end + +---合并对象的推断类型 +---@param infers string[] +---@return string +function m.viewInfers(infers) + if infers[0] then + return infers[0] + end + -- 如果有显性的 any ,则直接显示为 any + if infers['any'] then + infers[0] = 'any' + return 'any' + end + local result = {} + local count = 0 + for infer in pairs(infers) do + count = count + 1 + result[count] = infer + end + -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any + if count == 0 then + infers[0] = 'any' + return 'any' + end + table.sort(result, function (a, b) + local sa = TypeSort[a] or 100 + local sb = TypeSort[b] or 100 + if sa == sb then + return a < b + else + return sa < sb + end + end) + infers[0] = table.concat(result, '|') + return infers[0] +end + +---合并对象的值 +---@param literals string[] +---@return string +function m.viewLiterals(literals) + local result = {} + local count = 0 + for infer in pairs(literals) do + count = count + 1 + result[count] = util.viewLiteral(infer) + end + if count == 0 then + return nil + end + table.sort(result) + local view = table.concat(result, '|') + return view +end + +function m.viewDocName(doc) + if not doc then + return nil + end + if doc.type == 'doc.type' then + local list = {} + for _, tp in ipairs(doc.types) do + list[#list+1] = m.getDocName(tp) + end + for _, enum in ipairs(doc.enums) do + list[#list+1] = m.getDocName(enum) + end + return table.concat(list, '|') + end + return m.getDocName(doc) +end + +function m.getDocName(doc) + if not doc then + return nil + end + if doc.type == 'doc.class.name' + or doc.type == 'doc.type.name' then + local name = doc[1] or '?' + if doc.typeGeneric then + return '<' .. name .. '>' + else + return name + end + end + if doc.type == 'doc.type.array' then + local nodeName = m.viewDocName(doc.node) or '?' + return nodeName .. '[]' + end + if doc.type == 'doc.type.table' then + local key = m.viewDocName(doc.tkey) or '?' + local value = m.viewDocName(doc.tvalue) or '?' + return ('table<%s, %s>'):format(key, value) + end + if doc.type == 'doc.type.function' then + return 'function' + end + if doc.type == 'doc.type.enum' + or doc.type == 'doc.resume' then + local value = doc[1] or '?' + return value + end +end + +function m.viewDocFunction(doc) + if doc.type ~= 'doc.type.function' then + return '' + end + local args = {} + for i, arg in ipairs(doc.args) do + args[i] = ('%s: %s'):format(arg.name[1], m.viewDocName(arg.extends)) + end + local label = ('fun(%s)'):format(table.concat(args, ', ')) + if #doc.returns > 0 then + local returns = {} + for i, rtn in ipairs(doc.returns) do + returns[i] = m.viewDocName(rtn) + end + label = ('%s:%s'):format(label, table.concat(returns)) + end + return label +end + +---显示对象的推断类型 +---@param source parser.guide.object +---@param mark table +---@return string +local function searchInfer(source, infers, mark) + if bindClassOrType(source) then + return + end + if searchInferOfValue(source, infers, mark) then + return + end + local value = searcher.getObjectValue(source) + if value then + searchInferOfValue(value, infers, mark) + return + end + -- check LuaDoc + local docName = m.getDocName(source) + if docName then + infers[docName] = true + if docName ~= 'unknown' then + infers[CLASS] = true + end + if docName == 'table' then + infers[TABLE] = true + end + end + if source.parent.type == 'unary' then + local op = source.parent.op.type + -- # XX -> string | table + if op == '#' then + infers[STRING_OR_TABLE] = true + return + end + if op == '-' then + infers['number'] = true + return + end + if op == '~' then + infers['integer'] = true + return + end + return + end + if source.parent.type == 'binary' then + local op = source.parent.op.type + if op == '+' + or op == '-' + or op == '*' + or op == '/' + or op == '//' + or op == '^' + or op == '%' then + infers['number'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '|' + or op == '&' then + infers['integer'] = true + return + end + if op == '..' then + infers[BE_CONNACT] = true + return + end + end + -- X.a -> table + if source.next and source.next.node == source then + if source.next.type == 'setfield' + or source.next.type == 'setindex' + or source.next.type == 'setmethod' + or source.next.type == 'getfield' + or source.next.type == 'getindex' then + infers['table'] = true + end + if source.next.type == 'getmethod' then + infers[STRING_OR_TABLE] = true + end + end + -- return XX + if source.parent.type == 'return' then + infers[BE_RETURN] = true + end +end + +local function searchLiteral(source, literals, mark) + local value = searcher.getObjectValue(source) + if value then + searchLiteralOfValue(value, literals, mark) + return + end +end + +---搜索对象的推断类型 +---@param source parser.guide.object +---@param field? string +---@param mark? table +---@return string[] +function m.searchInfers(source, field, mark) + if not source then + return nil + end + local defs = vm.getDefs(source, field) + local infers = {} + mark = mark or {} + if not field then + mark[source] = true + searchInfer(source, infers) + local id = noder.getID(source) + if id then + local node = noder.getNodeByID(source, id) + if node and node.sources then + for _, src in ipairs(node.sources) do + if not mark[src] then + mark[src] = true + searchInfer(src, infers, mark) + end + end + end + end + end + if source.type == 'field' or source.type == 'method' then + mark[source.parent] = true + searchInfer(source.parent, infers, mark) + end + for _, def in ipairs(defs) do + if not mark[def] then + mark[def] = true + searchInfer(def, infers, mark) + end + end + if source.docParam then + local docType = source.docParam.extends + if docType.type == 'doc.type' then + for _, def in ipairs(docType.types) do + if def.typeGeneric and not mark[def] then + mark[def] = true + searchInfer(def, infers, mark) + end + end + end + end + if source.type == 'doc.type' then + if source.type == 'doc.type' then + for _, def in ipairs(source.types) do + if def.typeGeneric and not mark[def] then + mark[def] = true + searchInfer(def, infers, mark) + end + end + end + end + cleanInfers(infers) + return infers +end + +---搜索对象的字面量值 +---@param source parser.guide.object +---@param field? string +---@param mark? table +---@return table +function m.searchLiterals(source, field, mark) + local defs = vm.getDefs(source, field) + local literals = {} + mark = mark or {} + if not field then + mark[source] = true + searchLiteral(source, literals, mark) + end + for _, def in ipairs(defs) do + if not mark[def] then + mark[def] = true + searchLiteral(def, literals, mark) + end + end + return literals +end + +---搜索并显示推断值 +---@param source parser.guide.object +---@param field? string +---@return string +function m.searchAndViewLiterals(source, field, mark) + if not source then + return nil + end + local literals = m.searchLiterals(source, field, mark) + local view = m.viewLiterals(literals) + return view +end + +---判断对象的推断值是否是 true +---@param source parser.guide.object +---@param mark? table +function m.isTrue(source, mark) + if not source then + return false + end + local literals = m.searchLiterals(source, nil, mark) + for literal in pairs(literals) do + if literal ~= false then + return true + end + end + return false +end + +---判断对象的推断类型是否包含某个类型 +function m.hasType(source, tp, mark) + local infers = m.searchInfers(source, nil, mark) + return infers[tp] or false +end + +---搜索并显示推断类型 +---@param source parser.guide.object +---@param field? string +---@return string +function m.searchAndViewInfers(source, field, mark) + if not source then + return 'any' + end + local infers = m.searchInfers(source, field, mark) + local view = m.viewInfers(infers) + return view +end + +---搜索并显示推断的class +---@param source parser.guide.object +---@return string? +function m.getClass(source) + if not source then + return nil + end + local infers = {} + local defs = vm.getDefs(source) + for _, def in ipairs(defs) do + if def.type == 'doc.class.name' then + infers[def[1]] = true + end + end + local view = m.viewInfers(infers) + if view == 'any' then + return nil + end + return view +end + +return m diff --git a/script/core/keyword.lua b/script/core/keyword.lua index 71ea4969..73892f18 100644 --- a/script/core/keyword.lua +++ b/script/core/keyword.lua @@ -1,6 +1,6 @@ local define = require 'proto.define' -local guide = require 'core.guide' local files = require 'files' +local guide = require 'parser.guide' local keyWordMap = { {'do', function (info, results) diff --git a/script/core/noder.lua b/script/core/noder.lua new file mode 100644 index 00000000..43d349ee --- /dev/null +++ b/script/core/noder.lua @@ -0,0 +1,1007 @@ +local util = require 'utility' +local guide = require 'parser.guide' +local collector = require 'core.collector' + +local LastIDCache = {} +local FirstIDCache = {} +local SPLIT_CHAR = '\x1F' +local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$' +local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*' +local ANY_FIELD_CHAR = '*' +local INDEX_CHAR = '[' +local RETURN_INDEX = SPLIT_CHAR .. '#' +local PARAM_INDEX = SPLIT_CHAR .. '&' +local TABLE_KEY = SPLIT_CHAR .. '<' +local INDEX_FIELD = SPLIT_CHAR .. INDEX_CHAR +local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR +local URI_CHAR = '@' +local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(.*)' + +---@class node +-- 当前节点的id +---@field id string +-- 使用该ID的单元 +---@field sources parser.guide.object[] +-- 前进的关联ID +---@field forward string[] +-- 后退的关联ID +---@field backward string[] +-- 函数调用参数信息(用于泛型) +---@field call parser.guide.object + +---@alias noders table<string, node[]> + +---创建source的链接信息 +---@param noders noders +---@param id string +---@return node +local function getNode(noders, id) + if not noders[id] then + noders[id] = { + id = id, + } + end + return noders[id] +end + +---获取语法树单元的key +---@param source parser.guide.object +---@return string? key +---@return parser.guide.object? node +local function getKey(source) + if source.type == 'local' then + return tostring(source.start), nil + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return tostring(source.node.start), nil + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + local node = source.node + if node.tag == '_ENV' then + return ('%q'):format(source[1] or ''), nil + else + return ('%q'):format(source[1] or ''), node + end + elseif source.type == 'getfield' + or source.type == 'setfield' then + return ('%q'):format(source.field and source.field[1] or ''), source.node + elseif source.type == 'tablefield' then + return ('%q'):format(source.field and source.field[1] or ''), source.parent + elseif source.type == 'getmethod' + or source.type == 'setmethod' then + return ('%q'):format(source.method and source.method[1] or ''), source.node + elseif source.type == 'setindex' + or source.type == 'getindex' then + local index = source.index + if not index then + return INDEX_CHAR, source.node + end + if index.type == 'string' + or index.type == 'boolean' + or index.type == 'number' then + return ('%q'):format(index[1] or ''), source.node + else + return INDEX_CHAR, source.node + end + elseif source.type == 'tableindex' then + local index = source.index + if not index then + return ANY_FIELD_CHAR, source.parent + end + if index.type == 'string' + or index.type == 'boolean' + or index.type == 'number' then + return ('%q'):format(index[1] or ''), source.parent + elseif index.type ~= 'function' + and index.type ~= 'table' then + return ANY_FIELD_CHAR, source.parent + end + elseif source.type == 'table' then + return source.start, nil + elseif source.type == 'label' then + return source.start, nil + elseif source.type == 'goto' then + if source.node then + return source.node.start, nil + end + return nil, nil + elseif source.type == 'function' then + return source.start, nil + elseif source.type == 'string' then + return '', nil + elseif source.type == 'integer' + or source.type == 'number' + or source.type == 'boolean' + or source.type == 'nil' then + return source.start, nil + elseif source.type == '...' then + return source.start, nil + elseif source.type == 'varargs' then + if source.node then + return source.node.start, nil + end + elseif source.type == 'select' then + return ('%d%s%d'):format(source.start, RETURN_INDEX, source.sindex) + elseif source.type == 'call' then + local node = source.node + if node.special == 'rawget' + or node.special == 'rawset' then + if not source.args then + return nil, nil + end + local tbl, key = source.args[1], source.args[2] + if not tbl or not key then + return nil, nil + end + if key.type == 'string' then + return ('%q'):format(key[1] or ''), tbl + else + return '', tbl + end + end + return source.finish, nil + elseif source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.see.name' then + local name = source[1] + return name, nil + elseif source.type == 'doc.type.name' then + local name = source[1] + if source.typeGeneric then + return source.typeGeneric[name][1].start, nil + else + return name, nil + end + elseif source.type == 'doc.class' + or source.type == 'doc.type' + or source.type == 'doc.param' + or source.type == 'doc.vararg' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.table' + or source.type == 'doc.type.array' + or source.type == 'doc.type.function' then + return source.start, nil + elseif source.type == 'doc.see.field' then + return ('%q'):format(source[1]), source.parent.name + elseif source.type == 'generic.closure' then + return source.call.start, nil + elseif source.type == 'generic.value' then + return ('%s|%s'):format( + source.closure.call.start, + getKey(source.proto) + ) + end + return nil, nil +end + +local function checkMode(source) + if source.type == 'table' then + return 't:' + end + if source.type == 'select' then + return 's:' + end + if source.type == 'function' then + return 'f:' + end + if source.type == 'string' then + return 'str:' + end + if source.type == 'number' + or source.type == 'integer' + or source.type == 'boolean' + or source.type == 'nil' then + return 'i:' + end + if source.type == 'call' then + return 'c:' + end + if source.type == '...' + or source.type == 'varargs' then + return 'va:' + end + if source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' then + if source.typeGeneric then + return 'dg:' + end + return 'dn:' + end + if source.type == 'doc.field.name' then + return 'dfn:' + end + if source.type == 'doc.see.name' then + return 'dsn:' + end + if source.type == 'doc.class' then + return 'dc:' + end + if source.type == 'doc.type' then + return 'dt:' + end + if source.type == 'doc.param' then + return 'dp:' + end + if source.type == 'doc.type.function' then + return 'dfun:' + end + if source.type == 'doc.type.table' then + return 'dtable:' + end + if source.type == 'doc.type.array' then + return 'darray:' + end + if source.type == 'doc.vararg' then + return 'dv:' + end + if source.type == 'doc.type.enum' + or source.type == 'doc.resume' then + return 'de:' + end + if source.type == 'generic.closure' then + return 'gc:' + end + if source.type == 'generic.value' then + local id = 'gv:' + if guide.getUri(source.closure.call) ~= guide.getUri(source.proto) then + id = id .. URI_CHAR .. guide.getUri(source.closure.call) + end + return id + end + if guide.isGlobal(source) then + return 'g:' + end + if source.type == 'getlocal' + or source.type == 'setlocal' then + source = source.node + end + if source.parent.type == 'funcargs' then + return 'p:' + end + return 'l:' +end + +local IDList = {} +---获取语法树单元的字符串ID +---@param source parser.guide.object +---@return string? id +local function getID(source) + if not source then + return nil + end + if source._id ~= nil then + return source._id or nil + end + if source.type == 'field' + or source.type == 'method' then + source._id = false + return nil + end + local current = source + local index = 0 + while true do + if current.type == 'paren' then + current = current.exp + if not current then + break + end + goto CONTINUE + end + local id, node = getKey(current) + if not id then + break + end + index = index + 1 + IDList[index] = id + if not node then + break + end + current = node + if current.special == '_G' then + for i = index, 2, -1 do + if IDList[i] == '"_G"' then + IDList[i] = nil + end + end + break + end + ::CONTINUE:: + end + if index == 0 then + source._id = false + return nil + end + for i = index + 1, #IDList do + IDList[i] = nil + end + local mode = checkMode(current) + if not mode then + source._id = false + return nil + end + util.revertTable(IDList) + local id = mode .. table.concat(IDList, SPLIT_CHAR) + source._id = id + return id +end + +---添加关联的前进ID +---@param noders noders +---@param id string +---@param forwardID string +local function pushForward(noders, id, forwardID, tag) + if not id + or not forwardID + or forwardID == '' + or id == forwardID then + return + end + local node = getNode(noders, id) + if not node.forward then + node.forward = {} + end + if node.forward[forwardID] ~= nil then + return + end + node.forward[forwardID] = tag or false + node.forward[#node.forward+1] = forwardID +end + +---添加关联的后退ID +---@param noders noders +---@param id string +---@param backwardID string +local function pushBackward(noders, id, backwardID, tag) + if not id + or not backwardID + or backwardID == '' + or id == backwardID then + return + end + local node = getNode(noders, id) + if not node.backward then + node.backward = {} + end + if node.backward[backwardID] ~= nil then + return + end + node.backward[backwardID] = tag or false + node.backward[#node.backward+1] = backwardID +end + +local m = {} + +m.SPLIT_CHAR = SPLIT_CHAR +m.RETURN_INDEX = RETURN_INDEX +m.PARAM_INDEX = PARAM_INDEX +m.TABLE_KEY = TABLE_KEY +m.ANY_FIELD = ANY_FIELD +m.URI_CHAR = URI_CHAR +m.INDEX_FIELD = INDEX_FIELD + +--- 寻找doc的主体 +---@param obj parser.guide.object +---@return parser.guide.object +local function getDocStateWithoutCrossFunction(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return obj + end + if parent.type == 'doc' then + return obj + end + if parent.type == 'doc.type.function' then + return nil + end + obj = parent + end + error('guide.getDocState overstack') +end + +---添加关联单元 +---@param noders noders +---@param source parser.guide.object +function m.pushSource(noders, source) + local id = m.getID(source) + if not id then + return + end + local node = getNode(noders, id) + if not node.sources then + node.sources = {} + end + node.sources[#node.sources+1] = source +end + +local function bindValue(noders, source, id) + local value = source.value + if not value then + return + end + local valueID = getID(value) + if not valueID then + return + end + if source.type == 'getlocal' + or source.type == 'setlocal' then + source = source.node + end + if source.bindDocs and value.type ~= 'table' then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' + or doc.type == 'doc.type' then + return + end + end + end + -- x = y : x -> y + pushForward(noders, id, valueID, 'set') + -- 参数/call禁止反向查找赋值 + local valueType = valueID:match '^.-:' + if valueType ~= 'p:' + and valueType ~= 's:' + and valueType ~= 'c:' then + pushBackward(noders, valueID, id, 'set') + end +end + +local function compileCall(noders, call, sourceID, returnIndex) + if not sourceID then + return + end + local node = call.node + local nodeID = getID(node) + if not nodeID then + return + end + local callID = getID(call) + if not callID then + return + end + -- 将setmetatable映射到 param1 以及 param2.__index 上 + if node.special == 'setmetatable' then + local tblID = getID(call.args and call.args[1]) + local metaID = getID(call.args and call.args[2]) + local indexID + if metaID then + indexID = ('%s%s%q'):format( + metaID, + SPLIT_CHAR, + '__index' + ) + end + pushForward(noders, sourceID, tblID) + pushForward(noders, sourceID, indexID) + pushBackward(noders, tblID, sourceID) + return + --pushBackward(noders, indexID, callID) + end + if node.special == 'require' then + local arg1 = call.args and call.args[1] + if arg1 and arg1.type == 'string' then + getNode(noders, sourceID).require = arg1[1] + end + return + end + if node.special == 'pcall' + or node.special == 'xpcall' then + local index = returnIndex - 1 + if index <= 0 then + return + end + local funcID = call.args and getID(call.args[1]) + if not funcID then + return + end + local pfuncXID = ('%s%s%s'):format( + funcID, + RETURN_INDEX, + index + ) + pushForward(noders, sourceID, pfuncXID) + --pushBackward(noders, funcXID, id) + return + end + local funcXID = ('%s%s%s'):format( + nodeID, + RETURN_INDEX, + returnIndex + ) + getNode(noders, sourceID).call = call + pushForward(noders, sourceID, funcXID) +end + +---@param uri uri +---@param noders noders +---@param source parser.guide.object +---@return parser.guide.object[] +function m.compileNode(uri, noders, source) + local id = getID(source) + bindValue(noders, source, id) + if source.special == 'setmetatable' + or source.special == 'require' + or source.special == 'dofile' + or source.special == 'loadfile' + or source.special == 'rawset' + or source.special == 'rawget' then + local node = getNode(noders, id) + node.skip = true + end + -- self -> mt:xx + if source.type == 'local' and source[1] == 'self' then + local func = guide.getParentFunction(source) + if func.isGeneric then + return + end + if source.parent.type ~= 'funcargs' then + return + end + local setmethod = func.parent + -- guess `self` + if setmethod and ( setmethod.type == 'setmethod' + or setmethod.type == 'setfield' + or setmethod.type == 'setindex') then + pushForward(noders, id, getID(setmethod.node), 'method') + --pushBackward(noders, getID(setmethod.node), id, 'method') + end + end + -- 分解 @type + if source.type == 'doc.type' then + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) + end + end + for _, enumUnit in ipairs(source.enums) do + pushForward(noders, id, getID(enumUnit)) + end + for _, resumeUnit in ipairs(source.resumes) do + pushForward(noders, id, getID(resumeUnit)) + end + for _, typeUnit in ipairs(source.types) do + local unitID = getID(typeUnit) + pushForward(noders, id, unitID) + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushBackward(noders, unitID, getID(src)) + end + end + end + end + -- 分解 @alias + if source.type == 'doc.alias' then + pushForward(noders, getID(source.alias), getID(source.extends)) + end + -- 分解 @class + if source.type == 'doc.class' then + pushForward(noders, id, getID(source.class)) + pushForward(noders, getID(source.class), id) + if source.extends then + for _, ext in ipairs(source.extends) do + pushBackward(noders, id, getID(ext)) + end + end + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) + end + end + for _, field in ipairs(source.fields) do + local key = field.field[1] + if key then + local keyID = ('%s%s%q'):format( + id, + SPLIT_CHAR, + key + ) + pushForward(noders, keyID, getID(field.field)) + pushForward(noders, getID(field.field), keyID) + pushForward(noders, keyID, getID(field.extends)) + pushBackward(noders, getID(field.extends), keyID) + end + end + end + if source.type == 'doc.param' then + pushForward(noders, id, getID(source.extends)) + for _, src in ipairs(source.bindSources) do + if src.type == 'local' and src.parent.type == 'in' then + pushForward(noders, getID(src), id) + end + end + end + if source.type == 'doc.vararg' then + pushForward(noders, getID(source), getID(source.vararg)) + end + if source.type == 'doc.see' then + local nameID = getID(source.name) + local classID = nameID:gsub('^dsn:', 'dn:') + pushForward(noders, nameID, classID) + if source.field then + local fieldID = getID(source.field) + local fieldClassID = fieldID:gsub('^dsn:', 'dn:') + pushForward(noders, fieldID, fieldClassID) + end + end + if source.type == 'call' then + if source.parent.type ~= 'select' then + compileCall(noders, source, id, 1) + end + end + if source.type == 'select' then + if source.vararg.type == 'call' then + local call = source.vararg + compileCall(noders, call, id, source.sindex) + end + if source.vararg.type == 'varargs' then + pushForward(noders, id, getID(source.vararg)) + end + end + if source.type == 'doc.type.function' then + if source.returns then + for index, rtn in ipairs(source.returns) do + local returnID = ('%s%s%s'):format( + id, + RETURN_INDEX, + index + ) + pushForward(noders, returnID, getID(rtn)) + end + end + -- @type fun(x: T):T 的情况 + local docType = getDocStateWithoutCrossFunction(source) + if docType and docType.type == 'doc.type' then + guide.eachSourceType(source, 'doc.type.name', function (typeName) + if typeName.typeGeneric then + source.isGeneric = true + return false + end + end) + end + end + if source.type == 'doc.type.table' then + if source.tkey then + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, getID(source.tkey)) + end + if source.tvalue then + local valueID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, valueID, getID(source.tvalue)) + end + end + if source.type == 'doc.type.array' then + if source.node then + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source.node)) + end + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, 'dn:integer') + end + if source.type == 'doc.type.name' then + collector.subscribe(uri, id, getNode(noders, id)) + end + if source.type == 'doc.class.name' then + collector.subscribe(uri, id, getNode(noders, id)) + collector.subscribe(uri, 'def:' .. id, getNode(noders, id)) + collector.subscribe(uri, 'def:dn', getNode(noders, id)) + end + if source.type == 'doc.alias.name' then + collector.subscribe(uri, id, getNode(noders, id)) + collector.subscribe(uri, 'def:' .. id, getNode(noders, id)) + collector.subscribe(uri, 'def:dn', getNode(noders, id)) + end + if guide.isGlobal(source) then + collector.subscribe(uri, id, getNode(noders, id)) + if guide.isSet(source) then + collector.subscribe(uri, 'def:' .. id, getNode(noders, id)) + collector.subscribe(uri, 'def:g', getNode(noders, id)) + end + end + -- 将函数的返回值映射到具体的返回值上 + if source.type == 'function' then + local hasDocReturn = {} + -- 检查 luadoc + if source.bindDocs then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + local fullID = ('%s%s%s'):format( + id, + RETURN_INDEX, + rtn.returnIndex + ) + pushForward(noders, fullID, getID(rtn)) + hasDocReturn[rtn.returnIndex] = true + end + end + if doc.type == 'doc.param' then + local paramName = doc.param[1] + if source.docParamMap then + local paramIndex = source.docParamMap[paramName] + local param = source.args[paramIndex] + if param then + pushForward(noders, getID(param), getID(doc)) + param.docParam = doc + end + end + end + if doc.type == 'doc.vararg' then + for _, param in ipairs(source.args) do + if param.type == '...' then + pushForward(noders, getID(param), getID(doc)) + end + end + end + if doc.type == 'doc.generic' then + source.isGeneric = true + end + if doc.type == 'doc.overload' then + pushForward(noders, id, getID(doc.overload)) + end + end + end + -- 检查实体返回值 + if source.returns then + local returns = {} + for _, rtn in ipairs(source.returns) do + for index, rtnObj in ipairs(rtn) do + if not hasDocReturn[index] then + if not returns[index] then + returns[index] = {} + end + returns[index][#returns[index]+1] = rtnObj + end + end + end + for index, rtnObjs in ipairs(returns) do + local returnID = ('%s%s%s'):format( + id, + RETURN_INDEX, + index + ) + for _, rtnObj in ipairs(rtnObjs) do + pushForward(noders, returnID, getID(rtnObj)) + if rtnObj.type == 'function' + or rtnObj.type == 'call' then + --pushBackward(noders, getID(rtnObj), returnID) + end + end + end + end + end + if source.type == 'table' then + if #source == 1 and source[1].type == 'varargs' then + source.array = source[1] + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source[1])) + end + end + if source.type == 'main' then + if source.returns then + for _, rtn in ipairs(source.returns) do + local rtnObj = rtn[1] + if rtnObj then + pushForward(noders, 'mainreturn', getID(rtnObj)) + --pushBackward(noders, getID(rtnObj), 'mainreturn') + end + end + end + end + if source.type == 'generic.closure' then + for i, rtn in ipairs(source.returns) do + local closureID = ('%s%s%s'):format( + id, + RETURN_INDEX, + i + ) + local returnID = getID(rtn) + pushForward(noders, closureID, returnID) + end + end + if source.type == 'generic.value' then + local proto = source.proto + local closure = source.closure + local upvalues = closure.upvalues + if proto.type == 'doc.type.name' then + local key = proto[1] + if upvalues[key] then + for _, paramID in ipairs(upvalues[key]) do + pushForward(noders, id, paramID) + pushBackward(noders, paramID, id) + end + end + end + if proto.type == 'doc.type' then + for _, tp in ipairs(source.types) do + pushForward(noders, id, getID(tp)) + pushBackward(noders, getID(tp), id) + end + end + if proto.type == 'doc.type.array' then + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source.node)) + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, 'dn:integer') + end + if proto.type == 'doc.type.table' then + if source.tkey then + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, getID(source.tkey)) + end + if source.tvalue then + local valueID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, valueID, getID(source.tvalue)) + end + end + end +end + +---根据ID来获取所有的node +---@param root parser.guide.object +---@param id string +---@return node? +function m.getNodeByID(root, id) + root = guide.getRoot(root) + local noders = root._noders + if not noders then + return nil + end + return noders[id] +end + +---根据ID来获取第一个节点的ID +---@param id string +---@return string +function m.getFirstID(id) + if FirstIDCache[id] then + return FirstIDCache[id] or nil + end + local firstID, count = id:match(FIRST_REGEX) + if count == 0 then + FirstIDCache[id] = false + return nil + end + FirstIDCache[id] = firstID + return firstID +end + +---根据ID来获取上个节点的ID +---@param id string +---@return string +function m.getLastID(id) + if LastIDCache[id] then + return LastIDCache[id] or nil + end + local lastID, count = id:gsub(LAST_REGEX, '') + if count == 0 then + LastIDCache[id] = false + return nil + end + LastIDCache[id] = lastID + return lastID +end + +---测试id是否包含field,如果遇到函数调用则中断 +---@param id string +---@return boolean +function m.hasField(id) + local firstID = m.getFirstID(id) + if firstID == id then + return false + end + local nextChar = id:sub(#firstID + 1, #firstID + 1) + if nextChar ~= SPLIT_CHAR then + return false + end + local next2Char = id:sub(#firstID + 2, #firstID + 2) + if next2Char == RETURN_INDEX + or next2Char == PARAM_INDEX then + return false + end + return true +end + +---把形如 `@file:\\\XXXXX@gv:1|1`拆分成uri与id +---@param id string +---@return uri? string +---@return string id +function m.getUriAndID(id) + local uri, newID = id:match(URI_REGEX) + return uri, newID +end + +---获取source的ID +---@param source parser.guide.object +---@return string +function m.getID(source) + return getID(source) +end + +---获取source的key +---@param source parser.guide.object +---@return string +function m.getKey(source) + return getKey(source) +end + +---清除临时id(用于泛型的临时对象) +---@param root parser.guide.object +---@param id string +function m.removeID(root, id) + root = guide.getRoot(root) + local noders = root._noders + noders[id] = nil +end + +---寻找doc的主体 +---@param doc parser.guide.object +function m.getDocState(doc) + return getDocStateWithoutCrossFunction(doc) +end + +---获取对象的noders +---@param source parser.guide.object +---@return noders +function m.getNoders(source) + local root = guide.getRoot(source) + if not root._noders then + root._noders = {} + end + return root._noders +end + +---编译整个文件的node +---@param source parser.guide.object +---@return table +function m.compileNodes(source) + local root = guide.getRoot(source) + local noders = m.getNoders(source) + if next(noders) then + return noders + end + local uri = guide.getUri(root) + collector.dropUri(uri) + log.debug('compileNodes:', guide.getUri(root)) + guide.eachSource(root, function (src) + m.pushSource(noders, src) + m.compileNode(uri, noders, src) + end) + log.debug('compileNodes finish:', guide.getUri(root)) + return noders +end + +return m diff --git a/script/core/reference.lua b/script/core/reference.lua index 7620b09e..6ea79f5f 100644 --- a/script/core/reference.lua +++ b/script/core/reference.lua @@ -1,4 +1,5 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' @@ -6,8 +7,8 @@ local findSource = require 'core.find-source' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -19,7 +20,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else @@ -52,7 +53,7 @@ local accept = { } return function (uri, offset) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end @@ -64,11 +65,23 @@ return function (uri, offset) local metaSource = vm.isMetaFile(uri) + local refs = vm.getRefs(source) + local values = {} + for _, src in ipairs(refs) do + local value = searcher.getObjectValue(src) + if value and value ~= src and guide.isLiteral(value) then + values[value] = true + end + end + local results = {} - for _, src in ipairs(vm.getRefs(source, 5)) do + for _, src in ipairs(refs) do if src.dummy then goto CONTINUE end + if values[src] then + goto CONTINUE + end local root = guide.getRoot(src) if not root then goto CONTINUE diff --git a/script/core/rename.lua b/script/core/rename.lua index da82b0a6..bc85ac14 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -1,11 +1,11 @@ local files = require 'files' local vm = require 'vm' -local guide = require 'core.guide' local proto = require 'proto' local define = require 'proto.define' local util = require 'utility' local findSource = require 'core.find-source' -local ws = require 'workspace' +local guide = require 'parser.guide' +local noder = require 'core.noder' local Forcing @@ -185,7 +185,7 @@ local function renameField(source, newname, callback) end callback(source, source.start, source.finish, newname) elseif parent.type == 'setmethod' then - local uri = guide.getUri(source) + local uri = guide.getUri(source) local text = files.getText(uri) local func = parent.value -- function mt:name () end --> mt['newname'] = function (self) end @@ -292,14 +292,14 @@ local function ofField(source, newname, callback) else node = source.node end - for _, src in ipairs(vm.getFields(node, 5)) do + for _, src in ipairs(vm.getRefs(node, '*')) do ofFieldThen(key, src, newname, callback) end end local function ofGlobal(source, newname, callback) local key = guide.getKeyName(source) - for _, src in ipairs(vm.getRefs(source, 0)) do + for _, src in ipairs(vm.getRefs(source)) do ofFieldThen(key, src, newname, callback) end end @@ -308,24 +308,27 @@ local function ofLabel(source, newname, callback) if not isValidName(newname) and not askForcing(newname)then return false end - for _, src in ipairs(vm.getRefs(source, 0)) do + for _, src in ipairs(vm.getRefs(source)) do callback(src, src.start, src.finish, newname) end end local function ofDocTypeName(source, newname, callback) - for _, doc in ipairs(vm.getDocTypes(source[1])) do + local oldname = source[1] + for _, doc in ipairs(vm.getRefs(source)) do if doc.type == 'doc.class.name' or doc.type == 'doc.type.name' or doc.type == 'doc.alias.name' then - callback(doc, doc.start, doc.finish, newname) + if oldname == doc[1] then + callback(doc, doc.start, doc.finish, newname) + end end end end local function ofDocParamName(source, newname, callback) callback(source, source.start, source.finish, newname) - local doc = guide.getDocState(source) + local doc = noder.getDocState(source) if doc.bindSources then for _, src in ipairs(doc.bindSources) do if src.type == 'local' @@ -440,7 +443,7 @@ function m.rename(uri, pos, newname) if not newname then return nil end - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end @@ -489,7 +492,7 @@ function m.rename(uri, pos, newname) end function m.prepareRename(uri, pos) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end diff --git a/script/core/searcher.lua b/script/core/searcher.lua new file mode 100644 index 00000000..5a417765 --- /dev/null +++ b/script/core/searcher.lua @@ -0,0 +1,838 @@ +local noder = require 'core.noder' +local guide = require 'parser.guide' +local files = require 'files' +local generic = require 'core.generic' +local ws = require 'workspace' +local vm = require 'vm.vm' +local await = require 'await' +local collector = require 'core.collector' + +local NONE = {'NONE'} +local LAST = {'LAST'} + +local ignoredIDs = { + ['dn:unknown'] = true, + ['dn:nil'] = true, + ['dn:any'] = true, + ['dn:boolean'] = true, + ['dn:string'] = true, + ['dn:table'] = true, + ['dn:number'] = true, + ['dn:integer'] = true, + ['dn:userdata'] = true, + ['dn:lightuserdata'] = true, + ['dn:function'] = true, + ['dn:thread'] = true, +} + +local m = {} + +---@alias guide.searchmode '"ref"'|'"def"' + +---添加结果 +---@param status guide.status +---@param mode guide.searchmode +---@param source parser.guide.object +---@param force boolean +function m.pushResult(status, mode, source, force) + if not source then + return + end + local results = status.results + local mark = status.mark + if mark[source] then + return + end + mark[source] = true + if force then + results[#results+1] = source + return + end + local parent = source.parent + if mode == 'def' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'label' + or source.type == 'setfield' + or source.type == 'setmethod' + or source.type == 'setindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if noder.getID(source) ~= status.id then + results[#results+1] = source + end + end + elseif mode == 'ref' or mode == 'field' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'label' + or source.type == 'goto' + or source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' + or source.type == 'setindex' + or source.type == 'getindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' + or source.type == 'nil' + or source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if noder.getID(source) ~= status.id then + results[#results+1] = source + end + end + end +end + +---获取uri +---@param obj parser.guide.object +---@return uri +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = guide.getRoot(obj) + if root then + return root.uri + end + return '' +end + +---@param obj parser.guide.object +---@return parser.guide.object? +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent and obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args and obj.args[3] + else + return obj + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +local function crossSearch(status, uri, expect, mode, sourceUri) + if status.lock[uri] then + return + end + status.lock[uri] = true + await.delay() + if TRACE then + log.debug('crossSearch', uri, expect) + end + if FOOTPRINT then + status.footprint[#status.footprint+1] = ('cross search:%s %s'):format(uri, expect) + end + m.searchRefsByID(status, uri, expect, mode) + status.lock[uri] = nil + if FOOTPRINT then + status.footprint[#status.footprint+1] = ('cross search finish, back to: %s'):format(sourceUri) + end +end + +local function checkCache(status, uri, expect, mode) + local cache = vm.getCache('search:' .. mode) + local fileCache = cache[uri] + if not fileCache then + fileCache = {} + cache[uri] = fileCache + end + if fileCache[expect] then + for _, res in ipairs(fileCache[expect]) do + m.pushResult(status, mode, res, true) + end + return true + end + fileCache[expect] = status.results + return false +end + +function m.searchRefsByID(status, uri, expect, mode) + local ast = files.getState(uri) + if not ast then + return + end + local root = ast.ast + local searchStep + noder.compileNodes(root) + + status.id = expect + + local callStack = status.callStack + + local mark = {} + + local function search(id, field) + local firstID = noder.getFirstID(id) + if ignoredIDs[firstID] and (field or firstID ~= id) then + return + end + local cmark = mark[id] + if not cmark then + cmark = {} + mark[id] = cmark + end + if cmark[field or NONE] then + return + end + if TRACE then + log.debug('search:', id, field) + end + if FOOTPRINT then + if field then + status.footprint[#status.footprint+1] = 'search\t' .. id .. '\t' .. field + else + status.footprint[#status.footprint+1] = 'search\t' .. id + end + end + cmark[field or NONE] = true + searchStep(id, field) + if TRACE then + log.debug('pop:', id, field) + end + if FOOTPRINT then + if field then + status.footprint[#status.footprint+1] = 'pop\t' .. id .. '\t' .. field + else + status.footprint[#status.footprint+1] = 'pop\t' .. id + end + end + end + + local function checkLastID(id, field) + local cmark = mark[id] + if not cmark then + cmark = {} + mark[id] = cmark + end + if cmark[LAST] then + return + end + local lastID = noder.getLastID(id) + if not lastID then + return + end + local newField = id:sub(#lastID + 1) + if field then + newField = newField .. field + end + cmark[LAST] = true + search(lastID, newField) + return lastID + end + + local function searchID(id, field) + if not id then + return + end + if field then + id = id .. field + end + search(id, nil) + end + + local function isCallID(field) + if not field then + return false + end + if field:sub(1, 2) == noder.RETURN_INDEX then + return true + end + return false + end + + local function findLastCall() + for i = #callStack, 1, -1 do + local call = callStack[i] + if call then + -- 标记此处的call失效,等待在堆栈平衡时弹出 + callStack[i] = false + return call + end + end + return nil + end + + local genericCallArgs = {} + local closureCache = {} + local function checkGeneric(source, field) + if not source.isGeneric then + return + end + if not isCallID(field) then + return + end + local call = findLastCall() + if not call then + return + end + + if call.args then + for _, arg in ipairs(call.args) do + genericCallArgs[arg] = true + end + end + + local cacheID = noder.getID(source) .. noder.getID(call) + local closure = closureCache[cacheID] + if closure == false then + return + end + if not closure then + closure = generic.createClosure(source, call) + closureCache[cacheID] = closure or false + if not closure then + return + end + end + local id = noder.getID(closure) + searchID(id, field) + end + + local function checkENV(source, field) + if not field then + return + end + if source.special ~= '_G' then + return + end + local newID = 'g:' .. field:sub(2) + searchID(newID) + end + + local function checkThenPushTag(ward, tag) + if not tag then + return true + end + local checkTags + local pushTags + if ward == 'forward' then + checkTags = status.btag + pushTags = status.ftag + else + checkTags = status.ftag + pushTags = status.btag + end + if checkTags[tag] and checkTags[tag] > 0 then + return false + end + pushTags[tag] = (pushTags[tag] or 0) + 1 + return true + end + + local function popTag(ward, tag) + if not tag then + return + end + local popTags + if ward == 'forward' then + popTags = status.ftag + else + popTags = status.btag + end + popTags[tag] = popTags[tag] - 1 + end + + local function checkForward(id, node, field) + for _, forwardID in ipairs(node.forward) do + local tag = node.forward[forwardID] + if not checkThenPushTag('forward', tag) then + goto CONTINUE + end + local targetUri, targetID = noder.getUriAndID(forwardID) + if targetUri and not files.eq(targetUri, uri) then + crossSearch(status, targetUri, targetID .. (field or ''), mode, uri) + else + searchID(targetID or forwardID, field) + end + popTag('forward', tag) + ::CONTINUE:: + end + end + + local function checkBackward(id, node, field) + if mode ~= 'ref' and mode ~= 'field' and not field then + return + end + for _, backwardID in ipairs(node.backward) do + local tag = node.backward[backwardID] + if not checkThenPushTag('backward', tag) then + goto CONTINUE + end + local targetUri, targetID = noder.getUriAndID(backwardID) + if targetUri and not files.eq(targetUri, uri) then + crossSearch(status, targetUri, targetID .. (field or ''), mode, uri) + else + searchID(targetID or backwardID, field) + end + popTag('backward', tag) + ::CONTINUE:: + end + end + + local function checkSpecial(id, node, field) + -- Special rule: ('').XX -> stringlib.XX + if id == 'str:' + or id == 'dn:string' then + if field or mode == 'field' then + searchID('dn:stringlib', field) + end + return true + end + return false + end + + local function checkRequire(requireName, field) + local tid = 'mainreturn' .. (field or '') + local uris = ws.findUrisByRequirePath(requireName) + if FOOTPRINT then + status.footprint[#status.footprint+1] = ('require %q:\n%s'):format(requireName, table.concat(uris, '\n')) + end + for _, ruri in ipairs(uris) do + if not files.eq(uri, ruri) then + crossSearch(status, ruri, tid, mode, uri) + end + end + end + + local function checkGlobal(id, node, field) + if status.crossed[id] then + return + end + status.crossed[id] = true + --if not checkThenPushTag('forward', 'set') then + -- return + --end + local isCall = field and field:sub(2, 2) == noder.RETURN_INDEX + local tid = id .. (field or '') + if FOOTPRINT then + status.footprint[#status.footprint+1] = ('checkGlobal:%s + %s, isCall: %s'):format(id, field, isCall, tid) + end + for guri, def in collector.each(id) do + if def then + crossSearch(status, guri, tid, mode, uri) + goto CONTINUE + end + if isCall then + goto CONTINUE + end + if not field then + goto CONTINUE + end + if mode == 'def' then + goto CONTINUE + end + if not files.eq(uri, guri) then + goto CONTINUE + end + crossSearch(status, guri, tid, mode, uri) + ::CONTINUE:: + end + --popTag('forward', 'set') + end + + local function checkClass(id, node, field) + if status.crossed[id] then + return + end + status.crossed[id] = true + local tid = id .. (field or '') + for guri in collector.each(id) do + if not files.eq(uri, guri) then + crossSearch(status, guri, tid, mode, uri) + end + end + end + + local function searchNode(id, node, field) + if node.call then + callStack[#callStack+1] = node.call + end + if field == nil and node.sources then + for _, source in ipairs(node.sources) do + local force = genericCallArgs[source] + m.pushResult(status, mode, source, force) + end + end + + if node.require then + checkRequire(node.require, field) + return + end + + local isSepcial = checkSpecial(id, node, field) + if not isSepcial then + if node.forward then + checkForward(id, node, field) + end + if node.backward then + checkBackward(id, node, field) + end + end + + if node.sources then + checkGeneric(node.sources[1], field) + checkENV(node.sources[1], field) + end + + --checkMainReturn(id, node, field) + + if node.call then + callStack[#callStack] = nil + end + + return false + end + + local function checkAnyField(id, field) + if mode == 'ref' or mode == 'field' then + return + end + local lastID = noder.getLastID(id) + if not lastID then + return + end + local originField = id:sub(#lastID + 1) + if originField == noder.TABLE_KEY then + return + end + local anyFieldID = lastID .. noder.ANY_FIELD + local anyFieldNode = noder.getNodeByID(root, anyFieldID) + if anyFieldNode then + searchNode(anyFieldID, anyFieldNode, field) + end + end + + local stepCount = 0 + function searchStep(id, field) + stepCount = stepCount + 1 + status.count = status.count + 1 + if stepCount > 1000 + or status.count > 10000 then + if TEST then + if FOOTPRINT then + log.debug(table.concat(status.footprint, '\n')) + end + error('too large!') + else + log.warn('too large!') + if FOOTPRINT then + log.debug(table.concat(status.footprint, '\n')) + end + return + end + end + local node = noder.getNodeByID(root, id) + if node then + searchNode(id, node, field) + if node.skip and field then + return + end + end + checkGlobal(id, node, field) + checkClass(id, node, field) + checkLastID(id, field) + checkAnyField(id, field) + end + + search(expect) + + --清除来自泛型的临时对象 + for _, closure in pairs(closureCache) do + noder.removeID(root, noder.getID(closure)) + if closure then + for _, value in ipairs(closure.values) do + noder.removeID(root, noder.getID(value)) + end + end + end +end + +local function prepareSearch(source) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local root = guide.getRoot(source) + noder.compileNodes(root) + local uri = guide.getUri(source) + local id = noder.getID(source) + return uri, id +end + +local function getField(status, source, mode) + if source.type == 'table' then + for _, field in ipairs(source) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + m.pushResult(status, mode, field) + end + end + return + end + if source.type == 'doc.class.name' then + local class = source.parent + for _, field in ipairs(class.fields) do + m.pushResult(status, mode, field.field) + end + return + end + local field = source.next + if field then + if field.type == 'getmethod' + or field.type == 'setmethod' + or field.type == 'getfield' + or field.type == 'setfield' + or field.type == 'getindex' + or field.type == 'setindex' then + m.pushResult(status, mode, field) + end + return + end +end + +local function searchAllGlobalByUri(status, mode, uri, fullID) + local ast = files.getState(uri) + if not ast then + return + end + local root = ast.ast + noder.compileNodes(root) + local noders = noder.getNoders(root) + if fullID then + for id, node in pairs(noders) do + if node.sources + and id == fullID then + for _, source in ipairs(node.sources) do + m.pushResult(status, mode, source) + end + end + end + else + for id, node in pairs(noders) do + if node.sources + and id:sub(1, 2) == 'g:' + and not id:find(noder.SPLIT_CHAR) then + for _, source in ipairs(node.sources) do + m.pushResult(status, mode, source) + end + end + end + end +end + +local function searchAllGlobals(status, mode) + for uri in files.eachFile() do + searchAllGlobalByUri(status, mode, uri) + end +end + +---查找全局变量 +---@param uri uri +---@param mode guide.searchmode +---@param name string +---@return parser.guide.object[] +function m.findGlobals(uri, mode, name) + local status = m.status(mode) + + if name then + local fullID = ('g:%q'):format(name) + searchAllGlobalByUri(status, mode, uri, fullID) + else + searchAllGlobalByUri(status, mode, uri) + end + + return status.results +end + +---搜索对象的引用 +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +function m.searchRefs(status, source, mode) + local uri, id = prepareSearch(source) + if not id then + return + end + + if checkCache(status, uri, id, mode) then + return + end + + if TRACE then + log.debug('searchRefs:', id) + end + m.searchRefsByID(status, uri, id, mode) +end + +---搜索对象的field +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +---@param field string +function m.searchFields(status, source, mode, field) + local uri, id = prepareSearch(source) + if not id then + return + end + if TRACE then + log.debug('searchFields:', id, field) + end + if field == '*' then + if source.special == '_G' then + if checkCache(status, uri, '*', mode) then + return + end + searchAllGlobals(status, mode) + else + if checkCache(status, uri, id .. '*', mode) then + return + end + local newStatus = m.status('field') + m.searchRefsByID(newStatus, uri, id, 'field') + for _, def in ipairs(newStatus.results) do + getField(status, def, mode) + end + end + else + if source.special == '_G' then + local fullID = ('g:%q'):format(field) + if checkCache(status, uri, fullID, mode) then + return + end + m.searchRefsByID(status, uri, fullID, mode) + else + local fullID = ('%s%s%q'):format(id, noder.SPLIT_CHAR, field) + if checkCache(status, uri, fullID, mode) then + return + end + m.searchRefsByID(status, uri, fullID, mode) + end + end +end + +---@class guide.status +---搜索结果 +---@field results parser.guide.object[] + +---创建搜索状态 +---@param mode guide.searchmode +---@return guide.status +function m.status(mode) + local status = { + callStack = {}, + crossed = {}, + lock = {}, + results = {}, + mark = {}, + footprint = {}, + count = 0, + ftag = {}, + btag = {}, + cache = vm.getCache('searcher:' .. mode) + } + return status +end + +--- 请求对象的引用 +---@param obj parser.guide.object +---@param field? string +---@return parser.guide.object[] +function m.requestReference(obj, field) + local status = m.status('ref') + + if field then + m.searchFields(status, obj, 'ref', field) + else + m.searchRefs(status, obj, 'ref') + end + + return status.results +end + +--- 请求对象的定义 +---@param obj parser.guide.object +---@param field? string +---@return parser.guide.object[] +function m.requestDefinition(obj, field) + local status = m.status('def') + + if field then + m.searchFields(status, obj, 'def', field) + else + m.searchRefs(status, obj, 'def') + end + + return status.results +end + +return m diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index f8feaa09..f310e3f1 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -1,9 +1,10 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local await = require 'await' local define = require 'proto.define' local vm = require 'vm' local util = require 'utility' +local guide = require 'parser.guide' local Care = {} Care['setglobal'] = function (source, results) @@ -212,7 +213,7 @@ local function buildTokens(uri, results) end return function (uri, start, finish) - local ast = files.getAst(uri) + local ast = files.getState(uri) local lines = files.getLines(uri) local text = files.getText(uri) if not ast then diff --git a/script/core/signature.lua b/script/core/signature.lua index eb740784..8de1c374 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -1,8 +1,9 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local hoverLabel = require 'core.hover.label' local hoverDesc = require 'core.hover.description' +local guide = require 'parser.guide' local function findNearCall(uri, ast, pos) local text = files.getText(uri) @@ -96,10 +97,10 @@ local function makeSignatures(call, pos) index = 1 end local signs = {} - local defs = vm.getDefs(node, 0) + local defs = vm.getDefs(node) local mark = {} for _, src in ipairs(defs) do - src = guide.getObjectValue(src) or src + src = searcher.getObjectValue(src) or src if src.type == 'function' or src.type == 'doc.type.function' then if not mark[src] then @@ -132,7 +133,7 @@ local function skipSpace(text, offset) end return function (uri, pos) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return nil end diff --git a/script/core/type-formatting.lua b/script/core/type-formatting.lua index c2290ef3..b01a1999 100644 --- a/script/core/type-formatting.lua +++ b/script/core/type-formatting.lua @@ -1,6 +1,6 @@ local files = require 'files' local lookBackward = require 'core.look-backward' -local guide = require 'core.guide' +local guide = require "parser.guide" local function insertIndentation(uri, offset, edits) local lines = files.getLines(uri) @@ -69,7 +69,7 @@ local function checkSplitOneLine(results, uri, offset, ch) end return function (uri, offset, ch) - local ast = files.getAst(uri) + local ast = files.getState(uri) local text = files.getOriginText(uri) if not ast or not text then return nil diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua index ae420d32..18ab1eeb 100644 --- a/script/core/workspace-symbol.lua +++ b/script/core/workspace-symbol.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local matchKey = require 'core.matchkey' local define = require 'proto.define' local await = require 'await' @@ -47,7 +47,7 @@ local function buildSource(uri, source, key, results) end local function searchFile(uri, key, results) - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return end diff --git a/script/files.lua b/script/files.lua index 9cc6b549..a653b364 100644 --- a/script/files.lua +++ b/script/files.lua @@ -9,7 +9,7 @@ local await = require 'await' local timer = require 'timer' local plugin = require 'plugin' local util = require 'utility' -local guide = require 'core.guide' +local guide = require 'parser.guide' local smerger = require 'string-merger' local progress = require "progress" @@ -34,9 +34,9 @@ m.assocMatcher = nil m.globalVersion = 0 m.fileCount = 0 m.astCount = 0 -m.linesMap = setmetatable({}, { __mode = 'v' }) -m.originLinesMap = setmetatable({}, { __mode = 'v' }) -m.astMap = setmetatable({}, { __mode = 'v' }) +m.linesMap = {} --setmetatable({}, { __mode = 'v' }) +m.originLinesMap = {} --setmetatable({}, { __mode = 'v' }) +m.astMap = {} --setmetatable({}, { __mode = 'v' }) local uriMap = {} local function getUriKey(uri) @@ -345,6 +345,7 @@ function m.getAllUris() i = i + 1 files[i] = uri end + table.sort(files) end return m._pairsCache end @@ -377,7 +378,7 @@ function m.eachDll() return pairs(map) end -function m.compileAst(uri, text) +function m.compileState(uri, text) local ws = require 'workspace' if not m.isOpen(uri) and #text >= config.config.workspace.preloadFileSize * 1000 then if not m.notifyCache['preloadFileSize'] then @@ -445,8 +446,8 @@ end --- 获取文件语法树 ---@param uri uri ----@return table ast -function m.getAst(uri) +---@return table state +function m.getState(uri) uri = getUriKey(uri) if uri ~= '' and not m.isLua(uri) then return nil @@ -457,7 +458,7 @@ function m.getAst(uri) end local ast = m.astMap[uri] if not ast then - ast = m.compileAst(uri, file.text) + ast = m.compileState(uri, file.text) m.astMap[uri] = ast --await.delay() end diff --git a/script/parser/ast.lua b/script/parser/ast.lua index 45d77631..40b5788e 100644 --- a/script/parser/ast.lua +++ b/script/parser/ast.lua @@ -110,7 +110,7 @@ local function getSelect(vararg, index) start = vararg.start, finish = vararg.finish, vararg = vararg, - index = index, + sindex = index, } end @@ -1460,8 +1460,14 @@ local Defs = { local values if func then local call = createCall(exp, func.finish + 1, exp.finish) + if #exp == 0 then + exp[1] = getSelect(func, 2) + exp[2] = getSelect(func, 3) + exp[3] = getSelect(func, 4) + end call.node = func - call.start = func.start + call.start = inA + call.finish = doB - 1 func.next = call func.iterator = true values = { call } diff --git a/script/parser/compile.lua b/script/parser/compile.lua index a7e0dc1f..21be406d 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -125,6 +125,7 @@ local vmMap = { vararg.ref = {} end vararg.ref[#vararg.ref+1] = obj + obj.node = vararg end end end, @@ -150,8 +151,8 @@ local vmMap = { local value = obj.value local localself = { type = 'local', - start = 0, - finish = 0, + start = value.start, + finish = value.finish, method = obj, effect = obj.finish, tag = 'self', diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 2369e84f..ad07e90e 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -1,34 +1,8 @@ -local util = require 'utility' local error = error local type = type -local next = next -local tostring = tostring -local print = print -local ipairs = ipairs -local tableInsert = table.insert -local tableUnpack = table.unpack -local tableRemove = table.remove -local tableMove = table.move -local tableSort = table.sort -local tableConcat = table.concat -local mathType = math.type -local pairs = pairs -local setmetatable = setmetatable -local assert = assert -local select = select -local osClock = os.clock -local tonumber = tonumber -local tointeger = math.tointeger -local DEVELOP = _G.DEVELOP -local log = log -local _G = _G ---@class parser.guide.object -local function logWarn(...) - log.warn(...) -end - ---@class guide ---@field debugMode boolean local m = {} @@ -91,7 +65,7 @@ m.childMap = { ['doc'] = {'#'}, ['doc.class'] = {'class', '#extends', 'comment'}, - ['doc.type'] = {'#types', '#enums', 'name', 'comment'}, + ['doc.type'] = {'#types', '#enums', '#resumes', 'name', 'comment'}, ['doc.alias'] = {'alias', 'extends', 'comment'}, ['doc.param'] = {'param', 'extends', 'comment'}, ['doc.return'] = {'#returns', 'comment'}, @@ -100,9 +74,9 @@ m.childMap = { ['doc.generic.object'] = {'generic', 'extends', 'comment'}, ['doc.vararg'] = {'vararg', 'comment'}, ['doc.type.array'] = {'node'}, - ['doc.type.table'] = {'node', 'key', 'value', 'comment'}, + ['doc.type.table'] = {'tkey', 'tvalue', 'comment'}, ['doc.type.function'] = {'#args', '#returns', 'comment'}, - ['doc.type.typeliteral'] = {'node'}, + ['doc.type.literal'] = {'node'}, ['doc.type.arg'] = {'extends'}, ['doc.overload'] = {'overload', 'comment'}, ['doc.see'] = {'name', 'field'}, @@ -123,19 +97,31 @@ m.actionMap = { ['funcargs'] = {'#'}, } -local TypeSort = { - ['boolean'] = 1, - ['string'] = 2, - ['integer'] = 3, - ['number'] = 4, - ['table'] = 5, - ['function'] = 6, - ['true'] = 101, - ['false'] = 102, - ['nil'] = 999, -} +local inf = 1 / 0 +local nan = 0 / 0 + +local function isInteger(n) + if math.type then + return math.type(n) == 'integer' + else + return type(n) == 'number' and n % 1 == 0 + end +end -local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) +local function formatNumber(n) + if n == inf + or n == -inf + or n == nan + or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则 + return ('%q'):format(n) + end + if isInteger(n) then + return tostring(n) + end + local str = ('%.10f'):format(n) + str = str:gsub('%.?0*$', '') + return str +end --- 是否是字面量 ---@param obj parser.guide.object @@ -182,23 +168,6 @@ function m.getParentFunction(obj) return nil end ---- 寻找父的table类型 doc.type.table ----@param obj parser.guide.object ----@return parser.guide.object -function m.getParentDocTypeTable(obj) - for _ = 1, 1000 do - local parent = obj.parent - if not parent then - return nil - end - if parent.type == 'doc.type.table' then - return obj - end - obj = parent - end - error('guide.getParentDocTypeTable overstack') -end - --- 寻找所在区块 ---@param obj parser.guide.object ---@return parser.guide.object @@ -293,10 +262,19 @@ end ---@param obj parser.guide.object ---@return parser.guide.object function m.getRoot(obj) + local source = obj + if source._root then + return source._root + end for _ = 1, 1000 do if obj.type == 'main' then + source._root = obj return obj end + if obj._root then + source._root = obj._root + return source._root + end local parent = obj.parent if not parent then return nil @@ -501,8 +479,8 @@ function m.addChilds(list, obj, map) for i = 1, #keys do local key = keys[i] if key == '#' then - for i = 1, #obj do - list[#list+1] = obj[i] + for j = 1, #obj do + list[#list+1] = obj[j] end elseif obj[key] then list[#list+1] = obj[key] @@ -510,8 +488,8 @@ function m.addChilds(list, obj, map) and key:sub(1, 1) == '#' then key = key:sub(2) if obj[key] then - for i = 1, #obj[key] do - list[#list+1] = obj[key][i] + for j = 1, #obj[key] do + list[#list+1] = obj[key][j] end end end @@ -613,9 +591,16 @@ function m.eachSource(ast, callback) index = index + 1 if not mark[obj] then mark[obj] = true - callback(obj) + local res = callback(obj) + if res == true then + goto CONTINUE + end + if res == false then + return + end m.addChilds(list, obj, m.childMap) end + ::CONTINUE:: end end @@ -718,4 +703,294 @@ function m.lineData(lines, row) return lines[row] end +function m.isSet(source) + local tp = source.type + if tp == 'setglobal' + or tp == 'local' + or tp == 'setlocal' + or tp == 'setfield' + or tp == 'setmethod' + or tp == 'setindex' + or tp == 'tablefield' + or tp == 'tableindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawset' then + return true + end + end + return false +end + +function m.isGet(source) + local tp = source.type + if tp == 'getglobal' + or tp == 'getlocal' + or tp == 'getfield' + or tp == 'getmethod' + or tp == 'getindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawget' then + return true + end + end + return false +end + +function m.getSpecial(source) + if not source then + return nil + end + return source.special +end + +function m.getKeyNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return obj[1] + elseif tp == 'string' then + local s = obj[1] + if s then + return s + end + elseif tp == 'number' then + local n = obj[1] + if n then + return ('%s'):format(formatNumber(obj[1])) + end + elseif tp == 'boolean' then + local b = obj[1] + if b then + return tostring(b) + end + end +end + +function m.getKeyName(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + if obj.field then + return obj.field[1] + end + elseif tp == 'getmethod' + or tp == 'setmethod' then + if obj.method then + return obj.method[1] + end + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' + or tp == 'doc.see.field' then + return obj[1] + elseif tp == 'doc.class' then + return obj.class[1] + elseif tp == 'doc.alias' then + return obj.alias[1] + elseif tp == 'doc.field' then + return obj.field[1] + elseif tp == 'doc.field.name' then + return obj[1] + elseif tp == 'dummy' then + return obj[1] + end + return m.getKeyNameOfLiteral(obj) +end + +function m.getKeyTypeOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return 'string' + elseif tp == 'string' then + return 'string' + elseif tp == 'number' then + return 'number' + elseif tp == 'boolean' then + return 'boolean' + end +end + +function m.getKeyType(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return 'string' + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return 'local' + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + return 'string' + elseif tp == 'getmethod' + or tp == 'setmethod' then + return 'string' + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyTypeOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' + or tp == 'doc.see.field' then + return 'string' + elseif tp == 'doc.class' then + return 'string' + elseif tp == 'doc.alias' then + return 'string' + elseif tp == 'doc.field' then + return 'string' + elseif tp == 'dummy' then + return 'string' + end + if tp == 'doc.field.name' then + return 'string' + end + return m.getKeyTypeOfLiteral(obj) +end + +--- 测试 a 到 b 的路径(不经过函数,不考虑 goto), +--- 每个路径是一个 block 。 +--- +--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>` +--- +--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>` +--- +--- 否则返回 `false` +--- +--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。 +---@param a table +---@param b table +---@return string|boolean mode +---@return table pathA? +---@return table pathB? +function m.getPath(a, b, sameFunction) + --- 首先测试双方在同一个函数内 + if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then + return false + end + local mode + local objA + local objB + if a.finish < b.start then + mode = 'before' + objA = a + objB = b + elseif a.start > b.finish then + mode = 'after' + objA = b + objB = a + else + return 'equal', {}, {} + end + local pathA = {} + local pathB = {} + for _ = 1, 1000 do + objA = m.getParentBlock(objA) + pathA[#pathA+1] = objA + if (not sameFunction and objA.type == 'function') or objA.type == 'main' then + break + end + end + for _ = 1, 1000 do + objB = m.getParentBlock(objB) + pathB[#pathB+1] = objB + if (not sameFunction and objA.type == 'function') or objB.type == 'main' then + break + end + end + -- pathA: {1, 2, 3, 4, 5} + -- pathB: {5, 6, 2, 3} + local top = #pathB + local start + for i = #pathA, 1, -1 do + local currentBlock = pathA[i] + if currentBlock == pathB[top] then + start = i + break + end + end + if not start then + return nil + end + -- pathA: { 1, 2, 3} + -- pathB: {5, 6, 2, 3} + local extra = 0 + local align = top - start + for i = start, 1, -1 do + local currentA = pathA[i] + local currentB = pathB[i+align] + if currentA ~= currentB then + extra = i + break + end + end + -- pathA: {1} + local resultA = {} + for i = extra, 1, -1 do + resultA[#resultA+1] = pathA[i] + end + -- pathB: {5, 6} + local resultB = {} + for i = extra + align, 1, -1 do + resultB[#resultB+1] = pathB[i] + end + return mode, resultA, resultB +end + +---是否是全局变量(包括 _G.XXX 形式) +---@param source parser.guide.object +---@return boolean +function m.isGlobal(source) + if source._isGlobal ~= nil then + return source._isGlobal + end + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + source._isGlobal = true + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.special == '_G' then + source._isGlobal = true + return true + end + source._isGlobal = false + return false +end + return m diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index ae8e3f34..335c8f24 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -1,7 +1,7 @@ local m = require 'lpeglabel' local re = require 'parser.relabel' local lines = require 'parser.lines' -local guide = require 'core.guide' +local guide = require 'parser.guide' local grammar = require 'parser.grammar' local TokenTypes, TokenStarts, TokenFinishs, TokenContents @@ -194,6 +194,7 @@ local function parseClass(parent) local result = { type = 'doc.class', parent = parent, + fields = {}, } result.class = parseName('doc.class.name', result) if not result.class then @@ -300,8 +301,8 @@ local function parseTypeUnitTable(parent, node) node.parent = result; result.finish = getFinish() - result.key = key - result.value = value + result.tkey = key + result.tvalue = value return result end @@ -425,9 +426,10 @@ local function parseTypeUnit(parent, content) return result end -local function parseResume() +local function parseResume(parent) local result = { - type = 'doc.resume' + type = 'doc.resume', + parent = parent, } if checkToken('symbol', '>', 1) then @@ -456,7 +458,6 @@ local function parseResume() return result end -local LastType function parseType(parent) local result = { type = 'doc.type', @@ -484,13 +485,7 @@ function parseType(parent) break end -- TypeLiteral,指代类型的字面值。比如,对于类 Cat 来说,它的 TypeLiteral 是 "Cat" - typeLiteral = { - type = 'doc.type.typeliteral', - parent = result, - start = getStart(), - finish = nil, - node = nil, - } + typeLiteral = true end if tp == 'name' then @@ -501,10 +496,7 @@ function parseType(parent) end if typeLiteral then nextToken() - typeLiteral.finish = getFinish() - typeLiteral.node = typeUnit - typeUnit.parent = typeLiteral - typeUnit = typeLiteral + typeUnit.literal = true end result.types[#result.types+1] = typeUnit if not result.start then @@ -566,7 +558,7 @@ function parseType(parent) row = row + i + 1 local finishPos = nextComm.text:find('#', 3) or #nextComm.text parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1) - local resume = parseResume() + local resume = parseResume(result) if resume then if comments then resume.comment = table.concat(comments, '\n') @@ -1122,17 +1114,25 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish) end local src = sources[index] if src.start < start then - left = index + left = index + 1 else right = index end end - for i = index - 1, max do + + -- 从前往后进行绑定 + for i = index, max do local src = sources[i] if src then if src.start > finish then break end + -- 遇到table后中断,处理以下情况: + -- ---@type AAA + -- local t = {x = 1, y = 2} + if src.type == 'table' then + break + end if src.start >= start then src.bindDocs = binded bindSources[#bindSources+1] = src @@ -1152,21 +1152,22 @@ local function bindParamAndReturnIndex(binded) if not func then return end - if not func.args then - return - end - local paramIndex = 0 - local paramMap = {} - for _, param in ipairs(func.args) do - paramIndex = paramIndex + 1 - if param[1] then - paramMap[param[1]] = paramIndex + local paramMap + if func.args then + local paramIndex = 0 + paramMap = {} + for _, param in ipairs(func.args) do + paramIndex = paramIndex + 1 + if param[1] then + paramMap[param[1]] = paramIndex + end end + func.docParamMap = paramMap end local returnIndex = 0 for _, doc in ipairs(binded) do if doc.type == 'doc.param' then - if doc.extends then + if paramMap and doc.extends then doc.extends.paramIndex = paramMap[doc.param[1]] end elseif doc.type == 'doc.return' then @@ -1178,6 +1179,24 @@ local function bindParamAndReturnIndex(binded) end end +local function bindClassAndFields(binded) + local class + for _, doc in ipairs(binded) do + if doc.type == 'doc.class' then + -- 多个class连续写在一起,只有最后一个class可以绑定source + if class then + class.bindSources = nil + end + class = doc + elseif doc.type == 'doc.field' then + if class then + class.fields[#class.fields+1] = doc + doc.class = class + end + end + end +end + local function bindDoc(sources, lns, binded) if not binded then return @@ -1200,6 +1219,7 @@ local function bindDoc(sources, lns, binded) bindDocsBetween(sources, binded, bindSources, nstart, nfinish) end bindParamAndReturnIndex(binded) + bindClassAndFields(binded) end local function bindDocs(state) @@ -1214,6 +1234,7 @@ local function bindDocs(state) or src.type == 'tablefield' or src.type == 'tableindex' or src.type == 'function' + or src.type == 'table' or src.type == '...' then sources[#sources+1] = src end diff --git a/script/proto/define.lua b/script/proto/define.lua index abfaa9b0..f2ee7ab5 100644 --- a/script/proto/define.lua +++ b/script/proto/define.lua @@ -103,7 +103,7 @@ m.DiagnosticDefaultNeededFileStatus = { ['unused-local'] = 'Opened', ['unused-function'] = 'Opened', ['undefined-global'] = 'Any', - ['undefined-field'] = 'Opened', + ['undefined-field'] = 'Any', ['global-in-nil-env'] = 'Any', ['unused-label'] = 'Opened', ['unused-vararg'] = 'Opened', @@ -124,7 +124,7 @@ m.DiagnosticDefaultNeededFileStatus = { ['close-non-object'] = 'Any', ['count-down-loop'] = 'Any', ['no-implicit-any'] = 'None', - ['deprecated'] = 'None', + ['deprecated'] = 'Opened', ['duplicate-doc-class'] = 'Any', ['undefined-doc-class'] = 'Any', @@ -284,4 +284,19 @@ m.BuiltIn = { ['utf8'] = 'default', } +m.BuiltinClass = { + ['unknown'] = true, + ['any'] = true, + ['nil'] = true, + ['boolean'] = true, + ['number'] = true, + ['integer'] = true, + ['thread'] = true, + ['table'] = true, + ['string'] = true, + ['userdata'] = true, + ['lightuserdata'] = true, + ['Function'] = true, +} + return m diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua index 883ae68c..4a207115 100644 --- a/script/provider/diagnostic.lua +++ b/script/provider/diagnostic.lua @@ -190,7 +190,7 @@ function m.doDiagnostic(uri) await.delay() - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then m.clear(uri) return diff --git a/script/service/service.lua b/script/service/service.lua index 44fd9aa4..247cb5b5 100644 --- a/script/service/service.lua +++ b/script/service/service.lua @@ -1,3 +1,4 @@ +---@diagnostic disable: deprecated local pub = require 'pub' local thread = require 'bee.thread' local await = require 'await' diff --git a/script/utility.lua b/script/utility.lua index 04597a39..16c5e0c9 100644 --- a/script/utility.lua +++ b/script/utility.lua @@ -317,12 +317,12 @@ end --- 排序后遍历 ---@param t table -function m.sortPairs(t) +function m.sortPairs(t, sorter) local keys = {} for k in pairs(t) do keys[#keys+1] = k end - tableSort(keys) + tableSort(keys, sorter) local i = 0 return function () i = i + 1 diff --git a/script/vm/eachDef.lua b/script/vm/eachDef.lua index d72c8f01..6f7af295 100644 --- a/script/vm/eachDef.lua +++ b/script/vm/eachDef.lua @@ -1,49 +1,7 @@ ---@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local files = require 'files' -local util = require 'utility' -local await = require 'await' -local config = require 'config' +local vm = require 'vm.vm' +local searcher = require 'core.searcher' -local function getDefs(source, deep) - local results = {} - local lock = vm.lock('eachDef', source) - if not lock then - return results - end - - await.delay() - - deep = config.config.intelliSense.searchDepth + (deep or 0) - - local clock = os.clock() - local myResults, count = guide.requestDefinition(source, vm.interface, deep) - if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestDefinition', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) - end - vm.mergeResults(results, myResults) - - lock() - - return results -end - -function vm.getDefs(source, deep) - deep = deep or -999 - if guide.isGlobal(source) then - local key = guide.getKeyName(source) - if not key then - return {} - end - return vm.getGlobalSets(key) - else - local cache = vm.getCache('eachDef')[source] - if not cache or cache.deep < deep then - cache = getDefs(source, deep) - cache.deep = deep - vm.getCache('eachDef')[source] = cache - end - return cache - end +function vm.getDefs(source, field) + return searcher.requestDefinition(source, field) end diff --git a/script/vm/eachField.lua b/script/vm/eachField.lua deleted file mode 100644 index 59f35f0c..00000000 --- a/script/vm/eachField.lua +++ /dev/null @@ -1,109 +0,0 @@ ----@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local await = require 'await' -local config = require 'config' - -local function getFields(source, deep, filterKey) - local unlock = vm.lock('eachField', source) - if not unlock then - return {} - end - - while source.type == 'paren' do - source = source.exp - if not source then - return {} - end - end - deep = config.config.intelliSense.searchDepth + (deep or 0) - - await.delay() - local results = guide.requestFields(source, vm.interface, deep, filterKey) - - unlock() - return results -end - -local function getDefFields(source, deep, filterKey) - local unlock = vm.lock('eachDefField', source) - if not unlock then - return {} - end - - while source.type == 'paren' do - source = source.exp - if not source then - return {} - end - end - deep = config.config.intelliSense.searchDepth + (deep or 0) - - await.delay() - local results = guide.requestDefFields(source, vm.interface, deep, filterKey) - - unlock() - return results -end - -local function getFieldsBySource(source, deep, filterKey) - deep = deep or -999 - local cache = vm.getCache('eachField')[source] - if not cache or cache.deep < deep then - cache = getFields(source, deep, filterKey) - cache.deep = deep - if not filterKey then - vm.getCache('eachField')[source] = cache - end - end - return cache -end - -local function getDefFieldsBySource(source, deep, filterKey) - deep = deep or -999 - local cache = vm.getCache('eachDefField')[source] - if not cache or cache.deep < deep then - cache = getDefFields(source, deep, filterKey) - cache.deep = deep - if not filterKey then - vm.getCache('eachDefField')[source] = cache - end - end - return cache -end - -function vm.getFields(source, deep) - if source.special == '_G' then - return vm.getGlobals '*' - end - if guide.isGlobal(source) then - local name = guide.getKeyName(source) - if not name then - return {} - end - local cache = vm.getCache('eachFieldOfGlobal')[name] - or getFieldsBySource(source, deep) - vm.getCache('eachFieldOfGlobal')[name] = cache - return cache - else - return getFieldsBySource(source, deep) - end -end - -function vm.getDefFields(source, deep) - if source.special == '_G' then - return vm.getGlobalSets '*' - end - if guide.isGlobal(source) then - local name = guide.getKeyName(source) - if not name then - return {} - end - local cache = vm.getCache('eachDefFieldOfGlobal')[name] - or getDefFieldsBySource(source, deep) - vm.getCache('eachDefFieldOfGlobal')[name] = cache - return cache - else - return getDefFieldsBySource(source, deep) - end -end diff --git a/script/vm/eachRef.lua b/script/vm/eachRef.lua index 9d0f061c..5aca198e 100644 --- a/script/vm/eachRef.lua +++ b/script/vm/eachRef.lua @@ -1,48 +1,7 @@ ---@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local util = require 'utility' -local await = require 'await' -local config = require 'config' +local vm = require 'vm.vm' +local searcher = require 'core.searcher' -local function getRefs(source, deep) - local results = {} - local lock = vm.lock('eachRef', source) - if not lock then - return results - end - - await.delay() - - deep = config.config.intelliSense.searchDepth + (deep or 0) - - local clock = os.clock() - local myResults, count = guide.requestReference(source, vm.interface, deep) - if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestReference', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) - end - vm.mergeResults(results, myResults) - - lock() - - return results -end - -function vm.getRefs(source, deep) - deep = deep or -999 - if guide.isGlobal(source) then - local key = guide.getKeyName(source) - if not key then - return {} - end - return vm.getGlobals(key) - else - local cache = vm.getCache('eachRef')[source] - if not cache or cache.deep < deep then - cache = getRefs(source, deep) - cache.deep = deep - vm.getCache('eachRef')[source] = cache - end - return cache - end +function vm.getRefs(source, field) + return searcher.requestReference(source, field) end diff --git a/script/vm/getClass.lua b/script/vm/getClass.lua deleted file mode 100644 index 5c68e0bb..00000000 --- a/script/vm/getClass.lua +++ /dev/null @@ -1,64 +0,0 @@ ----@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' - -local function lookUpDocClass(source) - local infers = vm.getInfers(source, 0) - for _, infer in ipairs(infers) do - if infer.source.type == 'doc.class' - or infer.source.type == 'doc.type' then - return guide.viewInferType(infers) - end - end - return nil -end - -local function getClass(source, classes, depth, deep) - local docClass = lookUpDocClass(source) - if docClass then - classes[docClass] = true - return - end - if depth > 3 then - return - end - local value = guide.getObjectValue(source) or source - if not deep then - if value and value.type == 'string' then - classes[value[1]] = true - end - else - for _, src in ipairs(vm.getDefFields(value)) do - local key = vm.getKeyName(src) - if not key then - goto CONTINUE - end - local lkey = key:lower() - if lkey == 'type' - or lkey == '__name' - or lkey == 'name' - or lkey == 'class' then - local value = guide.getObjectValue(src) - if value and value.type == 'string' then - classes[value[1]] = true - end - end - ::CONTINUE:: - end - end - if next(classes) then - return - end - vm.eachMeta(source, function (mt) - getClass(mt, classes, depth + 1, deep) - end) -end - -function vm.getClass(source, deep) - local classes = {} - getClass(source, classes, 1, deep) - if not next(classes) then - return nil - end - return guide.mergeTypes(classes) -end diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua index cfa9326f..16b82278 100644 --- a/script/vm/getDocs.lua +++ b/script/vm/getDocs.lua @@ -1,152 +1,65 @@ -local files = require 'files' -local util = require 'utility' -local guide = require 'core.guide' ----@type vm -local vm = require 'vm.vm' -local config = require 'config' +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm.vm' +local config = require 'config' +local collector = require 'core.collector' +local define = require 'proto.define' +local noder = require 'core.noder' -local function getTypesOfFile(uri) - local types = {} - local ast = files.getAst(uri) - if not ast or not ast.ast.docs then - return types +---获取class与alias +---@param name? string +---@return parser.guide.object[] +function vm.getDocDefines(name) + local cache = vm.getCache 'getDocDefines' + if cache[name] then + return cache[name] end - guide.eachSource(ast.ast.docs, function (src) - if src.type == 'doc.type.name' - or src.type == 'doc.class.name' - or src.type == 'doc.extends.name' - or src.type == 'doc.alias.name' then - if src.type == 'doc.type.name' then - if guide.getParentDocTypeTable(src) then - return - end - end - local name = src[1] - if name then - if not types[name] then - types[name] = {} - end - types[name][#types[name]+1] = src - end - end - end) - return types -end - -local function getDocTypes(name) local results = {} - if name == 'any' - or name == 'nil' then - return results - end - for uri in files.eachFile() do - local cache = files.getCache(uri) - cache.classes = cache.classes or getTypesOfFile(uri) - if name == '*' then - for _, sources in util.sortPairs(cache.classes) do - for _, source in ipairs(sources) do - results[#results+1] = source - end - end - else - if cache.classes[name] then - for _, source in ipairs(cache.classes[name]) do + local id = 'def:dn:' .. (name or '') + for node in collector.each(id) do + if node.sources then + for _, source in ipairs(node.sources) do + if source.type == 'doc.class.name' + or source.type == 'doc.alias.name' then results[#results+1] = source end end end end + cache[name] = results return results end -function vm.getDocEnums(doc, mark, results) - if not doc then - return nil - end - mark = mark or {} - if mark[doc] then - return nil - end - mark[doc] = true - results = results or {} - for _, enum in ipairs(doc.enums) do - results[#results+1] = enum - end - for _, resume in ipairs(doc.resumes) do - results[#results+1] = resume +function vm.isDocDefined(name) + if define.BuiltinClass[name] then + return true end - for _, unit in ipairs(doc.types) do - if unit.type == 'doc.type.name' then - for _, other in ipairs(vm.getDocTypes(unit[1])) do - if other.type == 'doc.alias.name' then - vm.getDocEnums(other.parent.extends, mark, results) - end - end - end + local id = 'def:dn:' .. name + if collector.has(id) then + return true end - return results + return false end -function vm.getDocTypeUnits(doc, mark, results) +function vm.getDocEnums(doc) if not doc then return nil end - mark = mark or {} - if mark[doc] then - return nil - end - mark[doc] = true - results = results or {} - for _, enum in ipairs(doc.enums) do - results[#results+1] = enum - end - for _, resume in ipairs(doc.resumes) do - results[#results+1] = resume - end - for _, unit in ipairs(doc.types) do - if unit.type == 'doc.type.name' then - for _, other in ipairs(vm.getDocTypes(unit[1])) do - if other.type == 'doc.alias.name' then - vm.getDocTypeUnits(other.parent.extends, mark, results) - elseif other.type == 'doc.class.name' then - results[#results+1] = other - end - end - else - results[#results+1] = unit - end - end - return results -end - -function vm.getDocTypes(name) - local cache = vm.getCache('getDocTypes')[name] - if cache ~= nil then - return cache - end - cache = getDocTypes(name) - vm.getCache('getDocTypes')[name] = cache - return cache -end + local defs = vm.getDefs(doc) + local results = {} -function vm.getDocClass(name) - local cache = vm.getCache('getDocClass')[name] - if cache ~= nil then - return cache - end - cache = {} - local results = getDocTypes(name) - for _, doc in ipairs(results) do - if doc.type == 'doc.class.name' then - cache[#cache+1] = doc + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then + results[#results+1] = def end end - vm.getCache('getDocClass')[name] = cache - return cache + + return results end function vm.isMetaFile(uri) - local status = files.getAst(uri) + local status = files.getState(uri) if not status then return false end @@ -224,7 +137,7 @@ end function vm.isDeprecated(value, deep) if deep then - local defs = vm.getDefs(value, 0) + local defs = vm.getDefs(value) if #defs == 0 then return false end @@ -300,7 +213,7 @@ local function makeDiagRange(uri, doc, results) end function vm.isDiagDisabledAt(uri, offset, name) - local status = files.getAst(uri) + local status = files.getState(uri) if not status then return false end diff --git a/script/vm/getGlobals.lua b/script/vm/getGlobals.lua index 2752ce09..e5bcafc0 100644 --- a/script/vm/getGlobals.lua +++ b/script/vm/getGlobals.lua @@ -1,5 +1,6 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local await = require "await" +local searcher = require "core.searcher" ---@type vm local vm = require 'vm.vm' local files = require 'files' @@ -17,12 +18,8 @@ local function getGlobalsOfFile(uri) end local globals = {} cache.globals = globals - local ast = files.getAst(uri) - if not ast then - return globals - end tracy.ZoneBeginN 'getGlobalsOfFile' - local results = guide.findGlobals(ast.ast) + local results = searcher.findGlobals(uri, 'ref') local subscribe = ws.getCache 'globalSubscribe' subscribe[uri] = {} local mark = {} @@ -34,7 +31,7 @@ local function getGlobalsOfFile(uri) goto CONTINUE end mark[res] = true - local name = guide.getSimpleName(res) + local name = guide.getKeyName(res) if name then if not globals[name] then globals[name] = {} @@ -59,12 +56,8 @@ local function getGlobalSetsOfFile(uri) end local globals = {} cache.globalSets = globals - local ast = files.getAst(uri) - if not ast then - return globals - end tracy.ZoneBeginN 'getGlobalSetsOfFile' - local results = guide.findGlobals(ast.ast) + local results = searcher.findGlobals(uri, 'def') local subscribe = ws.getCache 'globalSetsSubscribe' subscribe[uri] = {} local mark = {} @@ -76,16 +69,14 @@ local function getGlobalSetsOfFile(uri) goto CONTINUE end mark[res] = true - if vm.isSet(res) then - local name = guide.getSimpleName(res) - if name then - if not globals[name] then - globals[name] = {} - subscribe[uri][#subscribe[uri]+1] = name - end - globals[name][#globals[name]+1] = res - globals['*'][#globals['*']+1] = res + local name = guide.getKeyName(res) + if name then + if not globals[name] then + globals[name] = {} + subscribe[uri][#subscribe[uri]+1] = name end + globals[name][#globals[name]+1] = res + globals['*'][#globals['*']+1] = res end ::CONTINUE:: end @@ -265,7 +256,7 @@ files.watch(function (ev, uri) end needUpdateGlobals[uri] = true elseif ev == 'create' then - getGlobalsOfFile(uri) - getGlobalSetsOfFile(uri) + --getGlobalsOfFile(uri) + --getGlobalSetsOfFile(uri) end end) diff --git a/script/vm/getInfer.lua b/script/vm/getInfer.lua deleted file mode 100644 index 5447ca23..00000000 --- a/script/vm/getInfer.lua +++ /dev/null @@ -1,104 +0,0 @@ ----@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local util = require 'utility' -local await = require 'await' -local config = require 'config' - -NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) - ---- 是否包含某种类型 -function vm.hasType(source, type, deep) - local defs = vm.getDefs(source, deep) - for i = 1, #defs do - local def = defs[i] - local value = guide.getObjectValue(def) or def - if value.type == type then - return true - end - end - return false -end - ---- 是否包含某种类型 -function vm.hasInferType(source, type, deep) - local infers = vm.getInfers(source, deep) - for i = 1, #infers do - local infer = infers[i] - if infer.type == type then - return true - end - end - return false -end - -function vm.getInferType(source, deep) - local infers = vm.getInfers(source, deep) - return guide.viewInferType(infers) -end - -function vm.getInferLiteral(source, deep) - local infers = vm.getInfers(source, deep) - local literals = {} - local mark = {} - for _, infer in ipairs(infers) do - local value = infer.value - if value and not mark[value] then - mark[value] = true - literals[#literals+1] = util.viewLiteral(value) - end - end - if #literals == 0 then - return nil - end - table.sort(literals) - return table.concat(literals, '|') -end - -local function getInfers(source, deep) - local results = {} - local lock = vm.lock('getInfers', source) - if not lock then - return results - end - - deep = config.config.intelliSense.searchDepth + (deep or 0) - - await.delay() - - local clock = os.clock() - local myResults, count = guide.requestInfer(source, vm.interface, deep) - if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestInfer', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) - end - vm.mergeResults(results, myResults) - - lock() - - return results -end - -local function getInfersBySource(source, deep) - deep = deep or -999 - local cache = vm.getCache('getInfers')[source] - if not cache or cache.deep < deep then - cache = getInfers(source, deep) - cache.deep = deep - vm.getCache('getInfers')[source] = cache - end - return cache -end - ---- 获取对象的值 ---- 会尝试穿透函数调用 -function vm.getInfers(source, deep) - if guide.isGlobal(source) then - local name = guide.getKeyName(source) - local cache = vm.getCache('getInfersOfGlobal')[name] - or getInfersBySource(source, deep) - vm.getCache('getInfersOfGlobal')[name] = cache - return cache - else - return getInfersBySource(source, deep) - end -end diff --git a/script/vm/getLibrary.lua b/script/vm/getLibrary.lua index b52f7240..a3c8feb0 100644 --- a/script/vm/getLibrary.lua +++ b/script/vm/getLibrary.lua @@ -1,8 +1,11 @@ ---@type vm local vm = require 'vm.vm' -function vm.getLibraryName(source, deep) - local defs = vm.getDefs(source, deep) +function vm.getLibraryName(source) + if source.special then + return source.special + end + local defs = vm.getDefs(source) for _, def in ipairs(defs) do if def.special then return def.special diff --git a/script/vm/getLinks.lua b/script/vm/getLinks.lua index 91a5f1a0..14b34987 100644 --- a/script/vm/getLinks.lua +++ b/script/vm/getLinks.lua @@ -1,12 +1,11 @@ -local guide = require 'core.guide' ----@type vm +local guide = require 'parser.guide' local vm = require 'vm.vm' local files = require 'files' local function getFileLinks(uri) local ws = require 'workspace' local links = {} - local ast = files.getAst(uri) + local ast = files.getState(uri) if not ast then return links end @@ -33,11 +32,17 @@ local function getFileLinks(uri) return links end +local function getFileLinksOrCache(uri) + local cache = files.getCache(uri) + cache.links = cache.links or getFileLinks(uri) + return cache.links +end + local function getLinksTo(uri) uri = files.asKey(uri) local links = {} for u in files.eachFile() do - local ls = vm.getFileLinks(u) + local ls = getFileLinksOrCache(u) if ls[uri] then for _, l in ipairs(ls[uri]) do links[#links+1] = l @@ -47,6 +52,7 @@ local function getLinksTo(uri) return links end +-- 获取所有 require(uri) 的文件 function vm.getLinksTo(uri) local cache = vm.getCache('getLinksTo')[uri] if cache ~= nil then @@ -56,9 +62,3 @@ function vm.getLinksTo(uri) vm.getCache('getLinksTo')[uri] = cache return cache end - -function vm.getFileLinks(uri) - local cache = files.getCache(uri) - cache.links = cache.links or getFileLinks(uri) - return cache.links -end diff --git a/script/vm/getMeta.lua b/script/vm/getMeta.lua deleted file mode 100644 index 44d1874a..00000000 --- a/script/vm/getMeta.lua +++ /dev/null @@ -1,53 +0,0 @@ ----@type vm -local vm = require 'vm.vm' - -local function eachMetaOfArg1(source, callback) - local node, index = vm.getArgInfo(source) - local special = vm.getSpecial(node) - if special == 'setmetatable' and index == 1 then - local mt = node.next.args[2] - if mt then - callback(mt) - end - end -end - -local function eachMetaOfRecv(source, callback) - if not source or source.type ~= 'select' then - return - end - if source.index ~= 1 then - return - end - local call = source.vararg - if not call or call.type ~= 'call' then - return - end - local special = vm.getSpecial(call.node) - if special ~= 'setmetatable' then - return - end - local mt = call.args[2] - if mt then - callback(mt) - end -end - -function vm.eachMetaValue(source, callback) - vm.eachMeta(source, function (mt) - for _, src in ipairs(vm.getDefFields(mt)) do - if vm.getKeyName(src) == '__index' then - if src.value then - for _, valueSrc in ipairs(vm.getDefFields(src.value)) do - callback(valueSrc) - end - end - end - end - end) -end - -function vm.eachMeta(source, callback) - eachMetaOfArg1(source, callback) - eachMetaOfRecv(source.value, callback) -end diff --git a/script/vm/guideInterface.lua b/script/vm/guideInterface.lua index ae060481..a07b6644 100644 --- a/script/vm/guideInterface.lua +++ b/script/vm/guideInterface.lua @@ -2,7 +2,7 @@ local vm = require 'vm.vm' local files = require 'files' local ws = require 'workspace' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local await = require 'await' local config = require 'config' @@ -27,11 +27,11 @@ function m.require(args, index) return nil end local results = {} - local myUri = guide.getUri(args[1]) + local myUri = searcher.getUri(args[1]) local uris = ws.findUrisByRequirePath(reqName) for _, uri in ipairs(uris) do if not files.eq(myUri, uri) then - local ast = files.getAst(uri) + local ast = files.getState(uri) if ast then m.searchFileReturn(results, ast.ast, index) end @@ -47,11 +47,11 @@ function m.dofile(args, index) return end local results = {} - local myUri = guide.getUri(args[1]) + local myUri = searcher.getUri(args[1]) local uris = ws.findUrisByFilePath(reqName) for _, uri in ipairs(uris) do if not files.eq(myUri, uri) then - local ast = files.getAst(uri) + local ast = files.getState(uri) if ast then m.searchFileReturn(results, ast.ast, index) end @@ -87,9 +87,9 @@ function vm.interface.global(name, onlyDef) end end -function vm.interface.docType(name) +function vm.interface.doc(name, type) await.delay() - return vm.getDocTypes(name) + return vm.getDocNames(name, type) end function vm.interface.link(uri) diff --git a/script/vm/init.lua b/script/vm/init.lua index b9e8e147..c38f01d5 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -2,10 +2,6 @@ local vm = require 'vm.vm' require 'vm.getGlobals' require 'vm.getDocs' require 'vm.getLibrary' -require 'vm.getInfer' -require 'vm.getClass' -require 'vm.getMeta' -require 'vm.eachField' require 'vm.eachDef' require 'vm.eachRef' require 'vm.getLinks' diff --git a/script/vm/vm.lua b/script/vm/vm.lua index 0248ad8c..0e7f3176 100644 --- a/script/vm/vm.lua +++ b/script/vm/vm.lua @@ -1,18 +1,14 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local util = require 'utility' local files = require 'files' local timer = require 'timer' local setmetatable = setmetatable -local assert = assert -local require = require -local type = type local running = coroutine.running local ipairs = ipairs local log = log local xpcall = xpcall local mathHuge = math.huge -local collectgarbage = collectgarbage _ENV = nil @@ -64,7 +60,10 @@ function m.getArgInfo(source) end function m.getSpecial(source) - return guide.getSpecial(source) + if not source then + return nil + end + return source.special end function m.getKeyName(source) |