diff options
Diffstat (limited to 'script/core')
59 files changed, 2963 insertions, 565 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua index bae3df81..9ed000e9 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -1,10 +1,9 @@ -local files = require 'files' -local lang = require 'language' -local define = require 'proto.define' -local guide = require 'core.guide' -local util = require 'utility' -local sp = require 'bee.subprocess' -local vm = require 'vm' +local files = require 'files' +local lang = require 'language' +local util = require 'utility' +local sp = require 'bee.subprocess' +local vm = require 'vm' +local guide = require "parser.guide" local function checkDisableByLuaDocExits(uri, row, mode, code) local lines = files.getLines(uri) diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua index 527af8d5..ba1ee8eb 100644 --- a/script/core/command/removeSpace.lua +++ b/script/core/command/removeSpace.lua @@ -1,11 +1,10 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local proto = require 'proto' +local lang = require 'language' local function isInString(ast, offset) - return guide.eachSourceContain(ast.ast, offset, function (source) + return searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'string' then return true end @@ -23,10 +22,10 @@ return function (data) local textEdit = {} for i = 1, #lines do - local line = guide.lineContent(lines, text, i, true) + local line = searcher.lineContent(lines, text, i, true) local pos = line:find '[ \t]+$' if pos then - local start, finish = guide.lineRange(lines, i, true) + local start, finish = searcher.lineRange(lines, i, true) start = start + pos - 1 if isInString(ast, start) then goto NEXT_LINE diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua index 995a2109..dc23e7af 100644 --- a/script/core/command/solve.lua +++ b/script/core/command/solve.lua @@ -1,8 +1,7 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' local opMap = { ['+'] = true, diff --git a/script/core/completion.lua b/script/core/completion.lua index e3980eca..acaaa276 100644 --- a/script/core/completion.lua +++ b/script/core/completion.lua @@ -1,19 +1,14 @@ local define = require 'proto.define' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local matchKey = require 'core.matchkey' local vm = require 'vm' -local getLabel = require 'core.hover.label' local getName = require 'core.hover.name' local getArg = require 'core.hover.arg' -local getReturn = require 'core.hover.return' -local getDesc = require 'core.hover.description' local getHover = require 'core.hover' local config = require 'config' local util = require 'utility' local markdown = require 'provider.markdown' -local findSource = require 'core.find-source' -local await = require 'await' local parser = require 'parser' local keyWordMap = require 'core.keyword' local workspace = require 'workspace' @@ -21,6 +16,8 @@ local furi = require 'file-uri' local rpath = require 'workspace.require-path' local lang = require 'language' local lookBackward = require 'core.look-backward' +local guide = require 'parser.guide' +local infer = require 'core.infer' local DiagnosticModes = { 'disable-next-line', @@ -135,8 +132,8 @@ local function buildFunctionSnip(source, value, oop) end local function buildDetail(source) - local types = vm.getInferType(source, 0) - local literals = vm.getInferLiteral(source, 0) + local types = infer.searchAndViewInfers(source) + local literals = infer.searchAndViewLiterals(source) if literals then return types .. ' = ' .. literals else @@ -149,9 +146,9 @@ local function getSnip(source) if context <= 0 then return nil end - local defs = vm.getRefs(source, 0) + local defs = vm.getRefs(source) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def ~= source and def.type == 'function' then local uri = guide.getUri(def) local text = files.getText(uri) @@ -273,8 +270,8 @@ local function checkLocal(ast, word, offset, results) if not matchKey(word, name) then goto CONTINUE end - if vm.hasType(source, 'function') then - for _, def in ipairs(vm.getDefs(source, 0)) do + if infer.hasType(source, 'function') then + for _, def in ipairs(vm.getDefs(source)) do if def.type == 'function' or def.type == 'doc.type.function' then local funcLabel = name .. getParams(def, false) @@ -417,7 +414,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset) end local function checkFieldThen(name, src, word, start, offset, parent, oop, results) - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src local kind = define.CompletionItemKind.Field if value.type == 'function' or value.type == 'doc.type.function' then @@ -492,7 +489,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res end local funcLabel if config.config.completion.showParams then - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src if value.type == 'function' or value.type == 'doc.type.function' then funcLabel = name .. getParams(value, oop) @@ -539,16 +536,16 @@ end local function checkGlobal(ast, word, start, offset, parent, oop, results) local locals = guide.getVisibleLocals(ast.ast, offset) - local refs = vm.getGlobalSets '*' - checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global') + local globals = vm.getGlobalSets '*' + checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results, locals, 'global') end local function checkField(ast, word, start, offset, parent, oop, results) if parent.tag == '_ENV' or parent.special == '_G' then - local refs = vm.getGlobalSets '*' - checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) + local globals = vm.getGlobalSets '*' + checkFieldOfRefs(globals, ast, word, start, offset, parent, oop, results) else - local refs = vm.getFields(parent, 0) + local refs = vm.getRefs(parent, '*') checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) end end @@ -1043,14 +1040,14 @@ local function mergeEnums(a, b, source) end end -local function checkTypingEnum(ast, text, offset, infers, str, results) +local function checkTypingEnum(ast, text, offset, defs, str, results) local enums = {} - for _, infer in ipairs(infers) do - if infer.source.type == 'doc.type.enum' - or infer.source.type == 'doc.resume' then + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then enums[#enums+1] = { - label = infer.source[1], - description = infer.source.comment and infer.source.comment.text, + label = def[1], + description = def.comment and def.comment.text, kind = define.CompletionItemKind.EnumMember, } end @@ -1074,8 +1071,8 @@ local function checkEqualEnumLeft(ast, text, offset, source, results) return src end end) - local infers = vm.getInfers(source, 0) - checkTypingEnum(ast, text, offset, infers, str, results) + local defs = vm.getDefs(source) + checkTypingEnum(ast, text, offset, defs, str, results) end local function checkEqualEnum(ast, text, offset, results) @@ -1247,7 +1244,30 @@ function (%s)\ end"):format(table.concat(args, ', ')) end -local function getCallEnums(source, index) +local function pushCallEnumsAndFuncs(defs) + local results = {} + for _, def in ipairs(defs) do + if def.type == 'doc.type.enum' + or def.type == 'doc.resume' then + results[#results+1] = { + label = def[1], + description = def.comment, + kind = define.CompletionItemKind.EnumMember, + } + end + if def.type == 'doc.type.function' then + results[#results+1] = { + label = infer.viewDocFunction(def), + description = def.comment, + kind = define.CompletionItemKind.Function, + insertText = buildInsertDocFunction(def), + } + end + end + return results +end + +local function getCallEnumsAndFuncs(source, index) if source.type == 'function' and source.bindDocs then if not source.args then return @@ -1266,37 +1286,10 @@ local function getCallEnums(source, index) for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.param' and doc.param[1] == arg[1] then - local enums = {} - for _, enum in ipairs(vm.getDocEnums(doc.extends) or {}) do - enums[#enums+1] = { - label = enum[1], - description = enum.comment, - kind = define.CompletionItemKind.EnumMember, - } - end - for _, unit in ipairs(vm.getDocTypeUnits(doc.extends) or {}) do - if unit.type == 'doc.type.function' then - local text = files.getText(guide.getUri(unit)) - enums[#enums+1] = { - label = text:sub(unit.start, unit.finish), - description = doc.comment, - kind = define.CompletionItemKind.Function, - insertText = buildInsertDocFunction(unit), - } - end - end - return enums + return pushCallEnumsAndFuncs(vm.getDefs(doc.extends)) elseif doc.type == 'doc.vararg' and arg.type == '...' then - local enums = {} - for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do - enums[#enums+1] = { - label = enum[1], - description = enum.comment, - kind = define.CompletionItemKind.EnumMember, - } - end - return enums + return pushCallEnumsAndFuncs(vm.getDefs(doc.vararg)) end end end @@ -1403,12 +1396,12 @@ local function checkTableLiteralFieldByCall(ast, text, offset, call, defs, index return end for _, def in ipairs(defs) do - local func = guide.getObjectValue(def) or def + local func = searcher.getObjectValue(def) or def local param = getFuncParamByCallIndex(func, index) if not param then goto CONTINUE end - local defs = vm.getDefFields(param, 0) + local defs = vm.getDefs(param, '*') for _, field in ipairs(defs) do local name = guide.getKeyName(field) if name and not mark[name] then @@ -1431,10 +1424,10 @@ local function tryCallArg(ast, text, offset, results) if arg and arg.type == 'function' then return end - local defs = vm.getDefs(call.node, 0) + local defs = vm.getDefs(call.node) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def - local enums = getCallEnums(def, argIndex) + def = searcher.getObjectValue(def) or def + local enums = getCallEnumsAndFuncs(def, argIndex) if enums then mergeEnums(myResults, enums, arg) end @@ -1461,7 +1454,8 @@ local function tryTable(ast, text, offset, results) if source.type ~= 'table' then tbl = source.parent end - local defs = vm.getDefFields(tbl, 0) + local parent = tbl.parent + local defs = vm.getDefs(parent, '*') for _, field in ipairs(defs) do local name = guide.getKeyName(field) if name and not mark[name] then @@ -1560,7 +1554,7 @@ end local function tryLuaDocBySource(ast, offset, source, results) if source.type == 'doc.extends.name' then if source.parent.type == 'doc.class' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if doc.type == 'doc.class.name' and doc.parent ~= source.parent and matchKey(source[1], doc[1]) then @@ -1578,7 +1572,7 @@ local function tryLuaDocBySource(ast, offset, source, results) end return true elseif source.type == 'doc.type.name' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') and doc.parent ~= source.parent and matchKey(source[1], doc[1]) then @@ -1652,7 +1646,7 @@ end local function tryLuaDocByErr(ast, offset, err, docState, results) if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if doc.type == 'doc.class.name' and doc.parent ~= docState then results[#results+1] = { @@ -1662,7 +1656,7 @@ local function tryLuaDocByErr(ast, offset, err, docState, results) end end elseif err.type == 'LUADOC_MISS_TYPE_NAME' then - for _, doc in ipairs(vm.getDocTypes '*') do + for _, doc in ipairs(vm.getDocDefines()) do if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then results[#results+1] = { label = doc[1], @@ -1735,14 +1729,14 @@ local function buildLuaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = vm.getInferType(arg) + args[#args+1] = infer.searchAndViewInfers(arg) end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = vm.getInferType(rtns[n]) + returns[n] = infer.searchAndViewInfers(rtns[n]) end end end diff --git a/script/core/definition.lua b/script/core/definition.lua index b26bb922..3ced05a2 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -1,14 +1,15 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local workspace = require 'workspace' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' +local guide = require 'parser.guide' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -20,7 +21,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else @@ -127,11 +128,11 @@ return function (uri, offset) end end - local defs = vm.getDefs(source, 0) + local defs = vm.getDefs(source) local values = {} for _, src in ipairs(defs) do - local value = guide.getObjectValue(src) - if value and value ~= src then + local value = searcher.getObjectValue(src) + if value and value ~= src and guide.isLiteral(value) then values[value] = true end end @@ -148,9 +149,6 @@ return function (uri, offset) goto CONTINUE end src = src.field or src.method or src.index or src - if src.type == 'table' and src.parent.type ~= 'return' then - goto CONTINUE - end if src.type == 'doc.class.name' and source.type ~= 'doc.type.name' and source.type ~= 'doc.extends.name' diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua index 19bb4f97..37815fb5 100644 --- a/script/core/diagnostics/ambiguity-1.lua +++ b/script/core/diagnostics/ambiguity-1.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local opMap = { diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua index 702cd904..d2e26378 100644 --- a/script/core/diagnostics/circle-doc-class.lua +++ b/script/core/diagnostics/circle-doc-class.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' +local guide = require 'parser.guide' return function (uri, callback) local state = files.getAst(uri) @@ -40,7 +40,7 @@ return function (uri, callback) end if not mark[newName] then mark[newName] = true - local docs = vm.getDocTypes(newName) + local docs = vm.getDocDefines(newName) for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' then list[#list+1] = otherDoc.parent diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index d1983c42..7828efe9 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -1,7 +1,6 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' return function (uri, callback) local state = files.getAst(uri) diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua index f23755ea..f300a61a 100644 --- a/script/core/diagnostics/code-after-break.lua +++ b/script/core/diagnostics/code-after-break.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' return function (uri, callback) local state = files.getAst(uri) diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua index 65099af8..ee245781 100644 --- a/script/core/diagnostics/count-down-loop.lua +++ b/script/core/diagnostics/count-down-loop.lua @@ -1,6 +1,6 @@ -local files = require "files" -local guide = require "core.guide" -local lang = require 'language' +local files = require "files" +local guide = require "parser.guide" +local lang = require 'language' return function (uri, callback) local state = files.getAst(uri) diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua index 60d60946..a6f8a47e 100644 --- a/script/core/diagnostics/deprecated.lua +++ b/script/core/diagnostics/deprecated.lua @@ -1,10 +1,10 @@ -local files = require 'files' -local vm = require 'vm' -local lang = require 'language' -local guide = require 'core.guide' -local config = require 'config' -local define = require 'proto.define' -local await = require 'await' +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local guide = require 'parser.guide' +local config = require 'config' +local define = require 'proto.define' +local await = require 'await' return function (uri, callback) local ast = files.getAst(uri) @@ -20,7 +20,7 @@ return function (uri, callback) return end if src.type == 'getglobal' then - local key = guide.getKeyName(src) + local key = src[1] if not key then return end @@ -34,11 +34,11 @@ return function (uri, callback) await.delay() - if not vm.isDeprecated(src, 0) then + if not vm.isDeprecated(src, true) then return end - local defs = vm.getDefs(src, 0) + local defs = vm.getDefs(src) local validVersions for _, def in ipairs(defs) do if def.bindDocs then diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua index 8c6696a9..daecb836 100644 --- a/script/core/diagnostics/duplicate-doc-class.lua +++ b/script/core/diagnostics/duplicate-doc-class.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' +local guide = require 'parser.guide' return function (uri, callback) local state = files.getAst(uri) @@ -20,7 +20,7 @@ return function (uri, callback) or doc.type == 'doc.alias' then local name = guide.getKeyName(doc) if not cache[name] then - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) cache[name] = {} for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' @@ -28,7 +28,7 @@ return function (uri, callback) cache[name][#cache[name]+1] = { start = otherDoc.start, finish = otherDoc.finish, - uri = guide.getUri(otherDoc), + uri = searcher.getUri(otherDoc), } end end diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua index 5e63d39e..d1ba9261 100644 --- a/script/core/diagnostics/duplicate-index.lua +++ b/script/core/diagnostics/duplicate-index.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' return function (uri, callback) local ast = files.getAst(uri) diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua index c1e2285a..e1883fe5 100644 --- a/script/core/diagnostics/duplicate-set-field.lua +++ b/script/core/diagnostics/duplicate-set-field.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local define = require 'proto.define' +local guide = require "parser.guide" return function (uri, callback) local ast = files.getAst(uri) @@ -30,7 +30,7 @@ return function (uri, callback) if not name then goto CONTINUE end - local value = guide.getObjectValue(nxt) + local value = searcher.getObjectValue(nxt) if not value or value.type ~= 'function' then goto CONTINUE end diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua index 690a4ca2..2024f4e3 100644 --- a/script/core/diagnostics/empty-block.lua +++ b/script/core/diagnostics/empty-block.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index de23bc76..9a0d4f35 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' -- TODO: 检查路径是否可达 diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua index a2b831f7..1d1ab9af 100644 --- a/script/core/diagnostics/init.lua +++ b/script/core/diagnostics/init.lua @@ -62,14 +62,11 @@ local function check(uri, name, results) end return function (uri, response) - local vm = require 'vm' local ast = files.getAst(uri) if not ast then return nil end - local isOpen = files.isOpen(uri) - for _, name in ipairs(diagList) do await.delay() local results = {} diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua index 9c094701..8c7ae793 100644 --- a/script/core/diagnostics/lowercase-global.lua +++ b/script/core/diagnostics/lowercase-global.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local config = require 'config' local vm = require 'vm' diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua index 0727c2fd..75681cbc 100644 --- a/script/core/diagnostics/newfield-call.lua +++ b/script/core/diagnostics/newfield-call.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua index 807f76a2..159a60c9 100644 --- a/script/core/diagnostics/newline-call.lua +++ b/script/core/diagnostics/newline-call.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua index ffaab821..23af570a 100644 --- a/script/core/diagnostics/no-implicit-any.lua +++ b/script/core/diagnostics/no-implicit-any.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -10,7 +10,7 @@ return function (uri, callback) return end - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) if source.type ~= 'local' and source.type ~= 'setlocal' and source.type ~= 'setglobal' diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua index 857d80d2..48093417 100644 --- a/script/core/diagnostics/redefined-local.lua +++ b/script/core/diagnostics/redefined-local.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' return function (uri, callback) @@ -13,6 +13,9 @@ return function (uri, callback) or name == ast.ENVMode then return end + if source.tag == 'self' then + return + end local exist = guide.getLocal(source, name, source.start-1) if exist then callback { diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index c5bcd5a5..eca7fc91 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' local define = require 'proto.define' @@ -84,7 +84,7 @@ return function (uri, callback) local funcArgs = cache[func] if funcArgs == nil then funcArgs = getFuncArgs(func) or false - local refs = vm.getRefs(func, 0) + local refs = vm.getRefs(func) for _, ref in ipairs(refs) do cache[ref] = funcArgs end diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua index 0a4b1d57..e54a6e60 100644 --- a/script/core/diagnostics/trailing-space.lua +++ b/script/core/diagnostics/trailing-space.lua @@ -1,6 +1,6 @@ local files = require 'files' local lang = require 'language' -local guide = require 'core.guide' +local guide = require 'parser.guide' local function isInString(ast, offset) local result = false diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua index b2b2800c..35aebb45 100644 --- a/script/core/diagnostics/unbalanced-assignments.lua +++ b/script/core/diagnostics/unbalanced-assignments.lua @@ -1,7 +1,7 @@ local files = require 'files' local define = require 'proto.define' local lang = require 'language' -local guide = require 'core.guide' +local guide = require 'parser.guide' return function (uri, callback, code) local ast = files.getAst(uri) diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua index a91cfa7f..d79f7ea4 100644 --- a/script/core/diagnostics/undefined-doc-class.lua +++ b/script/core/diagnostics/undefined-doc-class.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -25,7 +25,7 @@ return function (uri, callback) end for _, ext in ipairs(doc.extends) do local name = ext[1] - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) if cache[name] == nil then cache[name] = false for _, otherDoc in ipairs(docs) do diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index d8a4363b..871f16e1 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -22,7 +22,7 @@ return function (uri, callback) if classCache[name] ~= nil then return classCache[name] end - local docs = vm.getDocTypes(name) + local docs = vm.getDocDefines(name) for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.class.name' or otherDoc.type == 'doc.alias.name' then diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua index 0bf371e5..4a97947d 100644 --- a/script/core/diagnostics/undefined-doc-param.lua +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua index 89efb8c7..c97c3fe8 100644 --- a/script/core/diagnostics/undefined-env-child.lua +++ b/script/core/diagnostics/undefined-env-child.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local vm = require 'vm' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' +local lang = require 'language' return function (uri, callback) local ast = files.getAst(uri) @@ -13,7 +13,7 @@ return function (uri, callback) if source.node.tag == '_ENV' then return end - local defs = guide.requestDefinition(source) + local defs = searcher.requestDefinition(source) if #defs > 0 then return end diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index b10c9ab0..2d357d5b 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -2,9 +2,16 @@ local files = require 'files' local vm = require 'vm' local lang = require 'language' local config = require 'config' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' +local SkipCheckClass = { + ['unknown'] = true, + ['any'] = true, + ['table'] = true, + ['nil'] = true, +} + return function (uri, callback) local ast = files.getAst(uri) if not ast then @@ -18,7 +25,7 @@ return function (uri, callback) if cache[src] == nil then tracy.ZoneBeginN('undefined-field getInfers') infers = vm.getInfers(src, 0) or false - local refs = vm.getRefs(src, 0) + local refs = vm.getRefs(src) for _, ref in ipairs(refs) do cache[ref] = infers end @@ -47,7 +54,7 @@ return function (uri, callback) elseif inferSource.type == 'doc.class.name' then addTo(allDocClass, inferSource.parent) elseif inferSource.type == 'doc.type.name' then - local docTypes = vm.getDocTypes(inferSource[1]) + local docTypes = vm.getDocDefines(inferSource[1]) for _, docType in ipairs(docTypes) do if docType.type == 'doc.class.name' then addTo(allDocClass, docType.parent) @@ -65,7 +72,7 @@ return function (uri, callback) local empty = true for _, docClass in ipairs(allDocClass) do tracy.ZoneBeginN('undefined-field getDefFields') - local refs = vm.getDefFields(docClass) + local refs = vm.getDefs(docClass, '*') tracy.ZoneEnd() for _, ref in ipairs(refs) do @@ -87,35 +94,37 @@ return function (uri, callback) end local function checkUndefinedField(src) - local fieldName = guide.getKeyName(src) - - local allDocClass = getAllDocClassFromInfer(src.node) - if (not allDocClass) or (#allDocClass == 0) then - return - end - - local fields = getAllFieldsFromAllDocClass(allDocClass) - - -- 没找到任何 field,跳过检查 - if not fields then + if #vm.getDefs(src) > 0 then return end - - if not fields[fieldName] then - local message = lang.script('DIAG_UNDEF_FIELD', fieldName) - if src.type == 'getfield' and src.field then - callback { - start = src.field.start, - finish = src.field.finish, - message = message, - } - elseif src.type == 'getmethod' and src.method then - callback { - start = src.method.start, - finish = src.method.finish, - message = message, - } + local node = src.node + if node then + local defs = vm.getDefs(node) + local ok + for _, def in ipairs(defs) do + if def.type == 'doc.class.name' + and not SkipCheckClass[def[1]] then + ok = true + break + end end + if not ok then + return + end + end + local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src)) + if src.type == 'getfield' and src.field then + callback { + start = src.field.start, + finish = src.field.finish, + message = message, + } + elseif src.type == 'getmethod' and src.method then + callback { + start = src.method.start, + finish = src.method.finish, + message = message, + } end end guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField); diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index 161d8856..825b14f1 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -1,9 +1,8 @@ -local files = require 'files' -local vm = require 'vm' -local lang = require 'language' -local config = require 'config' -local guide = require 'core.guide' -local define = require 'proto.define' +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local config = require 'config' +local guide = require 'parser.guide' local requireLike = { ['include'] = true, @@ -20,7 +19,7 @@ return function (uri, callback) -- 遍历全局变量,检查所有没有 set 模式的全局变量 guide.eachSourceType(ast.ast, 'getglobal', function (src) - local key = guide.getKeyName(src) + local key = src[1] if not key then return end @@ -30,7 +29,11 @@ return function (uri, callback) if config.config.runtime.special[key] then return end - if #vm.getGlobalSets(key) == 0 then + local node = src.node + if node.tag ~= '_ENV' then + return + end + if #vm.getDefs(src) == 0 then local message = lang.script('DIAG_UNDEF_GLOBAL', key) if requireLike[key:lower()] then message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key)) diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua index b6f92e60..41c239f9 100644 --- a/script/core/diagnostics/unused-function.lua +++ b/script/core/diagnostics/unused-function.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local vm = require 'vm' local define = require 'proto.define' local lang = require 'language' diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua index e2d5e49a..e6d998ba 100644 --- a/script/core/diagnostics/unused-label.lua +++ b/script/core/diagnostics/unused-label.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua index fde90cb8..1a77a45f 100644 --- a/script/core/diagnostics/unused-local.lua +++ b/script/core/diagnostics/unused-local.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' @@ -87,6 +87,9 @@ return function (uri, callback) or name == ast.ENVMode then return end + if source.tag == 'self' then + return + end if isToBeClosed(source) then return end diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua index ec0a05fb..74cc08e7 100644 --- a/script/core/diagnostics/unused-vararg.lua +++ b/script/core/diagnostics/unused-vararg.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua index cc87e3ca..e36ba29b 100644 --- a/script/core/document-symbol.lua +++ b/script/core/document-symbol.lua @@ -1,8 +1,8 @@ -local await = require 'await' -local files = require 'files' -local guide = require 'core.guide' -local define = require 'proto.define' -local util = require 'utility' +local await = require 'await' +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local util = require 'utility' local function buildName(source, text) if source.type == 'setmethod' diff --git a/script/core/find-source.lua b/script/core/find-source.lua index b36306b6..edbb1e2c 100644 --- a/script/core/find-source.lua +++ b/script/core/find-source.lua @@ -1,4 +1,4 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local function isValidFunctionPos(source, offset) for i = 1, #source.keyword // 2 do diff --git a/script/core/folding.lua b/script/core/folding.lua index 15678995..1bbae944 100644 --- a/script/core/folding.lua +++ b/script/core/folding.lua @@ -1,5 +1,5 @@ local files = require "files" -local guide = require "core.guide" +local searcher = require "core.searcher" local util = require 'utility' local Care = { @@ -153,7 +153,7 @@ return function (uri) local regions = {} local status = {} - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) local tp = source.type if Care[tp] then Care[tp](source, text, regions) diff --git a/script/core/generic.lua b/script/core/generic.lua new file mode 100644 index 00000000..15950974 --- /dev/null +++ b/script/core/generic.lua @@ -0,0 +1,234 @@ +local guide = require 'parser.guide' +local noder = require "core.noder" + +---@class generic.value +---@field type string +---@field closure generic.closure +---@field proto parser.guide.object +---@field parent parser.guide.object + +---@class generic.closure +---@field type string +---@field proto parser.guide.object +---@field upvalues table<string, generic.value[]> +---@field params generic.value[] +---@field returns generic.value[] + +local m = {} + +---@param closure generic.closure +---@param proto parser.guide.object +local function instantValue(closure, proto) + ---@type generic.value + local value = { + type = 'generic.value', + closure = closure, + proto = proto, + parent = proto.parent, + } + closure.values[#closure.values+1] = value + return value +end + +---递归实例化对象 +---@param proto parser.guide.object +---@return generic.value +local function createValue(closure, proto, callback, road) + if callback then + road = road or {} + end + if proto.type == 'doc.type' then + local types = {} + local hasGeneric + for i, tp in ipairs(proto.types) do + local genericValue = createValue(closure, tp, callback, road) + if genericValue then + hasGeneric = true + types[i] = genericValue + else + types[i] = tp + end + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.types = types + noder.compileNode(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.name' then + if not proto.typeGeneric then + return nil + end + local key = proto[1] + local value = instantValue(closure, proto) + if callback then + callback(road, key, proto) + end + noder.compileNode(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.function' then + local hasGeneric + local args = {} + local returns = {} + for i, arg in ipairs(proto.args) do + local value = createValue(closure, arg, callback, road) + if value then + hasGeneric = true + end + args[i] = value or arg + end + for i, rtn in ipairs(proto.returns) do + local value = createValue(closure, rtn, callback, road) + if value then + hasGeneric = true + end + returns[i] = value or rtn + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.args = args + value.returns = returns + value.isGeneric = true + noder.pushSource(noder.getNoders(proto), value) + return value + end + if proto.type == 'doc.type.array' then + if road then + road[#road+1] = noder.ANY_FIELD + end + local node = createValue(closure, proto.node, callback, road) + if road then + road[#road] = nil + end + if not node then + return nil + end + local value = instantValue(closure, proto) + value.node = node + return value + end + if proto.type == 'doc.type.table' then + road[#road+1] = noder.TABLE_KEY + local tkey = createValue(closure, proto.tkey, callback, road) + road[#road] = nil + + road[#road+1] = noder.ANY_FIELD + local tvalue = createValue(closure, proto.tvalue, callback, road) + road[#road] = nil + + if not tkey and not tvalue then + return nil + end + local value = instantValue(closure, proto) + value.tkey = tkey or proto.tkey + value.tvalue = tvalue or proto.tvalue + return value + end +end + +local function buildValue(road, key, proto, param, upvalues) + local paramID + if proto.literal then + local str = param.type == 'string' and param[1] + if not str then + return + end + paramID = 'dn:' .. str + else + paramID = noder.getID(param) + end + if not paramID then + return + end + local myUri = guide.getUri(param) + local myHead = noder.URI_CHAR .. myUri .. noder.URI_CHAR + paramID = myHead .. paramID + if not upvalues[key] then + upvalues[key] = {} + end + upvalues[key][#upvalues[key]+1] = paramID .. table.concat(road) +end + +-- 为所有的 param 与 return 创建副本 +---@param closure generic.closure +local function buildValues(closure) + local protoFunction = closure.proto + local upvalues = closure.upvalues + local params = closure.call.args + + if protoFunction.type == 'function' then + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.param' then + local extends = doc.extends + local index = extends.paramIndex + if index then + local param = params and params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + end + end + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + closure.returns[rtn.returnIndex] = createValue(closure, rtn) or rtn + end + end + end + end + if protoFunction.type == 'doc.type.function' then + for index, arg in ipairs(protoFunction.args) do + local extends = arg.extends + local param = params and params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + for index, rtn in ipairs(protoFunction.returns) do + closure.returns[index] = createValue(closure, rtn) or rtn + end + end +end + +---创建一个闭包 +---@param proto parser.guide.object|generic.value # 原型函数|泛型值 +---@return generic.closure +function m.createClosure(proto, call) + local protoFunction, parentClosure + if proto.type == 'function' then + protoFunction = proto + elseif proto.type == 'doc.type.function' then + protoFunction = proto + elseif proto.type == 'generic.value' then + protoFunction = proto.proto + parentClosure = proto.closure + end + ---@type generic.closure + local closure = { + type = 'generic.closure', + parent = protoFunction.parent, + proto = protoFunction, + call = call, + upvalues = parentClosure and parentClosure.upvalues or {}, + params = {}, + returns = {}, + values = {}, + } + buildValues(closure) + + if #closure.returns == 0 then + return nil + end + + noder.compileNode(noder.getNoders(proto), closure) + + return closure +end + +return m diff --git a/script/core/guide.lua b/script/core/guide2.lua index e4871060..576c0c20 100644 --- a/script/core/guide.lua +++ b/script/core/guide2.lua @@ -292,8 +292,13 @@ end ---@param obj parser.guide.object ---@return parser.guide.object function m.getRoot(obj) + local source = obj + if source._root then + return source._root + end for _ = 1, 1000 do if obj.type == 'main' then + source._root = obj return obj end local parent = obj.parent diff --git a/script/core/highlight.lua b/script/core/highlight.lua index 12ec114f..45001134 100644 --- a/script/core/highlight.lua +++ b/script/core/highlight.lua @@ -1,31 +1,18 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local files = require 'files' local vm = require 'vm' local define = require 'proto.define' local findSource = require 'core.find-source' local util = require 'utility' +local guide = require 'parser.guide' local function eachRef(source, callback) - local results = guide.requestReference(source) + local results = searcher.requestReference(source) for i = 1, #results do callback(results[i]) end end -local function eachField(source, callback) - if not source then - return - end - local isGlobal = guide.isGlobal(source) - local results = guide.requestReference(source) - for i = 1, #results do - local res = results[i] - if isGlobal == guide.isGlobal(res) then - callback(res) - end - end -end - local function eachLocal(source, callback) callback(source) if source.ref then @@ -43,21 +30,21 @@ local function find(source, uri, callback) eachLocal(source.node, callback) elseif source.type == 'field' or source.type == 'method' then - eachField(source.parent, callback) + eachRef(source.parent, callback) elseif source.type == 'getindex' or source.type == 'setindex' or source.type == 'tableindex' then - eachField(source, callback) + eachRef(source, callback) elseif source.type == 'setglobal' or source.type == 'getglobal' then - eachField(source, callback) + eachRef(source, callback) elseif source.type == 'goto' or source.type == 'label' then eachRef(source, callback) elseif source.type == 'string' and source.parent and source.parent.index == source then - eachField(source.parent, callback) + eachRef(source.parent, callback) elseif source.type == 'string' or source.type == 'boolean' or source.type == 'number' @@ -238,6 +225,16 @@ local accept = { ['nil'] = true, } +local function isLiteralValue(source) + if not guide.isLiteral(source) then + return false + end + if source.parent.index == source then + return false + end + return true +end + return function (uri, offset) local ast = files.getAst(uri) if not ast then @@ -249,10 +246,25 @@ return function (uri, offset) local source = findSource(ast, offset, accept) if source then + local isGlobal = guide.isGlobal(source) + local isLiteral = isLiteralValue(source) find(source, uri, function (target) + if not target then + return + end if target.dummy then return end + if mark[target] then + return + end + mark[target] = true + if isGlobal ~= guide.isGlobal(target) then + return + end + if isLiteral ~= isLiteralValue(target) then + return + end local kind if target.type == 'getfield' then target = target.field @@ -315,13 +327,6 @@ return function (uri, offset) else return end - if not target then - return - end - if mark[target] then - return - end - mark[target] = true results[#results+1] = { start = target.start, finish = target.finish, diff --git a/script/core/hint.lua b/script/core/hint.lua index 13d01dc7..43b8726e 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local vm = require 'vm' -local config = require 'config' +local files = require 'files' +local searcher = require 'core.searcher' +local vm = require 'vm' +local config = require 'config' local function typeHint(uri, edits, start, finish) local ast = files.getAst(uri) @@ -9,7 +9,7 @@ local function typeHint(uri, edits, start, finish) return end local mark = {} - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) if source.type ~= 'local' and source.type ~= 'setglobal' and source.type ~= 'tablefield' @@ -21,7 +21,7 @@ local function typeHint(uri, edits, start, finish) if source[1] == '_' then return end - if source.value and guide.isLiteral(source.value) then + if source.value and searcher.isLiteral(source.value) then return end if source.parent.type == 'funcargs' then @@ -84,7 +84,7 @@ local function hasLiteralArgInCall(call) return false end for _, arg in ipairs(call.args) do - if guide.isLiteral(arg) then + if searcher.isLiteral(arg) then return true end end @@ -100,14 +100,14 @@ local function paramName(uri, edits, start, finish) return end local mark = {} - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) if source.type ~= 'call' then return end if not hasLiteralArgInCall(source) then return end - local defs = vm.getDefs(source.node, 0) + local defs = vm.getDefs(source.node) if not defs then return end @@ -130,7 +130,7 @@ local function paramName(uri, edits, start, finish) table.remove(args, 1) end for i, arg in ipairs(source.args) do - if not mark[arg] and guide.isLiteral(arg) then + if not mark[arg] and searcher.isLiteral(arg) then mark[arg] = true if args[i] and args[i] ~= '' then edits[#edits+1] = { diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua index 324d28af..822be2b6 100644 --- a/script/core/hover/arg.lua +++ b/script/core/hover/arg.lua @@ -1,4 +1,5 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' +local infer = require 'core.infer' local vm = require 'vm' local function optionalArg(arg) @@ -21,7 +22,7 @@ local function asFunction(source, oop) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(vm.getInferType(parent.node)) + args[#args+1] = ('self: %s'):format(infer.searchAndViewInfers(parent.node)) end if source.args then for i = 1, #source.args do @@ -34,10 +35,12 @@ local function asFunction(source, oop) args[#args+1] = ('%s%s: %s'):format( name, optionalArg(arg) and '?' or '', - vm.getInferType(arg) + infer.searchAndViewInfers(arg) ) + elseif arg.type == '...' then + args[#args+1] = '...' else - args[#args+1] = ('%s'):format(vm.getInferType(arg)) + args[#args+1] = ('%s'):format(infer.searchAndViewInfers(arg)) end ::CONTINUE:: end @@ -61,7 +64,7 @@ local function asDocFunction(source) args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - vm.getInferType(arg.extends) + infer.searchAndViewInfers(arg.extends) ) else args[i] = ('%s%s'):format( diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index 401ca5a7..bcc3065a 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -2,11 +2,13 @@ local vm = require 'vm' local ws = require 'workspace' local furi = require 'file-uri' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local markdown = require 'provider.markdown' local config = require 'config' local lang = require 'language' local util = require 'utility' +local guide = require 'parser.guide' +local noder = require 'core.noder' local function asStringInRequire(source, literal) local rootPath = ws.path or '' @@ -124,10 +126,10 @@ local function getBindComment(source, docGroup, base) end local function tryDocClassComment(source) - for _, def in ipairs(vm.getDefs(source, 0)) do + for _, def in ipairs(vm.getDefs(source)) do if def.type == 'doc.class.name' or def.type == 'doc.alias.name' then - local class = guide.getDocState(def) + local class = noder.getDocState(def) local comment = getBindComment(class, class.bindGroup, class) if comment then return comment @@ -180,7 +182,7 @@ local function isFunction(source) if source.type == 'function' then return true end - local value = guide.getObjectValue(source) + local value = searcher.getObjectValue(source) if not value then return false end @@ -223,13 +225,14 @@ local function getBindEnums(source, docGroup) end local function tryDocFieldUpComment(source) - if source.type ~= 'doc.field' then + if source.type ~= 'doc.field.name' then return end - if not source.bindGroup then + local docField = source.parent + if not docField.bindGroup then return end - local comment = getBindComment(source, source.bindGroup, source) + local comment = getBindComment(docField, docField.bindGroup, docField) return comment end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index 0c8644ed..5dd00c43 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local getLabel = require 'core.hover.label' local getDesc = require 'core.hover.description' @@ -7,6 +7,7 @@ local util = require 'utility' local findSource = require 'core.find-source' local lang = require 'language' local markdown = require 'provider.markdown' +local infer = require 'core.infer' local function eachFunctionAndOverload(value, callback) callback(value) @@ -24,7 +25,7 @@ local function getHoverAsValue(source) local label = getLabel(source) local desc = getDesc(source) if not desc then - local values = vm.getDefs(source, 0) + local values = vm.getDefs(source) for _, def in ipairs(values) do desc = getDesc(def) if desc then @@ -40,7 +41,7 @@ local function getHoverAsValue(source) end local function getHoverAsFunction(source) - local values = vm.getDefs(source, 0) + local values = vm.getDefs(source) local desc = getDesc(source) local labels = {} local defs = 0 @@ -48,7 +49,7 @@ local function getHoverAsFunction(source) local other = 0 local mark = {} for _, def in ipairs(values) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def.type == 'function' or def.type == 'doc.type.function' then eachFunctionAndOverload(def, function (value) @@ -123,7 +124,7 @@ local function getHover(source) if source.type == 'doc.type.name' then return getHoverAsDocName(source) end - local isFunction = vm.hasInferType(source, 'function', 0) + local isFunction = infer.hasType(source, 'function') if isFunction then return getHoverAsFunction(source) else diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index d93b14e3..032f19c0 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -2,9 +2,10 @@ local buildName = require 'core.hover.name' local buildArg = require 'core.hover.arg' local buildReturn = require 'core.hover.return' local buildTable = require 'core.hover.table' +local infer = require 'core.infer' local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local config = require 'config' local files = require 'files' @@ -31,29 +32,28 @@ local function asDocFunction(source) end local function asDocTypeName(source) - for _, doc in ipairs(vm.getDocTypes(source[1])) do + local defs = searcher.requestDefinition(source) + for _, doc in ipairs(defs) do if doc.type == 'doc.class.name' then - return 'class ' .. source[1] + return 'class ' .. doc[1] end if doc.type == 'doc.alias.name' then local extends = doc.parent.extends - return lang.script('HOVER_EXTENDS', vm.getInferType(extends)) + return lang.script('HOVER_EXTENDS', infer.searchAndViewInfers(extends)) end end end local function asValue(source, title) local name = buildName(source) - local infers = vm.getInfers(source, 0) - local type = vm.getInferType(source, 0) - local class = vm.getClass(source, 0) - local literal = vm.getInferLiteral(source, 0) + local type = infer.searchAndViewInfers(source) + local literal = infer.searchAndViewLiterals(source) local cont - if not vm.hasInferType(source, 'string', 0) + if not infer.hasType(source, 'string') and not type:find('%[%]$') and not type:find('%w%<') then - if #vm.getFields(source, 0) > 0 - or vm.hasInferType(source, 'table', 0) then + if #vm.getRefs(source, '*') > 0 + or infer.hasType(source, 'table') then cont = buildTable(source) end end @@ -66,11 +66,7 @@ local function asValue(source, title) or type == 'nil') then type = nil end - if class then - pack[#pack+1] = class - else - pack[#pack+1] = type - end + pack[#pack+1] = type if literal then pack[#pack+1] = '=' pack[#pack+1] = literal @@ -123,30 +119,21 @@ local function asField(source) return asValue(source, 'field') end -local function asDocField(source) - local name = source.field[1] +local function asDocFieldName(source) + local name = source[1] + local docField = source.parent local class - for _, doc in ipairs(source.bindGroup) do + for _, doc in ipairs(docField.bindGroup) do if doc.type == 'doc.class' then class = doc break end end - local infers = {} - for _, infer in ipairs(vm.getInfers(source.extends) or {}) do - infers[#infers+1] = infer - end + local view = infer.searchAndViewInfers(docField.extends) if not class then - return ('field ?.%s: %s'):format( - name, - guide.viewInferType(infers) - ) - end - return ('field %s.%s: %s'):format( - class.class[1], - name, - guide.viewInferType(infers) - ) + return ('field ?.%s: %s'):format(name, view) + end + return ('field %s.%s: %s'):format(class.class[1], name, view) end local function asString(source) @@ -177,7 +164,7 @@ local function asNumber(source) if type(num) ~= 'number' then return nil end - local uri = guide.getUri(source) + local uri = searcher.getUri(source) local text = files.getText(uri) if not text then return nil @@ -215,7 +202,7 @@ return function (source, oop) return asDocFunction(source) elseif source.type == 'doc.type.name' then return asDocTypeName(source) - elseif source.type == 'doc.field' then - return asDocField(source) + elseif source.type == 'doc.field.name' then + return asDocFieldName(source) end end diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua index d583f1e1..d2b9d30b 100644 --- a/script/core/hover/name.lua +++ b/script/core/hover/name.lua @@ -1,4 +1,6 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' +local infer = require 'core.infer' +local guide = require 'parser.guide' local vm = require 'vm' local buildName @@ -19,7 +21,7 @@ end local function asField(source, oop) local class if source.node.type ~= 'getglobal' then - class = vm.getClass(source.node, 0) + class = infer.getClass(source.node) end local node = class or guide.getKeyName(source.node) or '?' local method = guide.getKeyName(source) diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index c3e9656d..0f0d85e0 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,12 +1,4 @@ -local guide = require 'core.guide' -local vm = require 'vm' - -local function mergeTypes(returns) - if type(returns) == 'string' then - return returns - end - return guide.mergeTypes(returns) -end +local infer = require 'core.infer' local function getReturnDualByDoc(source) local docs = source.bindDocs @@ -55,24 +47,20 @@ local function asFunction(source) local returns = {} for i, rtn in ipairs(dual) do local line = {} - local types = {} + local infers = {} if i == 1 then line[#line+1] = ' -> ' else line[#line+1] = ('% 3d. '):format(i) end for n = 1, #rtn do - local values = vm.getInfers(rtn[n]) - for _, value in ipairs(values) do - if value.type then - for tp in value.type:gmatch '[^|]+' do - types[tp] = true - end - end + local values = infer.searchInfers(rtn[n]) + for tp in pairs(values) do + infers[tp] = true end end - if next(types) or rtn[1] then - local tp = mergeTypes(types) or 'any' + if next(infers) or rtn[1] then + local tp = infer.viewInfers(infers) if rtn[1].name then line[#line+1] = ('%s%s: %s'):format( rtn[1].name[1], @@ -103,7 +91,7 @@ local function asDocFunction(source) local returns = {} for i, rtn in ipairs(source.returns) do local rtnText = ('%s%s'):format( - vm.getInferType(rtn), + infer.searchAndViewInfers(rtn), rtn.optional and '?' or '' ) if i == 1 then diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index edb7751b..159453e6 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -1,26 +1,12 @@ local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local config = require 'config' local lang = require 'language' +local infer = require 'core.infer' -local function getKey(src) - local key = vm.getKeyName(src) - if not key or #key <= 0 then - if not src.index then - return '[any]' - end - local class = vm.getClass(src.index) - if class then - return ('[%s]'):format(class) - end - local tp = vm.getInferType(src.index) - if tp then - return ('[%s]'):format(tp) - end - return '[any]' - end - if guide.getKeyType(src) == 'string' then +local function formatKey(key) + if type(key) == 'string' then if key:match '^[%a_][%w_]*$' then return key else @@ -30,104 +16,16 @@ local function getKey(src) return ('[%s]'):format(key) end -local function getFieldFull(src) - local value = guide.getObjectValue(src) or src - local tp = vm.getInferType(value, 0) - --local class = vm.getClass(src) - local literal = vm.getInferLiteral(value) - if type(literal) == 'string' and #literal >= 50 then - literal = literal:sub(1, 47) .. '...' - end - return tp, literal -end - -local function getFieldFast(src) - if src.bindDocs then - return getFieldFull(src) - end - local value = guide.getObjectValue(src) or src - if not value then - return 'any' - end - if value.type == 'boolean' then - return value.type, util.viewLiteral(value[1]) - end - if value.type == 'number' - or value.type == 'integer' then - if math.tointeger(value[1]) then - if config.config.runtime.version == 'Lua 5.3' - or config.config.runtime.version == 'Lua 5.4' then - return 'integer', util.viewLiteral(value[1]) - end - end - return value.type, util.viewLiteral(value[1]) - end - if value.type == 'table' - or value.type == 'function' then - return value.type - end - if value.type == 'string' then - local literal = value[1] - if type(literal) == 'string' and #literal >= 50 then - literal = literal:sub(1, 47) .. '...' - end - return value.type, util.viewLiteral(literal) - end - if value.type == 'doc.field' then - return vm.getInferType(value) - end -end - -local function getField(src, timeUp, mark, key) - if src.type == 'table' - or src.type == 'function' then - return nil - end - if src.parent then - if src.type == 'string' - or src.type == 'boolean' - or src.type == 'number' - or src.type == 'integer' then - if src.parent.type == 'tableindex' - or src.parent.type == 'setindex' - or src.parent.type == 'getindex' then - if src.parent.index == src then - src = src.parent - end - end - end - end - local tp, literal - tp, literal = getFieldFast(src) - if tp then - return tp, literal - end - if timeUp or mark[key] then - return nil - end - mark[key] = true - tp, literal = getFieldFull(src) - if tp then - return tp, literal - end - return nil -end - -local function buildAsHash(classes, literals, reachMax) - local keys = {} - for k in pairs(classes) do - keys[#keys+1] = k - end - table.sort(keys) +local function buildAsHash(keys, inferMap, literalMap, reachMax) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local class = classes[key] - local literal = literals[key] - if literal then - lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + local inferView = inferMap[key] + local literalView = literalMap[key] + if literalView then + lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView) else - lines[#lines+1] = (' %s: %s,'):format(key, class) + lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView) end end if reachMax then @@ -137,23 +35,19 @@ local function buildAsHash(classes, literals, reachMax) return table.concat(lines, '\n') end -local function buildAsConst(classes, literals, reachMax) - local keys = {} - for k in pairs(classes) do - keys[#keys+1] = k - end +local function buildAsConst(keys, inferMap, literalMap, reachMax) table.sort(keys, function (a, b) - return tonumber(literals[a]) < tonumber(literals[b]) + return tonumber(literalMap[a]) < tonumber(literalMap[b]) end) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local class = classes[key] - local literal = literals[key] - if literal then - lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + local inferView = inferMap[key] + local literalView = literalMap[key] + if literalView then + lines[#lines+1] = (' %s: %s = %s,'):format(formatKey(key), inferView, literalView) else - lines[#lines+1] = (' %s: %s,'):format(key, class) + lines[#lines+1] = (' %s: %s,'):format(formatKey(key), inferView) end end if reachMax then @@ -163,111 +57,79 @@ local function buildAsConst(classes, literals, reachMax) return table.concat(lines, '\n') end -local function mergeLiteral(literals) - local results = {} +local typeSorter = { + ['string'] = 1, + ['number'] = 2, + ['boolean'] = 3, +} + +local function getKeyMap(fields) + local keys = {} local mark = {} - for _, value in ipairs(literals) do - if not mark[value] then - mark[value] = true - results[#results+1] = value + for _, field in ipairs(fields) do + local key = vm.getKeyName(field) + local tp = vm.getKeyType(field) + if tp == 'number' then + key = tonumber(key) + elseif tp == 'boolean' then + key = key == 'true' end - end - if #results == 0 then - return nil - end - table.sort(results) - return table.concat(results, '|') -end - -local function mergeTypes(types) - local results = {} - local mark = { - -- 讲道理table的keyvalue不会是nil - ['nil'] = true, - } - for _, tv in ipairs(types) do - for tp in tv:gmatch '[^|]+' do - if not mark[tp] then - mark[tp] = true - results[tp] = true - end + if key and not mark[key] then + mark[key] = true + keys[#keys+1] = key end end - return guide.mergeTypes(results) -end - -local function clearClasses(classes) - classes['[nil]'] = nil - classes['[any]'] = nil - classes['[string]'] = nil + table.sort(keys, function (a, b) + local ta = typeSorter[type(a)] + local tb = typeSorter[type(b)] + if ta == tb then + return tostring(a) < tostring(b) + else + return ta < tb + end + end) + return keys end return function (source) - if config.config.hover.previewFields <= 0 then + local maxFields = config.config.hover.previewFields + if maxFields <= 0 then return 'table' end - local literals = {} - local classes = {} - local clock = os.clock() - local timeUp - local mark = {} - local fields = vm.getFields(source, 0) - local keyCount = 0 - local reachMax - for _, src in ipairs(fields) do - local key = getKey(src) - if not key then - goto CONTINUE - end - if not classes[key] then - classes[key] = {} - keyCount = keyCount + 1 - end - if not literals[key] then - literals[key] = {} - end - if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then - timeUp = true - end - local class, literal = getField(src, timeUp, mark, key) - if literal == 'nil' then - literal = nil - end - classes[key][#classes[key]+1] = class - literals[key][#literals[key]+1] = literal - if keyCount >= config.config.hover.previewFields then - reachMax = true - break - end - ::CONTINUE:: - end - - clearClasses(classes) - for key, class in pairs(classes) do - literals[key] = mergeLiteral(literals[key]) - classes[key] = mergeTypes(class) - end + local fields = vm.getRefs(source, '*') + local keys = getKeyMap(fields) - if not next(classes) then + if #keys == 0 then return '{}' end - local intValue = true - for key, class in pairs(classes) do - if class ~= 'integer' or not tonumber(literals[key]) then - intValue = false - break + local inferMap = {} + local literalMap = {} + + local reachMax = maxFields < #keys + + local isConsts = true + for i = 1, math.min(maxFields, #keys) do + local key = keys[i] + inferMap[key] = infer.searchAndViewInfers(source, key) + literalMap[key] = infer.searchAndViewLiterals(source, key) + if not tonumber(literalMap[key]) then + isConsts = false end end + local result - if intValue then - result = buildAsConst(classes, literals, reachMax) + + if isConsts then + result = buildAsConst(keys, inferMap, literalMap, reachMax) else - result = buildAsHash(classes, literals, reachMax) - end - if timeUp then - result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + result = buildAsHash(keys, inferMap, literalMap, reachMax) end + + --if timeUp then + -- result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + --end + return result end diff --git a/script/core/infer.lua b/script/core/infer.lua new file mode 100644 index 00000000..77236811 --- /dev/null +++ b/script/core/infer.lua @@ -0,0 +1,634 @@ +local searcher = require 'core.searcher' +local config = require 'config' +local noder = require 'core.noder' +local util = require 'utility' + +local STRING_OR_TABLE = {'STRING_OR_TABLE'} +local BE_RETURN = {'BE_RETURN'} +local BE_CONNACT = {'BE_CONNACT'} +local CLASS = {'CLASS'} +local TABLE = {'TABLE'} + +local TypeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['true'] = 101, + ['false'] = 102, + ['nil'] = 999, +} + +local m = {} + +local function mergeTable(a, b) + if not b then + return + end + for v in pairs(b) do + a[v] = true + end +end + +local function searchInferOfUnary(value, infers) + local op = value.op.type + if op == 'not' then + infers['boolean'] = true + return + end + if op == '#' then + infers['integer'] = true + return + end + if op == '-' then + if m.hasType(value[1], 'integer') then + infers['integer'] = true + else + infers['number'] = true + end + return + end + if op == '~' then + infers['integer'] = true + return + end +end + +local function searchInferOfBinary(value, infers) + local op = value.op.type + if op == 'and' then + if m.isTrue(value[1]) then + mergeTable(infers, m.searchInfers(value[2])) + else + mergeTable(infers, m.searchInfers(value[1])) + end + return + end + if op == 'or' then + if m.isTrue(value[1]) then + mergeTable(infers, m.searchInfers(value[1])) + else + mergeTable(infers, m.searchInfers(value[2])) + end + return + end + if op == '==' + or op == '~=' + or op == '<' + or op == '>' + or op == '<=' + or op == '>=' then + infers['boolean'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '&' + or op == '|' then + infers['integer'] = true + return + end + if op == '..' then + infers['string'] = true + return + end + if op == '^' + or op == '/' then + infers['number'] = true + return + end + if op == '+' + or op == '-' + or op == '*' + or op == '%' + or op == '//' then + if m.hasType(value[1], 'integer') + and m.hasType(value[2], 'integer') then + infers['integer'] = true + else + infers['number'] = true + end + return + end +end + +local function searchInferOfValue(value, infers) + if value.type == 'string' then + infers['string'] = true + return true + end + if value.type == 'boolean' then + infers['boolean'] = true + return true + end + if value.type == 'table' then + if value.array then + local node = m.searchAndViewInfers(value.array) + local infer = node .. '[]' + infers[infer] = true + else + infers['table'] = true + end + return true + end + if value.type == 'number' then + if math.type(value[1]) == 'integer' then + infers['integer'] = true + else + infers['number'] = true + end + return true + end + if value.type == 'nil' then + infers['nil'] = true + return true + end + if value.type == 'function' then + infers['function'] = true + return true + end + if value.type == 'unary' then + searchInferOfUnary(value, infers) + return true + end + if value.type == 'binary' then + searchInferOfBinary(value, infers) + return true + end + return false +end + +local function searchLiteralOfValue(value, literals) + if value.type == 'string' + or value.type == 'boolean' + or value.type == 'number' + or value.type == 'integer' then + local v = value[1] + if v ~= nil then + literals[v] = true + end + return + end + if value.type == 'unary' then + local op = value.op.type + if op == '-' then + local subLiterals = m.searchLiterals(value[1]) + if subLiterals then + for subLiteral in pairs(subLiterals) do + local num = tonumber(subLiteral) + if num then + literals[-num] = true + end + end + end + end + if op == '~' then + local subLiterals = m.searchLiterals(value[1]) + if subLiterals then + for subLiteral in pairs(subLiterals) do + local num = math.tointeger(subLiteral) + if num then + literals[~num] = true + end + end + end + end + end + return +end + +local function bindClassOrType(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' + or doc.type == 'doc.type' then + return true + end + end + return false +end + +local function cleanInfers(infers) + local version = config.config.runtime.version + local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4' + infers['unknown'] = nil + if infers['any'] and infers['nil'] then + infers['nil'] = nil + end + if infers['number'] then + enableInteger = false + end + if not enableInteger and infers['integer'] then + infers['integer'] = nil + infers['number'] = true + end + -- stringlib 就是 string + if infers['stringlib'] and infers['string'] then + infers['stringlib'] = nil + end + -- 如果是通过 .. 来推测的,且结果里没有 number 与 integer,则推测为string + if infers[BE_CONNACT] then + infers[BE_CONNACT] = nil + if not infers['number'] and not infers['integer'] then + infers['string'] = true + end + end + -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型 + if infers[STRING_OR_TABLE] then + infers[STRING_OR_TABLE] = nil + if not infers['table'] and not infers['string'] then + infers['table'] = true + infers['string'] = true + end + end + -- 如果有doc标记,则先移除table类型 + if infers[CLASS] then + infers[CLASS] = nil + infers['table'] = nil + end + -- 用doc标记的table,加入table类型 + if infers[TABLE] then + infers[TABLE] = nil + infers['table'] = true + end + if infers[BE_RETURN] then + infers[BE_RETURN] = nil + infers['nil'] = nil + end + infers['any'] = nil +end + +---合并对象的推断类型 +---@param infers string[] +---@return string +function m.viewInfers(infers) + if infers[0] then + return infers[0] + end + -- 如果有显性的 any ,则直接显示为 any + if infers['any'] then + infers[0] = 'any' + return 'any' + end + local result = {} + local count = 0 + for infer in pairs(infers) do + count = count + 1 + result[count] = infer + end + -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any + if count == 0 then + infers[0] = 'any' + return 'any' + end + table.sort(result, function (a, b) + local sa = TypeSort[a] or 100 + local sb = TypeSort[b] or 100 + if sa == sb then + return a < b + else + return sa < sb + end + end) + infers[0] = table.concat(result, '|') + return infers[0] +end + +---合并对象的值 +---@param literals string[] +---@return string +function m.viewLiterals(literals) + local result = {} + local count = 0 + for infer in pairs(literals) do + count = count + 1 + result[count] = util.viewLiteral(infer) + end + if count == 0 then + return nil + end + table.sort(result) + local view = table.concat(result, '|') + return view +end + +function m.viewDocName(doc) + if not doc then + return nil + end + if doc.type == 'doc.type' then + local list = {} + for _, tp in ipairs(doc.types) do + list[#list+1] = m.getDocName(tp) + end + for _, enum in ipairs(doc.enums) do + list[#list+1] = m.getDocName(enum) + end + return table.concat(list, '|') + end + return m.getDocName(doc) +end + +function m.getDocName(doc) + if not doc then + return nil + end + if doc.type == 'doc.class.name' + or doc.type == 'doc.type.name' then + local name = doc[1] or '?' + if doc.typeGeneric then + return '<' .. name .. '>' + else + return name + end + end + if doc.type == 'doc.type.array' then + local nodeName = m.viewDocName(doc.node) or '?' + return nodeName .. '[]' + end + if doc.type == 'doc.type.table' then + local key = m.viewDocName(doc.tkey) or '?' + local value = m.viewDocName(doc.tvalue) or '?' + return ('table<%s, %s>'):format(key, value) + end + if doc.type == 'doc.type.function' then + return 'function' + end + if doc.type == 'doc.type.enum' + or doc.type == 'doc.resume' then + local value = doc[1] or '?' + return value + end +end + +function m.viewDocFunction(doc) + if doc.type ~= 'doc.type.function' then + return '' + end + local args = {} + for i, arg in ipairs(doc.args) do + args[i] = ('%s: %s'):format(arg.name[1], m.viewDocName(arg.extends)) + end + local label = ('fun(%s)'):format(table.concat(args, ', ')) + if #doc.returns > 0 then + local returns = {} + for i, rtn in ipairs(doc.returns) do + returns[i] = m.viewDocName(rtn) + end + label = ('%s:%s'):format(label, table.concat(returns)) + end + return label +end + +---显示对象的推断类型 +---@param source parser.guide.object +---@return string +local function searchInfer(source, infers) + if bindClassOrType(source) then + return + end + if searchInferOfValue(source, infers) then + return + end + local value = searcher.getObjectValue(source) + if value then + searchInferOfValue(value, infers) + return + end + -- check LuaDoc + local docName = m.getDocName(source) + if docName then + infers[docName] = true + if docName ~= 'unknown' then + infers[CLASS] = true + end + if docName == 'table' then + infers[TABLE] = true + end + end + if source.parent.type == 'unary' then + local op = source.parent.op.type + -- # XX -> string | table + if op == '#' then + infers[STRING_OR_TABLE] = true + return + end + if op == '-' then + infers['number'] = true + return + end + if op == '~' then + infers['integer'] = true + return + end + return + end + if source.parent.type == 'binary' then + local op = source.parent.op.type + if op == '+' + or op == '-' + or op == '*' + or op == '/' + or op == '//' + or op == '^' + or op == '%' then + infers['number'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '|' + or op == '&' then + infers['integer'] = true + return + end + if op == '..' then + infers[BE_CONNACT] = true + return + end + end + -- X.a -> table + if source.next and source.next.node == source then + if source.next.type == 'setfield' + or source.next.type == 'setindex' + or source.next.type == 'setmethod' + or source.next.type == 'getfield' + or source.next.type == 'getindex' then + infers['table'] = true + end + if source.next.type == 'getmethod' then + infers[STRING_OR_TABLE] = true + end + end + -- return XX + if source.parent.type == 'return' then + infers[BE_RETURN] = true + end +end + +local function searchLiteral(source, literals) + local value = searcher.getObjectValue(source) + if value then + searchLiteralOfValue(value, literals) + return + end +end + +---搜索对象的推断类型 +---@param source parser.guide.object +---@param field? string +---@return string[] +function m.searchInfers(source, field) + if not source then + return nil + end + local defs = searcher.requestDefinition(source, field) + local infers = {} + local mark = {} + if not field then + mark[source] = true + searchInfer(source, infers) + local id = noder.getID(source) + if id then + local node = noder.getNodeByID(source, id) + if node and node.sources then + for _, src in ipairs(node.sources) do + if not mark[src] then + mark[src] = true + searchInfer(src, infers) + end + end + end + end + end + if source.type == 'field' or source.type == 'method' then + mark[source.parent] = true + searchInfer(source.parent, infers) + end + for _, def in ipairs(defs) do + if not mark[def] then + mark[def] = true + searchInfer(def, infers) + end + end + if source.docParam then + local docType = source.docParam.extends + if docType.type == 'doc.type' then + for _, def in ipairs(docType.types) do + if def.typeGeneric and not mark[def] then + mark[def] = true + searchInfer(def, infers) + end + end + end + end + if source.type == 'doc.type' then + if source.type == 'doc.type' then + for _, def in ipairs(source.types) do + if def.typeGeneric and not mark[def] then + mark[def] = true + searchInfer(def, infers) + end + end + end + end + cleanInfers(infers) + return infers +end + +---搜索对象的字面量值 +---@param source parser.guide.object +---@param field? string +---@return table +function m.searchLiterals(source, field) + local defs = searcher.requestDefinition(source, field) + local literals = {} + local mark = {} + if not field then + mark[source] = true + searchLiteral(source, literals) + end + for _, def in ipairs(defs) do + if not mark[def] then + mark[def] = true + searchLiteral(def, literals) + end + end + return literals +end + +---搜索并显示推断值 +---@param source parser.guide.object +---@param field? string +---@return string +function m.searchAndViewLiterals(source, field) + if not source then + return nil + end + local literals = m.searchLiterals(source, field) + local view = m.viewLiterals(literals) + return view +end + +---判断对象的推断值是否是 true +---@param source parser.guide.object +function m.isTrue(source) + if not source then + return false + end + local literals = m.searchLiterals(source) + for literal in pairs(literals) do + if literal ~= false then + return true + end + end + return false +end + +---判断对象的推断类型是否包含某个类型 +function m.hasType(source, tp) + local infers = m.searchInfers(source) + return infers[tp] or false +end + +---搜索并显示推断类型 +---@param source parser.guide.object +---@param field? string +---@return string +function m.searchAndViewInfers(source, field) + if not source then + return 'any' + end + local infers = m.searchInfers(source, field) + local view = m.viewInfers(infers) + return view +end + +---搜索并显示推断的class +---@param source parser.guide.object +---@return string? +function m.getClass(source) + if not source then + return nil + end + local infers = {} + local defs = searcher.requestDefinition(source) + for _, def in ipairs(defs) do + if def.type == 'doc.class.name' then + infers[def[1]] = true + end + end + local view = m.viewInfers(infers) + if view == 'any' then + return nil + end + return view +end + +return m diff --git a/script/core/keyword.lua b/script/core/keyword.lua index 71ea4969..73892f18 100644 --- a/script/core/keyword.lua +++ b/script/core/keyword.lua @@ -1,6 +1,6 @@ local define = require 'proto.define' -local guide = require 'core.guide' local files = require 'files' +local guide = require 'parser.guide' local keyWordMap = { {'do', function (info, results) diff --git a/script/core/noder.lua b/script/core/noder.lua new file mode 100644 index 00000000..c3679612 --- /dev/null +++ b/script/core/noder.lua @@ -0,0 +1,926 @@ +local util = require 'utility' +local guide = require 'parser.guide' + +local LastIDCache = {} +local FirstIDCache = {} +local SPLIT_CHAR = '\x1F' +local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$' +local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*' +local ANY_FIELD_CHAR = '*' +local RETURN_INDEX = SPLIT_CHAR .. '#' +local PARAM_INDEX = SPLIT_CHAR .. '&' +local TABLE_KEY = SPLIT_CHAR .. '<' +local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR +local URI_CHAR = '@' +local URI_REGEX = URI_CHAR .. '([^' .. URI_CHAR .. ']*)' .. URI_CHAR .. '(.*)' + +---@class node +-- 当前节点的id +---@field id string +-- 使用该ID的单元 +---@field sources parser.guide.object[] +-- 前进的关联ID +---@field forward string[] +-- 后退的关联ID +---@field backward string[] +-- 函数调用参数信息(用于泛型) +---@field call parser.guide.object + +---@alias noders table<string, node[]> + +---创建source的链接信息 +---@param noders noders +---@param id string +---@return node +local function getNode(noders, id) + if not noders[id] then + noders[id] = { + id = id, + } + end + return noders[id] +end + +---获取语法树单元的key +---@param source parser.guide.object +---@return string? key +---@return parser.guide.object? node +local function getKey(source) + if source.type == 'local' then + return tostring(source.start), nil + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return tostring(source.node.start), nil + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + local node = source.node + if node.tag == '_ENV' then + return ('%q'):format(source[1] or ''), nil + else + return ('%q'):format(source[1] or ''), node + end + elseif source.type == 'getfield' + or source.type == 'setfield' then + return ('%q'):format(source.field and source.field[1] or ''), source.node + elseif source.type == 'tablefield' then + return ('%q'):format(source.field and source.field[1] or ''), source.parent + elseif source.type == 'getmethod' + or source.type == 'setmethod' then + return ('%q'):format(source.method and source.method[1] or ''), source.node + elseif source.type == 'setindex' + or source.type == 'getindex' then + local index = source.index + if not index then + return ANY_FIELD_CHAR, source.node + end + if index.type == 'string' + or index.type == 'boolean' + or index.type == 'number' then + return ('%q'):format(index[1] or ''), source.node + elseif index.type ~= 'function' + and index.type ~= 'table' then + return ANY_FIELD_CHAR, source.node + end + elseif source.type == 'tableindex' then + local index = source.index + if not index then + return ANY_FIELD_CHAR, source.parent + end + if index.type == 'string' + or index.type == 'boolean' + or index.type == 'number' then + return ('%q'):format(index[1] or ''), source.parent + elseif index.type ~= 'function' + and index.type ~= 'table' then + return ANY_FIELD_CHAR, source.parent + end + elseif source.type == 'table' then + return source.start, nil + elseif source.type == 'label' then + return source.start, nil + elseif source.type == 'goto' then + if source.node then + return source.node.start, nil + end + return nil, nil + elseif source.type == 'function' then + return source.start, nil + elseif source.type == 'string' then + return '', nil + elseif source.type == 'integer' + or source.type == 'number' + or source.type == 'boolean' + or source.type == 'nil' then + return source.start, nil + elseif source.type == '...' then + return source.start, nil + elseif source.type == 'varargs' then + if source.node then + return source.node.start, nil + end + elseif source.type == 'select' then + return ('%d%s%d'):format(source.start, RETURN_INDEX, source.sindex) + elseif source.type == 'call' then + local node = source.node + if node.special == 'rawget' + or node.special == 'rawset' then + if not source.args then + return nil, nil + end + local tbl, key = source.args[1], source.args[2] + if not tbl or not key then + return nil, nil + end + if key.type == 'string' then + return ('%q'):format(key[1] or ''), tbl + else + return '', tbl + end + end + return source.finish, nil + elseif source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.see.name' then + local name = source[1] + return name, nil + elseif source.type == 'doc.type.name' then + local name = source[1] + if source.typeGeneric then + return source.typeGeneric[name][1].start, nil + else + return name, nil + end + elseif source.type == 'doc.class' + or source.type == 'doc.type' + or source.type == 'doc.param' + or source.type == 'doc.vararg' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.table' + or source.type == 'doc.type.array' + or source.type == 'doc.type.function' then + return source.start, nil + elseif source.type == 'doc.see.field' then + return ('%q'):format(source[1]), source.parent.name + elseif source.type == 'generic.closure' then + return source.call.start, nil + elseif source.type == 'generic.value' then + return ('%s|%s'):format( + source.closure.call.start, + getKey(source.proto) + ) + end + return nil, nil +end + +local function checkMode(source) + if source.type == 'table' then + return 't:' + end + if source.type == 'select' then + return 's:' + end + if source.type == 'function' then + return 'f:' + end + if source.type == 'string' then + return 'str:' + end + if source.type == 'number' + or source.type == 'integer' + or source.type == 'boolean' + or source.type == 'nil' then + return 'i:' + end + if source.type == 'call' then + return 'c:' + end + if source.type == '...' + or source.type == 'varargs' then + return 'va:' + end + if source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' then + if source.typeGeneric then + return 'dg:' + end + return 'dn:' + end + if source.type == 'doc.field.name' then + return 'dfn:' + end + if source.type == 'doc.see.name' then + return 'dsn:' + end + if source.type == 'doc.class' then + return 'dc:' + end + if source.type == 'doc.type' then + return 'dt:' + end + if source.type == 'doc.param' then + return 'dp:' + end + if source.type == 'doc.type.function' then + return 'dfun:' + end + if source.type == 'doc.type.table' then + return 'dtable:' + end + if source.type == 'doc.type.array' then + return 'darray:' + end + if source.type == 'doc.vararg' then + return 'dv:' + end + if source.type == 'doc.type.enum' + or source.type == 'doc.resume' then + return 'de:' + end + if source.type == 'generic.closure' then + return 'gc:' + end + if source.type == 'generic.value' then + local id = 'gv:' + if guide.getUri(source.closure.call) ~= guide.getUri(source.proto) then + id = id .. URI_CHAR .. guide.getUri(source.closure.call) + end + return id + end + if guide.isGlobal(source) then + return 'g:' + end + if source.type == 'getlocal' + or source.type == 'setlocal' then + source = source.node + end + if source.parent.type == 'funcargs' then + return 'p:' + end + return 'l:' +end + +local IDList = {} +---获取语法树单元的字符串ID +---@param source parser.guide.object +---@return string? id +local function getID(source) + if not source then + return nil + end + if source._id ~= nil then + return source._id or nil + end + if source.type == 'field' + or source.type == 'method' then + source._id = false + return nil + end + local current = source + local index = 0 + while true do + if current.type == 'paren' then + current = current.exp + goto CONTINUE + end + local id, node = getKey(current) + if not id then + break + end + index = index + 1 + IDList[index] = id + if not node then + break + end + current = node + if current.special == '_G' then + for i = index, 2, -1 do + if IDList[i] == '"_G"' then + IDList[i] = nil + end + end + break + end + ::CONTINUE:: + end + if index == 0 then + source._id = false + return nil + end + for i = index + 1, #IDList do + IDList[i] = nil + end + local mode = checkMode(current) + if not mode then + source._id = false + return nil + end + util.revertTable(IDList) + local id = mode .. table.concat(IDList, SPLIT_CHAR) + source._id = id + return id +end + +---添加关联的前进ID +---@param noders noders +---@param id string +---@param forwardID string +local function pushForward(noders, id, forwardID, tag) + if not id + or not forwardID + or forwardID == '' + or id == forwardID then + return + end + local node = getNode(noders, id) + if not node.forward then + node.forward = {} + end + if node.forward[forwardID] ~= nil then + return + end + node.forward[forwardID] = tag or false + node.forward[#node.forward+1] = forwardID +end + +---添加关联的后退ID +---@param noders noders +---@param id string +---@param backwardID string +local function pushBackward(noders, id, backwardID, tag) + if not id + or not backwardID + or backwardID == '' + or id == backwardID then + return + end + local node = getNode(noders, id) + if not node.backward then + node.backward = {} + end + if node.backward[backwardID] ~= nil then + return + end + node.backward[backwardID] = tag or false + node.backward[#node.backward+1] = backwardID +end + +local m = {} + +m.SPLIT_CHAR = SPLIT_CHAR +m.RETURN_INDEX = RETURN_INDEX +m.PARAM_INDEX = PARAM_INDEX +m.TABLE_KEY = TABLE_KEY +m.ANY_FIELD = ANY_FIELD +m.URI_CHAR = URI_CHAR + +--- 寻找doc的主体 +---@param obj parser.guide.object +---@return parser.guide.object +local function getDocStateWithoutCrossFunction(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return obj + end + if parent.type == 'doc' then + return obj + end + if parent.type == 'doc.type.function' then + return nil + end + obj = parent + end + error('guide.getDocState overstack') +end + +---添加关联单元 +---@param noders noders +---@param source parser.guide.object +function m.pushSource(noders, source) + local id = m.getID(source) + if not id then + return + end + local node = getNode(noders, id) + if not node.sources then + node.sources = {} + end + node.sources[#node.sources+1] = source +end + +---@param noders noders +---@param source parser.guide.object +---@return parser.guide.object[] +function m.compileNode(noders, source) + local id = getID(source) + local value = source.value + if value then + local valueID = getID(value) + if valueID then + -- x = y : x -> y + pushForward(noders, id, valueID, 'set') + -- 参数禁止反向查找赋值 + if valueID:sub(1, 2) ~= 'p:' then + pushBackward(noders, valueID, id, 'set') + end + end + end + -- self -> mt:xx + if source.type == 'local' and source[1] == 'self' then + local func = guide.getParentFunction(source) + if func.isGeneric then + return + end + local setmethod = func.parent + -- guess `self` + if setmethod and ( setmethod.type == 'setmethod' + or setmethod.type == 'setfield' + or setmethod.type == 'setindex') then + pushForward(noders, id, getID(setmethod.node), 'method') + pushBackward(noders, getID(setmethod.node), id, 'method') + end + end + -- 分解 @type + if source.type == 'doc.type' then + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) + end + end + for _, enumUnit in ipairs(source.enums) do + pushForward(noders, id, getID(enumUnit)) + end + for _, resumeUnit in ipairs(source.resumes) do + pushForward(noders, id, getID(resumeUnit)) + end + for _, typeUnit in ipairs(source.types) do + local unitID = getID(typeUnit) + pushForward(noders, id, unitID) + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushBackward(noders, unitID, getID(src)) + end + end + end + end + -- 分解 @alias + if source.type == 'doc.alias' then + pushForward(noders, getID(source.alias), getID(source.extends)) + end + -- 分解 @class + if source.type == 'doc.class' then + pushForward(noders, id, getID(source.class)) + pushForward(noders, getID(source.class), id) + if source.extends then + for _, ext in ipairs(source.extends) do + pushBackward(noders, id, getID(ext)) + end + end + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) + end + end + for _, field in ipairs(source.fields) do + local key = field.field[1] + if key then + local keyID = ('%s%s%q'):format( + id, + SPLIT_CHAR, + key + ) + pushForward(noders, keyID, getID(field.field)) + pushForward(noders, getID(field.field), keyID) + pushForward(noders, keyID, getID(field.extends)) + pushBackward(noders, getID(field.extends), keyID) + end + end + end + if source.type == 'doc.param' then + pushForward(noders, id, getID(source.extends)) + for _, src in ipairs(source.bindSources) do + if src.type == 'local' and src.parent.type == 'in' then + pushForward(noders, getID(src), id) + end + end + end + if source.type == 'doc.vararg' then + pushForward(noders, getID(source), getID(source.vararg)) + end + if source.type == 'doc.see' then + local nameID = getID(source.name) + local classID = nameID:gsub('^dsn:', 'dn:') + pushForward(noders, nameID, classID) + if source.field then + local fieldID = getID(source.field) + local fieldClassID = fieldID:gsub('^dsn:', 'dn:') + pushForward(noders, fieldID, fieldClassID) + end + end + if source.type == 'call' then + local node = source.node + local nodeID = getID(node) + if not nodeID then + return + end + getNode(noders, id).call = source + -- 将 call 映射到 node#1 上 + local callID = ('%s%s%s'):format( + nodeID, + RETURN_INDEX, + 1 + ) + pushForward(noders, id, callID) + -- 将setmetatable映射到 param1 以及 param2.__index 上 + if node.special == 'setmetatable' then + local tblID = getID(source.args and source.args[1]) + local metaID = getID(source.args and source.args[2]) + local indexID + if metaID then + indexID = ('%s%s%q'):format( + metaID, + SPLIT_CHAR, + '__index' + ) + end + pushForward(noders, id, callID) + pushBackward(noders, callID, id) + pushForward(noders, callID, tblID) + pushForward(noders, callID, indexID) + pushBackward(noders, tblID, callID) + --pushBackward(noders, indexID, callID) + end + if node.special == 'require' then + local arg1 = source.args and source.args[1] + if arg1 and arg1.type == 'string' then + getNode(noders, callID).require = arg1[1] + end + end + end + if source.type == 'select' then + if source.vararg.type == 'call' then + local call = source.vararg + local node = call.node + local nodeID = getID(node) + if not nodeID then + return + end + -- 将call的返回值接收映射到函数返回值上 + local callXID = ('%s%s%s'):format( + nodeID, + RETURN_INDEX, + source.sindex + ) + pushForward(noders, id, callXID) + pushBackward(noders, callXID, id) + getNode(noders, id).call = call + if node.special == 'pcall' + or node.special == 'xpcall' then + local index = source.sindex - 1 + if index <= 0 then + return + end + local funcID = call.args and getID(call.args[1]) + if not funcID then + return + end + local funcXID = ('%s%s%s'):format( + funcID, + RETURN_INDEX, + index + ) + pushForward(noders, id, funcXID) + pushBackward(noders, funcXID, id) + end + end + if source.vararg.type == 'varargs' then + pushForward(noders, id, getID(source.vararg)) + end + end + if source.type == 'doc.type.function' then + if source.returns then + for index, rtn in ipairs(source.returns) do + local returnID = ('%s%s%s'):format( + id, + RETURN_INDEX, + index + ) + pushForward(noders, returnID, getID(rtn)) + end + end + -- @type fun(x: T):T 的情况 + local docType = getDocStateWithoutCrossFunction(source) + if docType and docType.type == 'doc.type' then + guide.eachSourceType(source, 'doc.type.name', function (typeName) + if typeName.typeGeneric then + source.isGeneric = true + return false + end + end) + end + end + if source.type == 'doc.type.table' then + if source.tkey then + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, getID(source.tkey)) + end + if source.tvalue then + local valueID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, valueID, getID(source.tvalue)) + end + end + if source.type == 'doc.type.array' then + if source.node then + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source.node)) + end + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, 'dn:integer') + end + -- 将函数的返回值映射到具体的返回值上 + if source.type == 'function' then + local hasDocReturn = {} + -- 检查 luadoc + if source.bindDocs then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + local fullID = ('%s%s%s'):format( + id, + RETURN_INDEX, + rtn.returnIndex + ) + pushForward(noders, fullID, getID(rtn)) + hasDocReturn[rtn.returnIndex] = true + end + end + if doc.type == 'doc.param' then + local paramName = doc.param[1] + if source.docParamMap then + local paramIndex = source.docParamMap[paramName] + local param = source.args[paramIndex] + if param then + pushForward(noders, getID(param), getID(doc)) + param.docParam = doc + end + end + end + if doc.type == 'doc.vararg' then + for _, param in ipairs(source.args) do + if param.type == '...' then + pushForward(noders, getID(param), getID(doc)) + end + end + end + if doc.type == 'doc.generic' then + source.isGeneric = true + end + if doc.type == 'doc.overload' then + pushForward(noders, id, getID(doc.overload)) + end + end + end + -- 检查实体返回值 + if source.returns then + local returns = {} + for _, rtn in ipairs(source.returns) do + for index, rtnObj in ipairs(rtn) do + if not hasDocReturn[index] then + if not returns[index] then + returns[index] = {} + end + returns[index][#returns[index]+1] = rtnObj + end + end + end + for index, rtnObjs in ipairs(returns) do + local returnID = ('%s%s%s'):format( + id, + RETURN_INDEX, + index + ) + for _, rtnObj in ipairs(rtnObjs) do + pushForward(noders, returnID, getID(rtnObj)) + if rtnObj.type == 'function' + or rtnObj.type == 'call' then + pushBackward(noders, getID(rtnObj), returnID) + end + end + end + end + end + if source.type == 'table' then + if #source == 1 and source[1].type == 'varargs' then + source.array = source[1] + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source[1])) + end + end + if source.type == 'main' then + if source.returns then + for _, rtn in ipairs(source.returns) do + local rtnObj = rtn[1] + if rtnObj then + pushForward(noders, 'mainreturn', getID(rtnObj)) + pushBackward(noders, getID(rtnObj), 'mainreturn') + end + end + end + end + if source.type == 'generic.closure' then + for i, rtn in ipairs(source.returns) do + local closureID = ('%s%s%s'):format( + id, + RETURN_INDEX, + i + ) + local returnID = getID(rtn) + pushForward(noders, closureID, returnID) + end + end + if source.type == 'generic.value' then + local proto = source.proto + local closure = source.closure + local upvalues = closure.upvalues + if proto.type == 'doc.type.name' then + local key = proto[1] + if upvalues[key] then + for _, paramID in ipairs(upvalues[key]) do + pushForward(noders, id, paramID) + pushBackward(noders, paramID, id) + end + end + end + if proto.type == 'doc.type' then + for _, tp in ipairs(source.types) do + pushForward(noders, id, getID(tp)) + pushBackward(noders, getID(tp), id) + end + end + if proto.type == 'doc.type.array' then + local nodeID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, nodeID, getID(source.node)) + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, 'dn:integer') + end + if proto.type == 'doc.type.table' then + if source.tkey then + local keyID = ('%s%s'):format( + id, + TABLE_KEY + ) + pushForward(noders, keyID, getID(source.tkey)) + end + if source.tvalue then + local valueID = ('%s%s'):format( + id, + ANY_FIELD + ) + pushForward(noders, valueID, getID(source.tvalue)) + end + end + end +end + +---根据ID来获取所有的node +---@param root parser.guide.object +---@param id string +---@return node? +function m.getNodeByID(root, id) + root = guide.getRoot(root) + local noders = root._noders + if not noders then + return nil + end + return noders[id] +end + +---根据ID来获取第一个节点的ID +---@param id string +---@return string +function m.getFirstID(id) + if FirstIDCache[id] then + return FirstIDCache[id] or nil + end + local firstID, count = id:match(FIRST_REGEX) + if count == 0 then + FirstIDCache[id] = false + return nil + end + FirstIDCache[id] = firstID + return firstID +end + +---根据ID来获取上个节点的ID +---@param id string +---@return string +function m.getLastID(id) + if LastIDCache[id] then + return LastIDCache[id] or nil + end + local lastID, count = id:gsub(LAST_REGEX, '') + if count == 0 then + LastIDCache[id] = false + return nil + end + LastIDCache[id] = lastID + return lastID +end + +---把形如 `@file:\\\XXXXX@gv:1|1`拆分成uri与id +---@param id string +---@return uri? string +---@return string id +function m.getUriAndID(id) + local uri, newID = id:match(URI_REGEX) + return uri, newID +end + +---获取source的ID +---@param source parser.guide.object +---@return string +function m.getID(source) + return getID(source) +end + +---获取source的key +---@param source parser.guide.object +---@return string +function m.getKey(source) + return getKey(source) +end + +---清除临时id(用于泛型的临时对象) +---@param root parser.guide.object +---@param id string +function m.removeID(root, id) + root = guide.getRoot(root) + local noders = root._noders + noders[id] = nil +end + +---寻找doc的主体 +---@param doc parser.guide.object +function m.getDocState(doc) + return getDocStateWithoutCrossFunction(doc) +end + +---获取对象的noders +---@param source parser.guide.object +---@return noders +function m.getNoders(source) + local root = guide.getRoot(source) + if not root._noders then + root._noders = {} + end + return root._noders +end + +---编译整个文件的node +---@param source parser.guide.object +---@return table +function m.compileNodes(source) + local root = guide.getRoot(source) + local noders = m.getNoders(source) + if next(noders) then + return noders + end + guide.eachSource(root, function (src) + m.pushSource(noders, src) + m.compileNode(noders, src) + end) + -- Special rule: ('').XX -> stringlib.XX + pushBackward(noders, 'str:', 'dn:stringlib') + pushBackward(noders, 'dn:string', 'dn:stringlib') + return noders +end + +return m diff --git a/script/core/reference.lua b/script/core/reference.lua index 7620b09e..c3f3b349 100644 --- a/script/core/reference.lua +++ b/script/core/reference.lua @@ -1,4 +1,5 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' @@ -6,8 +7,8 @@ local findSource = require 'core.find-source' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -19,7 +20,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else @@ -64,11 +65,23 @@ return function (uri, offset) local metaSource = vm.isMetaFile(uri) + local refs = vm.getRefs(source) + local values = {} + for _, src in ipairs(refs) do + local value = searcher.getObjectValue(src) + if value and value ~= src and guide.isLiteral(value) then + values[value] = true + end + end + local results = {} - for _, src in ipairs(vm.getRefs(source, 5)) do + for _, src in ipairs(refs) do if src.dummy then goto CONTINUE end + if values[src] then + goto CONTINUE + end local root = guide.getRoot(src) if not root then goto CONTINUE diff --git a/script/core/rename.lua b/script/core/rename.lua index da82b0a6..6b67d4be 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -1,11 +1,11 @@ local files = require 'files' local vm = require 'vm' -local guide = require 'core.guide' local proto = require 'proto' local define = require 'proto.define' local util = require 'utility' local findSource = require 'core.find-source' -local ws = require 'workspace' +local guide = require 'parser.guide' +local noder = require 'core.noder' local Forcing @@ -185,7 +185,7 @@ local function renameField(source, newname, callback) end callback(source, source.start, source.finish, newname) elseif parent.type == 'setmethod' then - local uri = guide.getUri(source) + local uri = guide.getUri(source) local text = files.getText(uri) local func = parent.value -- function mt:name () end --> mt['newname'] = function (self) end @@ -292,14 +292,14 @@ local function ofField(source, newname, callback) else node = source.node end - for _, src in ipairs(vm.getFields(node, 5)) do + for _, src in ipairs(vm.getRefs(node, '*')) do ofFieldThen(key, src, newname, callback) end end local function ofGlobal(source, newname, callback) local key = guide.getKeyName(source) - for _, src in ipairs(vm.getRefs(source, 0)) do + for _, src in ipairs(vm.getRefs(source)) do ofFieldThen(key, src, newname, callback) end end @@ -308,24 +308,27 @@ local function ofLabel(source, newname, callback) if not isValidName(newname) and not askForcing(newname)then return false end - for _, src in ipairs(vm.getRefs(source, 0)) do + for _, src in ipairs(vm.getRefs(source)) do callback(src, src.start, src.finish, newname) end end local function ofDocTypeName(source, newname, callback) - for _, doc in ipairs(vm.getDocTypes(source[1])) do + local oldname = source[1] + for _, doc in ipairs(vm.getRefs(source)) do if doc.type == 'doc.class.name' or doc.type == 'doc.type.name' or doc.type == 'doc.alias.name' then - callback(doc, doc.start, doc.finish, newname) + if oldname == doc[1] then + callback(doc, doc.start, doc.finish, newname) + end end end end local function ofDocParamName(source, newname, callback) callback(source, source.start, source.finish, newname) - local doc = guide.getDocState(source) + local doc = noder.getDocState(source) if doc.bindSources then for _, src in ipairs(doc.bindSources) do if src.type == 'local' diff --git a/script/core/searcher.lua b/script/core/searcher.lua new file mode 100644 index 00000000..11e00378 --- /dev/null +++ b/script/core/searcher.lua @@ -0,0 +1,728 @@ +local noder = require 'core.noder' +local guide = require 'parser.guide' +local files = require 'files' +local generic = require 'core.generic' +local ws = require 'workspace' +local vm = require 'vm.vm' + +local NONE = {'NONE'} +local LAST = {'LAST'} + +local ignoredIDs = { + ['dn:unknown'] = true, + ['dn:nil'] = true, + ['dn:any'] = true, + ['dn:boolean'] = true, + ['dn:string'] = true, + ['dn:table'] = true, + ['dn:number'] = true, + ['dn:integer'] = true, + ['dn:userdata'] = true, + ['dn:lightuserdata'] = true, + ['dn:function'] = true, + ['dn:thread'] = true, +} + +local m = {} + +---@alias guide.searchmode '"ref"'|'"def"' + +---添加结果 +---@param status guide.status +---@param mode guide.searchmode +---@param source parser.guide.object +---@param force boolean +function m.pushResult(status, mode, source, force) + if not source then + return + end + local results = status.results + if results[source] then + return + end + results[source] = true + if force then + results[#results+1] = source + return + end + local parent = source.parent + if mode == 'def' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'label' + or source.type == 'setfield' + or source.type == 'setmethod' + or source.type == 'setindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if noder.getID(source) ~= status.id then + results[#results+1] = source + end + end + elseif mode == 'ref' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'label' + or source.type == 'goto' + or source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' + or source.type == 'setindex' + or source.type == 'getindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' + or source.type == 'nil' + or source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.resume' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if noder.getID(source) ~= status.id then + results[#results+1] = source + end + end + end +end + +---获取uri +---@param obj parser.guide.object +---@return uri +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = guide.getRoot(obj) + if root then + return root.uri + end + return '' +end + +---@param obj parser.guide.object +---@return parser.guide.object? +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent and obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args and obj.args[3] + else + return obj + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +local function crossSearch(status, uri, expect, mode) + m.searchRefsByID(status, uri, expect, mode) +end + +local function getLock(status, uri, expect, mode) + local slock = status.lock + local ulock = slock[uri] + if not ulock then + ulock = {} + slock[uri] = ulock + end + local mlock = ulock[mode] + if not mlock then + mlock = {} + ulock[mode] = mlock + end + if mlock[expect] then + return false + end + mlock[expect] = true + return true +end + +function m.searchRefsByID(status, uri, expect, mode) + local ast = files.getAst(uri) + if not ast then + return + end + if not getLock(status, uri, expect, mode) then + return + end + local root = ast.ast + local searchStep + noder.compileNodes(root) + + status.id = expect + + local callStack = status.callStack + + local mark = {} + + local function search(id, field) + local firstID = noder.getFirstID(id) + if ignoredIDs[firstID] and (field or firstID ~= id) then + return + end + local cmark = mark[id] + if not cmark then + cmark = {} + mark[id] = cmark + end + log.debug('search:', id, field) + if field then + if cmark[field] then + return + end + cmark[field] = true + searchStep(id, field) + cmark[field] = nil + else + if cmark[NONE] then + return + end + cmark[NONE] = true + searchStep(id, nil) + cmark[NONE] = nil + end + log.debug('pop:', id, field) + end + + local function checkLastID(id, field) + local cmark = mark[id] + if not cmark then + cmark = {} + mark[id] = cmark + end + if cmark[LAST] then + return + end + local lastID = noder.getLastID(id) + if not lastID then + return + end + local newField = id:sub(#lastID + 1) + if field then + newField = newField .. field + end + cmark[LAST] = true + search(lastID, newField) + cmark[LAST] = nil + return lastID + end + + local function searchID(id, field) + if not id then + return + end + if field then + id = id .. field + end + search(id, nil) + end + + local function isCallID(field) + if not field then + return false + end + if field:sub(1, 2) == noder.RETURN_INDEX then + return true + end + return false + end + + local function findLastCall() + for i = #callStack, 1, -1 do + local call = callStack[i] + if call then + -- 标记此处的call失效,等待在堆栈平衡时弹出 + callStack[i] = false + return call + end + end + return nil + end + + local genericCallArgs = {} + local closureCache = {} + local function checkGeneric(source, field) + if not source.isGeneric then + return + end + if not isCallID(field) then + return + end + local call = findLastCall() + if not call then + return + end + + if call.args then + for _, arg in ipairs(call.args) do + genericCallArgs[arg] = true + end + end + + local cacheID = noder.getID(source) .. noder.getID(call) + local closure = closureCache[cacheID] + if closure == false then + return + end + if not closure then + closure = generic.createClosure(source, call) + closureCache[cacheID] = closure or false + if not closure then + return + end + end + local id = noder.getID(closure) + searchID(id, field) + end + + local function checkENV(source, field) + if not field then + return + end + if source.special ~= '_G' then + return + end + local newID = 'g:' .. field:sub(2) + searchID(newID) + end + + local forwardTag = {} + local backwardTag = {} + local function checkForward(id, node, field) + for _, forwardID in ipairs(node.forward) do + local tag = node.forward[forwardID] + if tag then + if backwardTag[tag] and backwardTag[tag] > 0 then + goto CONTINUE + end + forwardTag[tag] = (forwardTag[tag] or 0) + 1 + end + local targetUri, targetID = noder.getUriAndID(forwardID) + if targetUri and not files.eq(targetUri, uri) then + crossSearch(status, targetUri, targetID .. (field or ''), mode) + else + searchID(targetID or forwardID, field) + end + if tag then + forwardTag[tag] = forwardTag[tag] - 1 + end + ::CONTINUE:: + end + end + + local function checkBackward(id, node, field) + if mode ~= 'ref' and not field then + return + end + for _, backwardID in ipairs(node.backward) do + local tag = node.backward[backwardID] + if tag then + if forwardTag[tag] and forwardTag[tag] > 0 then + goto CONTINUE + end + backwardTag[tag] = (backwardTag[tag] or 0) + 1 + end + local targetUri, targetID = noder.getUriAndID(backwardID) + if targetUri and not files.eq(targetUri, uri) then + crossSearch(status, targetUri, targetID .. (field or ''), mode) + else + searchID(targetID or backwardID, field) + end + if tag then + backwardTag[tag] = backwardTag[tag] - 1 + end + ::CONTINUE:: + end + end + + local function checkRequire(requireName, field) + local tid = 'mainreturn' .. (field or '') + local uris = ws.findUrisByRequirePath(requireName) + for _, ruri in ipairs(uris) do + if not files.eq(uri, ruri) then + crossSearch(status, ruri, tid, mode) + end + end + end + + local function checkGlobal(id, node, field) + if id:sub(1, 2) ~= 'g:' then + return + end + local firstID = noder.getFirstID(id) + if status.crossed[firstID] then + return + end + status.crossed[firstID] = true + local tid = id .. (field or '') + for guri in files.eachFile() do + if not files.eq(uri, guri) then + crossSearch(status, guri, tid, mode) + end + end + end + + local function checkClass(id, node, field) + if id:sub(1, 3) ~= 'dn:' then + return + end + local firstID = noder.getFirstID(id) + if status.crossed[firstID] then + return + end + status.crossed[firstID] = true + local tid = id .. (field or '') + for guri in files.eachFile() do + if not files.eq(uri, guri) then + crossSearch(status, guri, tid, mode) + end + end + end + + local function checkMainReturn(id, node, field) + if id ~= 'mainreturn' then + return + end + if mode ~= 'ref' and not field then + return + end + local calls = vm.getLinksTo(uri) + for _, call in ipairs(calls) do + local turi = guide.getUri(call) + if not files.eq(turi, uri) then + local tid = noder.getID(call) .. (field or '') + crossSearch(status, turi, tid, mode) + end + end + end + + local function searchNode(id, node, field) + if node.call then + callStack[#callStack+1] = node.call + end + if field == nil and node.sources then + for _, source in ipairs(node.sources) do + local force = genericCallArgs[source] + m.pushResult(status, mode, source, force) + end + end + if node.forward then + checkForward(id, node, field) + end + if node.backward then + checkBackward(id, node, field) + end + + if node.sources then + checkGeneric(node.sources[1], field) + checkENV(node.sources[1], field) + end + + if node.require then + checkRequire(node.require, field) + end + + checkMainReturn(id, node, field) + + if node.call then + callStack[#callStack] = nil + end + end + + local function checkCrossUri(id, field) + local targetUri, newID = noder.getUriAndID(id) + if not targetUri then + return false + end + crossSearch(status, targetUri, newID .. (field or ''), mode) + return true + end + + local stepCount = 0 + function searchStep(id, field) + stepCount = stepCount + 1 + if stepCount > 1000 then + error('too large') + end + local node = noder.getNodeByID(root, id) + if node then + searchNode(id, node, field) + end + checkGlobal(id, node, field) + checkClass(id, node, field) + local lastID = checkLastID(id, field) + if not lastID then + return + end + local originField = id:sub(#lastID + 1) + if originField == noder.TABLE_KEY then + return + end + local anyFieldID = lastID .. noder.ANY_FIELD + local anyFieldNode = noder.getNodeByID(root, anyFieldID) + if anyFieldNode then + searchNode(anyFieldID, anyFieldNode, field) + end + end + + search(expect) + + --清除来自泛型的临时对象 + for _, closure in pairs(closureCache) do + noder.removeID(root, noder.getID(closure)) + if closure then + for _, value in ipairs(closure.values) do + noder.removeID(root, noder.getID(value)) + end + end + end +end + +local function prepareSearch(source) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local root = guide.getRoot(source) + noder.compileNodes(root) + local uri = guide.getUri(source) + local id = noder.getID(source) + return uri, id +end + +local function getField(status, source, mode) + if source.type == 'table' then + for _, field in ipairs(source) do + m.pushResult(status, mode, field) + end + return + end + if source.type == 'doc.class.name' then + local class = source.parent + for _, field in ipairs(class.fields) do + m.pushResult(status, mode, field.field) + end + return + end + local field = source.next + if field then + if field.type == 'getmethod' + or field.type == 'setmethod' + or field.type == 'getfield' + or field.type == 'setfield' + or field.type == 'getindex' + or field.type == 'setindex' then + m.pushResult(status, mode, field) + end + return + end +end + +local function searchAllGlobalByUri(status, mode, uri, fullID) + local ast = files.getAst(uri) + if not ast then + return + end + local root = ast.ast + noder.compileNodes(root) + local noders = noder.getNoders(root) + if fullID then + for id, node in pairs(noders) do + if node.sources + and id == fullID then + for _, source in ipairs(node.sources) do + m.pushResult(status, mode, source) + end + end + end + else + for id, node in pairs(noders) do + if node.sources + and id:sub(1, 2) == 'g:' + and not id:find(noder.SPLIT_CHAR) then + for _, source in ipairs(node.sources) do + m.pushResult(status, mode, source) + end + end + end + end +end + +local function searchAllGlobals(status, mode, fullID) + for uri in files.eachFile() do + searchAllGlobalByUri(status, mode, uri, fullID) + end +end + +---搜索对象的引用 +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +function m.searchRefs(status, source, mode) + local uri, id = prepareSearch(source) + if not id then + return + end + log.debug('searchRefs:', id) + m.searchRefsByID(status, uri, id, mode) +end + +function m.findGlobals(uri, mode, name) + local status = m.status() + + if name then + local fullID = ('g:%q'):format(name) + searchAllGlobalByUri(status, mode, uri, fullID) + else + searchAllGlobalByUri(status, mode, uri) + end + + return status.results +end + +---搜索对象的field +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +---@param field string +function m.searchFields(status, source, mode, field) + local uri, id = prepareSearch(source) + if not id then + return + end + log.debug('searchFields:', id, field) + if field == '*' then + if source.special == '_G' then + searchAllGlobals(status, mode) + else + local newStatus = m.status(status) + m.searchRefsByID(newStatus, uri, id, mode) + for _, def in ipairs(newStatus.results) do + getField(status, def, mode) + end + end + else + if source.special == '_G' then + local fullID = ('g:%q'):format(field) + m.searchRefsByID(status, uri, fullID, mode) + else + local fullID = ('%s%s%q'):format(id, noder.SPLIT_CHAR, field) + m.searchRefsByID(status, uri, fullID, mode) + end + end +end + +---@class guide.status +---搜索结果 +---@field results parser.guide.object[] + +---创建搜索状态 +---@param parentStatus guide.status +---@return guide.status +function m.status(parentStatus) + local status = { + --mark = parentStatus and parentStatus.mark or {}, + callStack = {}, + crossed = {}, + lock = {}, + results = {}, + } + return status +end + +--- 请求对象的引用 +---@param obj parser.guide.object +---@param field? string +---@return parser.guide.object[] +function m.requestReference(obj, field) + local status = m.status() + + if field then + m.searchFields(status, obj, 'ref', field) + else + m.searchRefs(status, obj, 'ref') + end + + return status.results +end + +--- 请求对象的定义 +---@param obj parser.guide.object +---@param field? string +---@return parser.guide.object[] +function m.requestDefinition(obj, field) + local status = m.status() + + if field then + m.searchFields(status, obj, 'def', field) + else + m.searchRefs(status, obj, 'def') + end + + return status.results +end + +return m diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index f8feaa09..5e9ee9b1 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local await = require 'await' local define = require 'proto.define' local vm = require 'vm' @@ -221,7 +221,7 @@ return function (uri, start, finish) local results = {} local count = 0 - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) local method = Care[source.type] if not method then return diff --git a/script/core/signature.lua b/script/core/signature.lua index eb740784..915310c0 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -1,8 +1,9 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local hoverLabel = require 'core.hover.label' local hoverDesc = require 'core.hover.description' +local guide = require 'parser.guide' local function findNearCall(uri, ast, pos) local text = files.getText(uri) @@ -96,10 +97,10 @@ local function makeSignatures(call, pos) index = 1 end local signs = {} - local defs = vm.getDefs(node, 0) + local defs = vm.getDefs(node) local mark = {} for _, src in ipairs(defs) do - src = guide.getObjectValue(src) or src + src = searcher.getObjectValue(src) or src if src.type == 'function' or src.type == 'doc.type.function' then if not mark[src] then diff --git a/script/core/type-formatting.lua b/script/core/type-formatting.lua index c2290ef3..49a721e5 100644 --- a/script/core/type-formatting.lua +++ b/script/core/type-formatting.lua @@ -1,6 +1,6 @@ local files = require 'files' local lookBackward = require 'core.look-backward' -local guide = require 'core.guide' +local guide = require "parser.guide" local function insertIndentation(uri, offset, edits) local lines = files.getLines(uri) diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua index ae420d32..2df23a4d 100644 --- a/script/core/workspace-symbol.lua +++ b/script/core/workspace-symbol.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local matchKey = require 'core.matchkey' local define = require 'proto.define' local await = require 'await' @@ -52,7 +52,7 @@ local function searchFile(uri, key, results) return end - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) buildSource(uri, source, key, results) end) end |