diff options
Diffstat (limited to 'script')
73 files changed, 2183 insertions, 442 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua index bae3df81..1ec2aa8b 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -1,10 +1,10 @@ -local files = require 'files' -local lang = require 'language' -local define = require 'proto.define' -local guide = require 'core.guide' -local util = require 'utility' -local sp = require 'bee.subprocess' -local vm = require 'vm' +local files = require 'files' +local lang = require 'language' +local define = require 'proto.define' +local searcher = require 'core.searcher' +local util = require 'utility' +local sp = require 'bee.subprocess' +local vm = require 'vm' local function checkDisableByLuaDocExits(uri, row, mode, code) local lines = files.getLines(uri) @@ -59,7 +59,7 @@ end local function disableDiagnostic(uri, code, start, results) local lines = files.getLines(uri) - local row = guide.positionOf(lines, start) + local row = searcher.positionOf(lines, start) results[#results+1] = { title = lang.script('ACTION_DISABLE_DIAG', code), kind = 'quickfix', @@ -137,12 +137,12 @@ end local function solveUndefinedGlobal(uri, diag, results) local ast = files.getAst(uri) local offset = files.offsetOfWord(uri, diag.range.start) - guide.eachSourceContain(ast.ast, offset, function (source) + searcher.eachSourceContain(ast.ast, offset, function (source) if source.type ~= 'getglobal' then return end - local name = guide.getKeyName(source) + local name = searcher.getKeyName(source) markGlobal(uri, name, results) end) @@ -156,12 +156,12 @@ end local function solveLowercaseGlobal(uri, diag, results) local ast = files.getAst(uri) local offset = files.offsetOfWord(uri, diag.range.start) - guide.eachSourceContain(ast.ast, offset, function (source) + searcher.eachSourceContain(ast.ast, offset, function (source) if source.type ~= 'setglobal' then return end - local name = guide.getKeyName(source) + local name = searcher.getKeyName(source) markGlobal(uri, name, results) end) end @@ -357,7 +357,7 @@ local function checkSwapParams(results, uri, start, finish) return end local args = {} - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) if source.type == 'callargs' or source.type == 'funcargs' then local targetIndex diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua index 527af8d5..ba1ee8eb 100644 --- a/script/core/command/removeSpace.lua +++ b/script/core/command/removeSpace.lua @@ -1,11 +1,10 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local proto = require 'proto' +local lang = require 'language' local function isInString(ast, offset) - return guide.eachSourceContain(ast.ast, offset, function (source) + return searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'string' then return true end @@ -23,10 +22,10 @@ return function (data) local textEdit = {} for i = 1, #lines do - local line = guide.lineContent(lines, text, i, true) + local line = searcher.lineContent(lines, text, i, true) local pos = line:find '[ \t]+$' if pos then - local start, finish = guide.lineRange(lines, i, true) + local start, finish = searcher.lineRange(lines, i, true) start = start + pos - 1 if isInString(ast, start) then goto NEXT_LINE diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua index 995a2109..dc23e7af 100644 --- a/script/core/command/solve.lua +++ b/script/core/command/solve.lua @@ -1,8 +1,7 @@ -local files = require 'files' -local define = require 'proto.define' -local guide = require 'core.guide' -local proto = require 'proto' -local lang = require 'language' +local files = require 'files' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' local opMap = { ['+'] = true, diff --git a/script/core/completion.lua b/script/core/completion.lua index ee61029d..ef8f220f 100644 --- a/script/core/completion.lua +++ b/script/core/completion.lua @@ -1,6 +1,6 @@ local define = require 'proto.define' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local matchKey = require 'core.matchkey' local vm = require 'vm' local getLabel = require 'core.hover.label' @@ -57,7 +57,7 @@ end local function findNearestSource(ast, offset) local source - guide.eachSourceContain(ast.ast, offset, function (src) + searcher.eachSourceContain(ast.ast, offset, function (src) source = src end) return source @@ -85,7 +85,7 @@ local function findParent(ast, text, offset) if not anyPos then return nil, nil end - local parent = guide.eachSourceContain(ast.ast, anyPos, function (source) + local parent = searcher.eachSourceContain(ast.ast, anyPos, function (source) if source.finish == anyPos then return source end @@ -100,8 +100,8 @@ end local function findParentInStringIndex(ast, text, offset) local near, nearStart - guide.eachSourceContain(ast.ast, offset, function (source) - local start = guide.getStartFinish(source) + searcher.eachSourceContain(ast.ast, offset, function (source) + local start = searcher.getStartFinish(source) if not start then return end @@ -151,9 +151,9 @@ local function getSnip(source) end local defs = vm.getRefs(source, 0) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def ~= source and def.type == 'function' then - local uri = guide.getUri(def) + local uri = searcher.getUri(def) local text = files.getText(uri) local lines = files.getLines(uri) if not text then @@ -162,7 +162,7 @@ local function getSnip(source) if vm.isMetaFile(uri) then goto CONTINUE end - local row = guide.positionOf(lines, def.start) + local row = searcher.positionOf(lines, def.start) local firstRow = lines[row] local lastRow = lines[math.min(row + context - 1, #lines)] local snip = text:sub(firstRow.start, lastRow.finish) @@ -204,7 +204,7 @@ local function buildFunction(results, source, value, oop, data) end local function buildInsertRequire(ast, targetUri, stemName) - local uri = guide.getUri(ast.ast) + local uri = searcher.getUri(ast.ast) local lines = files.getLines(uri) local text = files.getText(uri) local start = 1 @@ -234,7 +234,7 @@ local function buildInsertRequire(ast, targetUri, stemName) end local function isSameSource(ast, source, pos) - if not files.eq(guide.getUri(source), guide.getUri(ast.ast)) then + if not files.eq(searcher.getUri(source), searcher.getUri(ast.ast)) then return false end if source.type == 'field' @@ -265,7 +265,7 @@ local function getParams(func, oop) end local function checkLocal(ast, word, offset, results) - local locals = guide.getVisibleLocals(ast.ast, offset) + local locals = searcher.getVisibleLocals(ast.ast, offset) for name, source in pairs(locals) do if isSameSource(ast, source, offset) then goto CONTINUE @@ -311,9 +311,9 @@ local function checkModule(ast, word, offset, results) if not config.config.completion.autoRequire then return end - local locals = guide.getVisibleLocals(ast.ast, offset) + local locals = searcher.getVisibleLocals(ast.ast, offset) for uri in files.eachFile() do - if files.eq(uri, guide.getUri(ast.ast)) then + if files.eq(uri, searcher.getUri(ast.ast)) then goto CONTINUE end local originUri = files.getOriginUri(uri) @@ -372,7 +372,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset) return nil end local textEdit, additionalTextEdits - local uri = guide.getUri(parent) + local uri = searcher.getUri(parent) local text = files.getText(uri) local wordStart if word == '' then @@ -417,7 +417,7 @@ local function checkFieldFromFieldToIndex(name, parent, word, start, offset) end local function checkFieldThen(name, src, word, start, offset, parent, oop, results) - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src local kind = define.CompletionItemKind.Field if value.type == 'function' or value.type == 'doc.type.function' then @@ -443,7 +443,7 @@ local function checkFieldThen(name, src, word, start, offset, parent, oop, resul if oop then return end - local literal = guide.getLiteral(value) + local literal = searcher.getLiteral(value) if literal ~= nil then kind = define.CompletionItemKind.Enum end @@ -492,7 +492,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res end local funcLabel if config.config.completion.showParams then - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src if value.type == 'function' or value.type == 'doc.type.function' then funcLabel = name .. getParams(value, oop) @@ -538,7 +538,7 @@ local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, res end local function checkGlobal(ast, word, start, offset, parent, oop, results) - local locals = guide.getVisibleLocals(ast.ast, offset) + local locals = searcher.getVisibleLocals(ast.ast, offset) local refs = vm.getGlobalSets '*' checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global') end @@ -554,7 +554,7 @@ local function checkField(ast, word, start, offset, parent, oop, results) end local function checkTableField(ast, word, start, results) - local source = guide.eachSourceContain(ast.ast, start, function (source) + local source = searcher.eachSourceContain(ast.ast, start, function (source) if source.start == start and source.parent and source.parent.type == 'table' then @@ -565,7 +565,7 @@ local function checkTableField(ast, word, start, results) return end local used = {} - guide.eachSourceType(ast.ast, 'tablefield', function (src) + searcher.eachSourceType(ast.ast, 'tablefield', function (src) if not src.field then return end @@ -670,7 +670,7 @@ local function checkCommon(myUri, word, text, offset, results) end local function isInString(ast, offset) - return guide.eachSourceContain(ast.ast, offset, function (source) + return searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'string' then return true end @@ -686,7 +686,7 @@ local function checkKeyWord(ast, text, start, offset, word, hasSpace, afterLocal isExp = isExp, text = text, start = start, - uri = guide.getUri(ast.ast), + uri = searcher.getUri(ast.ast), offset = offset, ast = ast, } @@ -750,7 +750,7 @@ end local function checkProvideLocal(ast, word, start, results) local block - guide.eachSourceContain(ast.ast, start, function (source) + searcher.eachSourceContain(ast.ast, start, function (source) if source.type == 'function' or source.type == 'main' then block = source @@ -760,7 +760,7 @@ local function checkProvideLocal(ast, word, start, results) return end local used = {} - guide.eachSourceType(block, 'getglobal', function (source) + searcher.eachSourceType(block, 'getglobal', function (source) if source.start > start and not used[source[1]] and matchKey(word, source[1]) then @@ -771,7 +771,7 @@ local function checkProvideLocal(ast, word, start, results) } end end) - guide.eachSourceType(block, 'getlocal', function (source) + searcher.eachSourceType(block, 'getlocal', function (source) if source.start > start and not used[source[1]] and matchKey(word, source[1]) then @@ -785,7 +785,7 @@ local function checkProvideLocal(ast, word, start, results) end local function checkFunctionArgByDocParam(ast, word, start, results) - local func = guide.eachSourceContain(ast.ast, start, function (source) + local func = searcher.eachSourceContain(ast.ast, start, function (source) if source.type == 'function' then return source end @@ -836,8 +836,8 @@ end local function checkUri(ast, text, offset, results) local collect = {} - local myUri = guide.getUri(ast.ast) - guide.eachSourceContain(ast.ast, offset, function (source) + local myUri = searcher.getUri(ast.ast) + searcher.eachSourceContain(ast.ast, offset, function (source) if source.type ~= 'string' then return end @@ -850,7 +850,7 @@ local function checkUri(ast, text, offset, results) end local call = callargs.parent local func = call.node - local literal = guide.getLiteral(source) + local literal = searcher.getLiteral(source) local libName = vm.getLibraryName(func) if not libName then return @@ -957,7 +957,7 @@ local function checkUri(ast, text, offset, results) end local function checkLenPlusOne(ast, text, offset, results) - guide.eachSourceContain(ast.ast, offset, function (source) + searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'getindex' or source.type == 'setindex' then local _, pos = text:find('%s*%[%s*%#', source.node.finish) @@ -969,7 +969,7 @@ local function checkLenPlusOne(ast, text, offset, results) if not matchKey(writingText, nodeText) then return end - if source.parent == guide.getParentBlock(source) then + if source.parent == searcher.getParentBlock(source) then -- state local label = text:match('%#[ \t]*', pos) .. nodeText .. '+1' local eq = text:find('^%s*%]?%s*%=', source.finish) @@ -1069,7 +1069,7 @@ local function checkEqualEnumLeft(ast, text, offset, source, results) if not source then return end - local str = guide.eachSourceContain(ast.ast, offset, function (src) + local str = searcher.eachSourceContain(ast.ast, offset, function (src) if src.type == 'string' then return src end @@ -1104,7 +1104,7 @@ local function checkEqualEnum(ast, text, offset, results) end local function checkEqualEnumInString(ast, text, offset, results) - local source = guide.eachSourceContain(ast.ast, offset, function (source) + local source = searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'binary' then if source.op.type == '==' or source.op.type == '~=' then @@ -1135,7 +1135,7 @@ local function checkEqualEnumInString(ast, text, offset, results) end local function isFuncArg(ast, offset) - return guide.eachSourceContain(ast.ast, offset, function (source) + return searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'funcargs' then return true end @@ -1197,7 +1197,7 @@ local function tryWord(ast, text, offset, results) else checkLocal(ast, word, start, results) checkTableField(ast, word, start, results) - local env = guide.getENV(ast.ast, start) + local env = searcher.getENV(ast.ast, start) checkGlobal(ast, word, start, offset, env, false, results) checkModule(ast, word, start, results) end @@ -1271,7 +1271,7 @@ local function getCallEnums(source, index) end for _, unit in ipairs(vm.getDocTypeUnits(doc.extends) or {}) do if unit.type == 'doc.type.function' then - local text = files.getText(guide.getUri(unit)) + local text = files.getText(searcher.getUri(unit)) enums[#enums+1] = { label = text:sub(unit.start, unit.finish), description = doc.comment, @@ -1299,7 +1299,7 @@ end local function findCall(ast, text, offset) local call - guide.eachSourceContain(ast.ast, offset, function (src) + searcher.eachSourceContain(ast.ast, offset, function (src) if src.type == 'call' then if not call or call.start < src.start then call = src @@ -1338,14 +1338,14 @@ local function checkTableLiteralField(ast, text, offset, tbl, fields, results) for _, field in ipairs(tbl) do if field.type == 'tablefield' or field.type == 'tableindex' then - local name = guide.getKeyName(field) + local name = searcher.getKeyName(field) if name then mark[name] = true end end end table.sort(fields, function (a, b) - return guide.getKeyName(a) < guide.getKeyName(b) + return searcher.getKeyName(a) < searcher.getKeyName(b) end) -- {$} local left = lookBackward.findWord(text, offset) @@ -1358,12 +1358,12 @@ local function checkTableLiteralField(ast, text, offset, tbl, fields, results) end if left then for _, field in ipairs(fields) do - local name = guide.getKeyName(field) - if not mark[name] and matchKey(left, guide.getKeyName(field)) then + local name = searcher.getKeyName(field) + if not mark[name] and matchKey(left, searcher.getKeyName(field)) then results[#results+1] = { - label = guide.getKeyName(field), + label = searcher.getKeyName(field), kind = define.CompletionItemKind.Property, - insertText = ('%s = $0'):format(guide.getKeyName(field)), + insertText = ('%s = $0'):format(searcher.getKeyName(field)), id = stack(function () return { detail = buildDetail(field), @@ -1398,14 +1398,14 @@ local function checkTableLiteralFieldByCall(ast, text, offset, call, defs, index return end for _, def in ipairs(defs) do - local func = guide.getObjectValue(def) or def + local func = searcher.getObjectValue(def) or def local param = getFuncParamByCallIndex(func, index) if not param then goto CONTINUE end local defs = vm.getDefFields(param, 0) for _, field in ipairs(defs) do - local name = guide.getKeyName(field) + local name = searcher.getKeyName(field) if name and not mark[name] then mark[name] = true fields[#fields+1] = field @@ -1428,7 +1428,7 @@ local function tryCallArg(ast, text, offset, results) end local defs = vm.getDefs(call.node, 0) for _, def in ipairs(defs) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def local enums = getCallEnums(def, argIndex) if enums then mergeEnums(myResults, enums, arg) @@ -1458,7 +1458,7 @@ local function tryTable(ast, text, offset, results) end local defs = vm.getDefFields(tbl, 0) for _, field in ipairs(defs) do - local name = guide.getKeyName(field) + local name = searcher.getKeyName(field) if name and not mark[name] then mark[name] = true fields[#fields+1] = field @@ -1514,7 +1514,7 @@ end local function getLuaDocByContain(ast, offset) local result local range = math.huge - guide.eachSourceContain(ast.ast.docs, offset, function (src) + searcher.eachSourceContain(ast.ast.docs, offset, function (src) if not src.start then return end @@ -1591,7 +1591,7 @@ local function tryLuaDocBySource(ast, offset, source, results) return true elseif source.type == 'doc.param.name' then local funcs = {} - guide.eachSourceBetween(ast.ast, offset, math.huge, function (src) + searcher.eachSourceBetween(ast.ast, offset, math.huge, function (src) if src.type == 'function' and src.start > offset then funcs[#funcs+1] = src end @@ -1667,7 +1667,7 @@ local function tryLuaDocByErr(ast, offset, err, docState, results) end elseif err.type == 'LUADOC_MISS_PARAM_NAME' then local funcs = {} - guide.eachSourceBetween(ast.ast, offset, math.huge, function (src) + searcher.eachSourceBetween(ast.ast, offset, math.huge, function (src) if src.type == 'function' and src.start > offset then funcs[#funcs+1] = src end diff --git a/script/core/definition.lua b/script/core/definition.lua index b26bb922..5d996a88 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -1,14 +1,15 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local workspace = require 'workspace' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' +local guide = require 'parser.guide' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -20,7 +21,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else @@ -66,7 +67,7 @@ local function checkRequire(source, offset) end local call = callargs.parent local func = call.node - local literal = guide.getLiteral(source) + local literal = searcher.getLiteral(source) local libName = vm.getLibraryName(func) if not libName then return nil @@ -130,7 +131,7 @@ return function (uri, offset) local defs = vm.getDefs(source, 0) local values = {} for _, src in ipairs(defs) do - local value = guide.getObjectValue(src) + local value = searcher.getObjectValue(src) if value and value ~= src then values[value] = true end @@ -148,9 +149,6 @@ return function (uri, offset) goto CONTINUE end src = src.field or src.method or src.index or src - if src.type == 'table' and src.parent.type ~= 'return' then - goto CONTINUE - end if src.type == 'doc.class.name' and source.type ~= 'doc.type.name' and source.type ~= 'doc.extends.name' diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua index 19bb4f97..0b8385d8 100644 --- a/script/core/diagnostics/ambiguity-1.lua +++ b/script/core/diagnostics/ambiguity-1.lua @@ -1,6 +1,6 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' local opMap = { ['+'] = true, @@ -30,7 +30,7 @@ return function (uri, callback) return end local text = files.getText(uri) - guide.eachSourceType(ast.ast, 'binary', function (source) + searcher.eachSourceType(ast.ast, 'binary', function (source) if source.op.type ~= 'or' then return end diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua index 702cd904..3d32324d 100644 --- a/script/core/diagnostics/circle-doc-class.lua +++ b/script/core/diagnostics/circle-doc-class.lua @@ -1,8 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' return function (uri, callback) local state = files.getAst(uri) @@ -19,7 +18,7 @@ return function (uri, callback) if not doc.extends then goto CONTINUE end - local myName = guide.getKeyName(doc) + local myName = searcher.getKeyName(doc) local list = { doc } local mark = {} for i = 1, 999 do diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index d1983c42..11c2b820 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -1,7 +1,6 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' return function (uri, callback) local state = files.getAst(uri) @@ -9,7 +8,7 @@ return function (uri, callback) return end - guide.eachSourceType(state.ast, 'local', function (source) + searcher.eachSourceType(state.ast, 'local', function (source) if not source.attrs then return end diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua index f23755ea..dc7226ae 100644 --- a/script/core/diagnostics/code-after-break.lua +++ b/script/core/diagnostics/code-after-break.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local define = require 'proto.define' return function (uri, callback) local state = files.getAst(uri) @@ -10,7 +10,7 @@ return function (uri, callback) end local mark = {} - guide.eachSourceType(state.ast, 'break', function (source) + searcher.eachSourceType(state.ast, 'break', function (source) local list = source.parent if mark[list] then return diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua index 65099af8..f682ec3c 100644 --- a/script/core/diagnostics/count-down-loop.lua +++ b/script/core/diagnostics/count-down-loop.lua @@ -1,6 +1,6 @@ -local files = require "files" -local guide = require "core.guide" -local lang = require 'language' +local files = require "files" +local searcher = require "core.searcher" +local lang = require 'language' return function (uri, callback) local state = files.getAst(uri) @@ -9,7 +9,7 @@ return function (uri, callback) return end - guide.eachSourceType(state.ast, 'loop', function (source) + searcher.eachSourceType(state.ast, 'loop', function (source) if not source.loc or not source.loc.value then return end diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua index 60d60946..d6bccc10 100644 --- a/script/core/diagnostics/deprecated.lua +++ b/script/core/diagnostics/deprecated.lua @@ -1,10 +1,10 @@ -local files = require 'files' -local vm = require 'vm' -local lang = require 'language' -local guide = require 'core.guide' -local config = require 'config' -local define = require 'proto.define' -local await = require 'await' +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local searcher = require 'core.searcher' +local config = require 'config' +local define = require 'proto.define' +local await = require 'await' return function (uri, callback) local ast = files.getAst(uri) @@ -12,7 +12,7 @@ return function (uri, callback) return end - guide.eachSource(ast.ast, function (src) + searcher.eachSource(ast.ast, function (src) if src.type ~= 'getglobal' and src.type ~= 'getfield' and src.type ~= 'getindex' @@ -20,7 +20,7 @@ return function (uri, callback) return end if src.type == 'getglobal' then - local key = guide.getKeyName(src) + local key = searcher.getKeyName(src) if not key then return end diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua index 8c6696a9..c625d234 100644 --- a/script/core/diagnostics/duplicate-doc-class.lua +++ b/script/core/diagnostics/duplicate-doc-class.lua @@ -1,8 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local vm = require 'vm' return function (uri, callback) local state = files.getAst(uri) @@ -18,7 +17,7 @@ return function (uri, callback) for _, doc in ipairs(state.ast.docs) do if doc.type == 'doc.class' or doc.type == 'doc.alias' then - local name = guide.getKeyName(doc) + local name = searcher.getKeyName(doc) if not cache[name] then local docs = vm.getDocTypes(name) cache[name] = {} @@ -28,7 +27,7 @@ return function (uri, callback) cache[name][#cache[name]+1] = { start = otherDoc.start, finish = otherDoc.finish, - uri = guide.getUri(otherDoc), + uri = searcher.getUri(otherDoc), } end end diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua index 5e63d39e..65fb00cd 100644 --- a/script/core/diagnostics/duplicate-index.lua +++ b/script/core/diagnostics/duplicate-index.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'core.guide' -local lang = require 'language' -local define = require 'proto.define' -local vm = require 'vm' +local files = require 'files' +local searcher = require 'core.searcher' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' return function (uri, callback) local ast = files.getAst(uri) @@ -10,7 +10,7 @@ return function (uri, callback) return end - guide.eachSourceType(ast.ast, 'table', function (source) + searcher.eachSourceType(ast.ast, 'table', function (source) local mark = {} for _, obj in ipairs(source) do if obj.type == 'tablefield' diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua index c1e2285a..ff915217 100644 --- a/script/core/diagnostics/duplicate-set-field.lua +++ b/script/core/diagnostics/duplicate-set-field.lua @@ -1,8 +1,7 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' -local vm = require 'vm' return function (uri, callback) local ast = files.getAst(uri) @@ -10,7 +9,7 @@ return function (uri, callback) return end - guide.eachSourceType(ast.ast, 'local', function (source) + searcher.eachSourceType(ast.ast, 'local', function (source) if not source.ref then return end @@ -26,11 +25,11 @@ return function (uri, callback) if nxt.type == 'setfield' or nxt.type == 'setmethod' or nxt.type == 'setindex' then - local name = guide.getKeyName(nxt) + local name = searcher.getKeyName(nxt) if not name then goto CONTINUE end - local value = guide.getObjectValue(nxt) + local value = searcher.getObjectValue(nxt) if not value or value.type ~= 'function' then goto CONTINUE end @@ -47,7 +46,7 @@ return function (uri, callback) end local blocks = {} for _, value in ipairs(values) do - local block = guide.getBlock(value) + local block = searcher.getBlock(value) if not blocks[block] then blocks[block] = {} end diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua index 690a4ca2..abd20bde 100644 --- a/script/core/diagnostics/empty-block.lua +++ b/script/core/diagnostics/empty-block.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' @@ -11,7 +11,7 @@ return function (uri, callback) return end - guide.eachSourceType(ast.ast, 'if', function (source) + searcher.eachSourceType(ast.ast, 'if', function (source) for _, block in ipairs(source) do if #block > 0 then return @@ -24,7 +24,7 @@ return function (uri, callback) message = lang.script.DIAG_EMPTY_BLOCK, } end) - guide.eachSourceType(ast.ast, 'loop', function (source) + searcher.eachSourceType(ast.ast, 'loop', function (source) if #source > 0 then return end @@ -35,7 +35,7 @@ return function (uri, callback) message = lang.script.DIAG_EMPTY_BLOCK, } end) - guide.eachSourceType(ast.ast, 'in', function (source) + searcher.eachSourceType(ast.ast, 'in', function (source) if #source > 0 then return end diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index de23bc76..02cd4f3f 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' -- TODO: 检查路径是否可达 @@ -12,8 +12,8 @@ return function (uri, callback) if not ast then return end - local root = guide.getRoot(ast.ast) - local env = guide.getENV(root) + local root = searcher.getRoot(ast.ast) + local env = searcher.getENV(root) local nilDefs = {} if not env.ref then @@ -36,7 +36,7 @@ return function (uri, callback) if node.tag == '_ENV' then local ok for _, nilDef in ipairs(nilDefs) do - local mode, pathA = guide.getPath(nilDef, source) + local mode, pathA = searcher.getPath(nilDef, source) if mode == 'before' and mayRun(pathA) then ok = nilDef @@ -61,6 +61,6 @@ return function (uri, callback) end end - guide.eachSourceType(ast.ast, 'getglobal', check) - guide.eachSourceType(ast.ast, 'setglobal', check) + searcher.eachSourceType(ast.ast, 'getglobal', check) + searcher.eachSourceType(ast.ast, 'setglobal', check) end diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua index 9c094701..aaeb2c94 100644 --- a/script/core/diagnostics/lowercase-global.lua +++ b/script/core/diagnostics/lowercase-global.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local config = require 'config' local vm = require 'vm' @@ -28,8 +28,8 @@ return function (uri, callback) definedGlobal[name] = true end - guide.eachSourceType(ast.ast, 'setglobal', function (source) - local name = guide.getKeyName(source) + searcher.eachSourceType(ast.ast, 'setglobal', function (source) + local name = searcher.getKeyName(source) if definedGlobal[name] then return end diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua index 0727c2fd..a71ae3e1 100644 --- a/script/core/diagnostics/newfield-call.lua +++ b/script/core/diagnostics/newfield-call.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' return function (uri, callback) @@ -11,15 +11,15 @@ return function (uri, callback) local lines = files.getLines(uri) local text = files.getText(uri) - guide.eachSourceType(ast.ast, 'table', function (source) + searcher.eachSourceType(ast.ast, 'table', function (source) for i = 1, #source do local field = source[i] if field.type == 'call' then local func = field.node local args = field.args if args then - local funcLine = guide.positionOf(lines, func.finish) - local argsLine = guide.positionOf(lines, args.start) + local funcLine = searcher.positionOf(lines, func.finish) + local argsLine = searcher.positionOf(lines, args.start) if argsLine > funcLine then callback { start = field.start, diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua index 807f76a2..31a8d09f 100644 --- a/script/core/diagnostics/newline-call.lua +++ b/script/core/diagnostics/newline-call.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' return function (uri, callback) @@ -10,7 +10,7 @@ return function (uri, callback) return end - guide.eachSourceType(ast.ast, 'call', function (source) + searcher.eachSourceType(ast.ast, 'call', function (source) local node = source.node local args = source.args if not args then @@ -26,8 +26,8 @@ return function (uri, callback) return end - local nodeRow = guide.positionOf(lines, node.finish) - local argRow = guide.positionOf(lines, args.start) + local nodeRow = searcher.positionOf(lines, node.finish) + local argRow = searcher.positionOf(lines, args.start) if nodeRow == argRow then return end diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua index ffaab821..23af570a 100644 --- a/script/core/diagnostics/no-implicit-any.lua +++ b/script/core/diagnostics/no-implicit-any.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -10,7 +10,7 @@ return function (uri, callback) return end - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) if source.type ~= 'local' and source.type ~= 'setlocal' and source.type ~= 'setglobal' diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua index 857d80d2..4922831b 100644 --- a/script/core/diagnostics/redefined-local.lua +++ b/script/core/diagnostics/redefined-local.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' return function (uri, callback) @@ -7,13 +7,13 @@ return function (uri, callback) if not ast then return end - guide.eachSourceType(ast.ast, 'local', function (source) + searcher.eachSourceType(ast.ast, 'local', function (source) local name = source[1] if name == '_' or name == ast.ENVMode then return end - local exist = guide.getLocal(source, name, source.start-1) + local exist = searcher.getLocal(source, name, source.start-1) if exist then callback { start = source.start, diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index c5bcd5a5..a6907bda 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local lang = require 'language' local define = require 'proto.define' @@ -74,7 +74,7 @@ return function (uri, callback) local cache = vm.getCache 'redundant-parameter' - guide.eachSourceType(ast.ast, 'call', function (source) + searcher.eachSourceType(ast.ast, 'call', function (source) local callArgs = countCallArgs(source) if callArgs == 0 then return diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua index 0a4b1d57..6ee92d26 100644 --- a/script/core/diagnostics/trailing-space.lua +++ b/script/core/diagnostics/trailing-space.lua @@ -1,10 +1,10 @@ local files = require 'files' local lang = require 'language' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local function isInString(ast, offset) local result = false - guide.eachSourceType(ast, 'string', function (source) + searcher.eachSourceType(ast, 'string', function (source) if offset >= source.start and offset <= source.finish then result = true end diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua index b2b2800c..006223db 100644 --- a/script/core/diagnostics/unbalanced-assignments.lua +++ b/script/core/diagnostics/unbalanced-assignments.lua @@ -1,7 +1,7 @@ local files = require 'files' local define = require 'proto.define' local lang = require 'language' -local guide = require 'core.guide' +local searcher = require 'core.searcher' return function (uri, callback, code) local ast = files.getAst(uri) @@ -31,7 +31,7 @@ return function (uri, callback, code) end end - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) if source.type == 'local' or source.type == 'setlocal' or source.type == 'setglobal' diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua index a91cfa7f..991b5849 100644 --- a/script/core/diagnostics/undefined-doc-class.lua +++ b/script/core/diagnostics/undefined-doc-class.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index d8a4363b..54f08ee6 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' @@ -44,7 +44,7 @@ return function (uri, callback) return true end - guide.eachSource(state.ast.docs, function (source) + searcher.eachSource(state.ast.docs, function (source) if source.type ~= 'doc.extends.name' and source.type ~= 'doc.type.name' then return diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua index 0bf371e5..4a97947d 100644 --- a/script/core/diagnostics/undefined-doc-param.lua +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local define = require 'proto.define' local vm = require 'vm' diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua index 89efb8c7..a32ad306 100644 --- a/script/core/diagnostics/undefined-env-child.lua +++ b/script/core/diagnostics/undefined-env-child.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local lang = require 'language' @@ -8,12 +8,12 @@ return function (uri, callback) if not ast then return end - guide.eachSourceType(ast.ast, 'getglobal', function (source) + searcher.eachSourceType(ast.ast, 'getglobal', function (source) -- 单独验证自己是否在重载过的 _ENV 中有定义 if source.node.tag == '_ENV' then return end - local defs = guide.requestDefinition(source) + local defs = searcher.requestDefinition(source) if #defs > 0 then return end diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index b10c9ab0..1f88740e 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -2,7 +2,7 @@ local files = require 'files' local vm = require 'vm' local lang = require 'language' local config = require 'config' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local define = require 'proto.define' return function (uri, callback) @@ -87,7 +87,7 @@ return function (uri, callback) end local function checkUndefinedField(src) - local fieldName = guide.getKeyName(src) + local fieldName = searcher.getKeyName(src) local allDocClass = getAllDocClassFromInfer(src.node) if (not allDocClass) or (#allDocClass == 0) then @@ -118,6 +118,6 @@ return function (uri, callback) end end end - guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField); - guide.eachSourceType(ast.ast, 'getmethod', checkUndefinedField); + searcher.eachSourceType(ast.ast, 'getfield', checkUndefinedField); + searcher.eachSourceType(ast.ast, 'getmethod', checkUndefinedField); end diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index 161d8856..3c7f02d1 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -2,7 +2,7 @@ local files = require 'files' local vm = require 'vm' local lang = require 'language' local config = require 'config' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local define = require 'proto.define' local requireLike = { @@ -19,8 +19,8 @@ return function (uri, callback) end -- 遍历全局变量,检查所有没有 set 模式的全局变量 - guide.eachSourceType(ast.ast, 'getglobal', function (src) - local key = guide.getKeyName(src) + searcher.eachSourceType(ast.ast, 'getglobal', function (src) + local key = searcher.getKeyName(src) if not key then return end diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua index b6f92e60..2d224e5e 100644 --- a/script/core/diagnostics/unused-function.lua +++ b/script/core/diagnostics/unused-function.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local define = require 'proto.define' local lang = require 'language' @@ -45,7 +45,7 @@ return function (uri, callback) local refs = vm.getRefs(source) for _, src in ipairs(refs) do if vm.isGet(src) then - local func = guide.getParentFunction(src) + local func = searcher.getParentFunction(src) if not checkFunction(func) then hasGet = true break @@ -75,7 +75,7 @@ return function (uri, callback) end -- 只检查局部函数 - guide.eachSourceType(ast.ast, 'function', function (source) + searcher.eachSourceType(ast.ast, 'function', function (source) checkFunction(source) end) end diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua index e2d5e49a..5d9488a1 100644 --- a/script/core/diagnostics/unused-label.lua +++ b/script/core/diagnostics/unused-label.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local define = require 'proto.define' local lang = require 'language' @@ -9,7 +9,7 @@ return function (uri, callback) return end - guide.eachSourceType(ast.ast, 'label', function (source) + searcher.eachSourceType(ast.ast, 'label', function (source) if not source.ref then callback { start = source.start, diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua index fde90cb8..4e3c8217 100644 --- a/script/core/diagnostics/unused-local.lua +++ b/script/core/diagnostics/unused-local.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local define = require 'proto.define' local lang = require 'language' @@ -81,7 +81,7 @@ return function (uri, callback) if not ast then return end - guide.eachSourceType(ast.ast, 'local', function (source) + searcher.eachSourceType(ast.ast, 'local', function (source) local name = source[1] if name == '_' or name == ast.ENVMode then diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua index ec0a05fb..301394c3 100644 --- a/script/core/diagnostics/unused-vararg.lua +++ b/script/core/diagnostics/unused-vararg.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local define = require 'proto.define' local lang = require 'language' @@ -9,7 +9,7 @@ return function (uri, callback) return end - guide.eachSourceType(ast.ast, 'function', function (source) + searcher.eachSourceType(ast.ast, 'function', function (source) local args = source.args if not args then return diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua index cc87e3ca..e36ba29b 100644 --- a/script/core/document-symbol.lua +++ b/script/core/document-symbol.lua @@ -1,8 +1,8 @@ -local await = require 'await' -local files = require 'files' -local guide = require 'core.guide' -local define = require 'proto.define' -local util = require 'utility' +local await = require 'await' +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local util = require 'utility' local function buildName(source, text) if source.type == 'setmethod' diff --git a/script/core/find-source.lua b/script/core/find-source.lua index b36306b6..edbb1e2c 100644 --- a/script/core/find-source.lua +++ b/script/core/find-source.lua @@ -1,4 +1,4 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local function isValidFunctionPos(source, offset) for i = 1, #source.keyword // 2 do diff --git a/script/core/folding.lua b/script/core/folding.lua index 15678995..1bbae944 100644 --- a/script/core/folding.lua +++ b/script/core/folding.lua @@ -1,5 +1,5 @@ local files = require "files" -local guide = require "core.guide" +local searcher = require "core.searcher" local util = require 'utility' local Care = { @@ -153,7 +153,7 @@ return function (uri) local regions = {} local status = {} - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) local tp = source.type if Care[tp] then Care[tp](source, text, regions) diff --git a/script/core/generic.lua b/script/core/generic.lua new file mode 100644 index 00000000..53ced59c --- /dev/null +++ b/script/core/generic.lua @@ -0,0 +1,220 @@ +local linker = require "core.linker" +---@class generic.value +---@field type string +---@field closure generic.closure +---@field proto parser.guide.object +---@field parent parser.guide.object + +---@class generic.closure +---@field type string +---@field proto parser.guide.object +---@field upvalues table<string, generic.value[]> +---@field params generic.value[] +---@field returns generic.value[] + +local m = {} + +---@param closure generic.closure +---@param proto parser.guide.object +local function instantValue(closure, proto) + ---@type generic.value + local value = { + type = 'generic.value', + closure = closure, + proto = proto, + parent = proto.parent, + } + return value +end + +---递归实例化对象 +---@param proto parser.guide.object +---@return generic.value +local function createValue(closure, proto, callback, road) + if callback then + road = road or {} + end + if proto.type == 'doc.type' then + local types = {} + local hasGeneric + for i, tp in ipairs(proto.types) do + local genericValue = createValue(closure, tp, callback, road) + if genericValue then + hasGeneric = true + types[i] = genericValue + else + types[i] = tp + end + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.types = types + linker.compileLink(value) + return value + end + if proto.type == 'doc.type.name' then + if not proto.typeGeneric then + return nil + end + local key = proto[1] + local value = instantValue(closure, proto) + if callback then + callback(road, key, proto) + end + linker.compileLink(value) + return value + end + if proto.type == 'doc.type.function' then + local hasGeneric + local args = {} + local returns = {} + for i, arg in ipairs(proto.args) do + local value = createValue(closure, arg, callback, road) + if value then + hasGeneric = true + end + args[i] = value or arg + end + for i, rtn in ipairs(proto.returns) do + local value = createValue(closure, rtn, callback, road) + if value then + hasGeneric = true + end + returns[i] = value or rtn + end + if not hasGeneric then + return nil + end + local value = instantValue(closure, proto) + value.args = args + value.returns = returns + value.isGeneric = true + linker.compileLink(value) + linker.pushSource(value) + return value + end + if proto.type == 'doc.type.array' then + if road then + road[#road+1] = linker.SPLIT_CHAR + end + local node = createValue(closure, proto.node, callback, road) + if road then + road[#road] = nil + end + if not node then + return nil + end + local value = instantValue(closure, proto) + value.node = node + return value + end + if proto.type == 'doc.type.table' then + local tkey = createValue(closure, proto.key, callback, road) + road[#road+1] = linker.SPLIT_CHAR + local tvalue = createValue(closure, proto.value, callback, road) + road[#road] = nil + if not tkey and not tvalue then + return nil + end + local value = instantValue(closure, proto) + value.key = tkey or proto.key + value.value = tvalue or proto.value + return value + end +end + +local function buildValue(road, key, proto, param, upvalues) + local paramID + if proto.literal then + local str = param.type == 'string' and param[1] + if not str then + return + end + paramID = 'dn:' .. str + else + paramID = linker.getID(param) + end + if not paramID then + return + end + if not upvalues[key] then + upvalues[key] = {} + end + upvalues[key][#upvalues[key]+1] = paramID .. table.concat(road) +end + +-- 为所有的 param 与 return 创建副本 +---@param closure generic.closure +local function buildValues(closure) + local protoFunction = closure.proto + local upvalues = closure.upvalues + local params = closure.call.args + + if protoFunction.type == 'function' then + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.param' then + local extends = doc.extends + local index = extends.paramIndex + local param = params and params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + end + for _, doc in ipairs(protoFunction.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + closure.returns[rtn.returnIndex] = createValue(closure, rtn) or rtn + end + end + end + end + if protoFunction.type == 'doc.type.function' then + for index, arg in ipairs(protoFunction.args) do + local extends = arg.extends + local param = params[index] + closure.params[index] = param and createValue(closure, extends, function (road, key, proto) + buildValue(road, key, proto, param, upvalues) + end) or extends + end + for index, rtn in ipairs(protoFunction.returns) do + closure.returns[index] = createValue(closure, rtn) or rtn + end + end +end + +---创建一个闭包 +---@param proto parser.guide.object|generic.value # 原型函数|泛型值 +---@return generic.closure +function m.createClosure(proto, call) + local protoFunction, parentClosure + if proto.type == 'function' then + protoFunction = proto + elseif proto.type == 'generic.value' then + protoFunction = proto.proto + parentClosure = proto.closure + end + ---@type generic.closure + local closure = { + type = 'generic.closure', + parent = protoFunction.parent, + proto = protoFunction, + call = call, + upvalues = parentClosure and parentClosure.upvalues or {}, + params = {}, + returns = {}, + } + buildValues(closure) + + if #closure.returns == 0 then + return nil + end + + linker.compileLink(closure) + + return closure +end + +return m diff --git a/script/core/guide.lua b/script/core/guide2.lua index e4871060..576c0c20 100644 --- a/script/core/guide.lua +++ b/script/core/guide2.lua @@ -292,8 +292,13 @@ end ---@param obj parser.guide.object ---@return parser.guide.object function m.getRoot(obj) + local source = obj + if source._root then + return source._root + end for _ = 1, 1000 do if obj.type == 'main' then + source._root = obj return obj end local parent = obj.parent diff --git a/script/core/highlight.lua b/script/core/highlight.lua index 12ec114f..b070c77e 100644 --- a/script/core/highlight.lua +++ b/script/core/highlight.lua @@ -1,4 +1,4 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local files = require 'files' local vm = require 'vm' local define = require 'proto.define' @@ -6,7 +6,7 @@ local findSource = require 'core.find-source' local util = require 'utility' local function eachRef(source, callback) - local results = guide.requestReference(source) + local results = searcher.requestReference(source) for i = 1, #results do callback(results[i]) end @@ -16,11 +16,11 @@ local function eachField(source, callback) if not source then return end - local isGlobal = guide.isGlobal(source) - local results = guide.requestReference(source) + local isGlobal = searcher.isGlobal(source) + local results = searcher.requestReference(source) for i = 1, #results do local res = results[i] - if isGlobal == guide.isGlobal(res) then + if isGlobal == searcher.isGlobal(res) then callback(res) end end @@ -107,7 +107,7 @@ local function makeIf(source, text, callback) end local function findKeyWord(ast, text, offset, callback) - guide.eachSourceContain(ast.ast, offset, function (source) + searcher.eachSourceContain(ast.ast, offset, function (source) if source.type == 'do' or source.type == 'function' or source.type == 'loop' diff --git a/script/core/hint.lua b/script/core/hint.lua index 13d01dc7..9c0d9cf0 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -1,7 +1,7 @@ -local files = require 'files' -local guide = require 'core.guide' -local vm = require 'vm' -local config = require 'config' +local files = require 'files' +local searcher = require 'core.searcher' +local vm = require 'vm' +local config = require 'config' local function typeHint(uri, edits, start, finish) local ast = files.getAst(uri) @@ -9,7 +9,7 @@ local function typeHint(uri, edits, start, finish) return end local mark = {} - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) if source.type ~= 'local' and source.type ~= 'setglobal' and source.type ~= 'tablefield' @@ -21,7 +21,7 @@ local function typeHint(uri, edits, start, finish) if source[1] == '_' then return end - if source.value and guide.isLiteral(source.value) then + if source.value and searcher.isLiteral(source.value) then return end if source.parent.type == 'funcargs' then @@ -84,7 +84,7 @@ local function hasLiteralArgInCall(call) return false end for _, arg in ipairs(call.args) do - if guide.isLiteral(arg) then + if searcher.isLiteral(arg) then return true end end @@ -100,7 +100,7 @@ local function paramName(uri, edits, start, finish) return end local mark = {} - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) if source.type ~= 'call' then return end @@ -130,7 +130,7 @@ local function paramName(uri, edits, start, finish) table.remove(args, 1) end for i, arg in ipairs(source.args) do - if not mark[arg] and guide.isLiteral(arg) then + if not mark[arg] and searcher.isLiteral(arg) then mark[arg] = true if args[i] and args[i] ~= '' then edits[#edits+1] = { diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua index 324d28af..b8c9eba0 100644 --- a/script/core/hover/arg.lua +++ b/script/core/hover/arg.lua @@ -1,4 +1,4 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local function optionalArg(arg) @@ -29,7 +29,7 @@ local function asFunction(source, oop) if arg.dummy then goto CONTINUE end - local name = arg.name or guide.getKeyName(arg) + local name = arg.name or searcher.getKeyName(arg) if name then args[#args+1] = ('%s%s: %s'):format( name, diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index 401ca5a7..85224c66 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -2,7 +2,7 @@ local vm = require 'vm' local ws = require 'workspace' local furi = require 'file-uri' local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local markdown = require 'provider.markdown' local config = require 'config' local lang = require 'language' @@ -72,7 +72,7 @@ local function asStringView(source, literal) end local function asString(source) - local literal = guide.getLiteral(source) + local literal = searcher.getLiteral(source) if type(literal) ~= 'string' then return nil end @@ -127,7 +127,7 @@ local function tryDocClassComment(source) for _, def in ipairs(vm.getDefs(source, 0)) do if def.type == 'doc.class.name' or def.type == 'doc.alias.name' then - local class = guide.getDocState(def) + local class = searcher.getDocState(def) local comment = getBindComment(class, class.bindGroup, class) if comment then return comment @@ -180,7 +180,7 @@ local function isFunction(source) if source.type == 'function' then return true end - local value = guide.getObjectValue(source) + local value = searcher.getObjectValue(source) if not value then return false end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index 81285ef2..86c5b992 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local getLabel = require 'core.hover.label' local getDesc = require 'core.hover.description' @@ -48,7 +48,7 @@ local function getHoverAsFunction(source) local other = 0 local mark = {} for _, def in ipairs(values) do - def = guide.getObjectValue(def) or def + def = searcher.getObjectValue(def) or def if def.type == 'function' or def.type == 'doc.type.function' then eachFunctionAndOverload(def, function (value) diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index d93b14e3..da07200f 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -4,7 +4,7 @@ local buildReturn = require 'core.hover.return' local buildTable = require 'core.hover.table' local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local lang = require 'language' local config = require 'config' local files = require 'files' @@ -139,13 +139,13 @@ local function asDocField(source) if not class then return ('field ?.%s: %s'):format( name, - guide.viewInferType(infers) + searcher.viewInferType(infers) ) end return ('field %s.%s: %s'):format( class.class[1], name, - guide.viewInferType(infers) + searcher.viewInferType(infers) ) end @@ -177,7 +177,7 @@ local function asNumber(source) if type(num) ~= 'number' then return nil end - local uri = guide.getUri(source) + local uri = searcher.getUri(source) local text = files.getText(uri) if not text then return nil diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua index d583f1e1..fe0f2ffb 100644 --- a/script/core/hover/name.lua +++ b/script/core/hover/name.lua @@ -1,10 +1,10 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local buildName local function asLocal(source) - local name = guide.getKeyName(source) + local name = searcher.getKeyName(source) if not source.attrs then return name end @@ -21,8 +21,8 @@ local function asField(source, oop) if source.node.type ~= 'getglobal' then class = vm.getClass(source.node, 0) end - local node = class or guide.getKeyName(source.node) or '?' - local method = guide.getKeyName(source) + local node = class or searcher.getKeyName(source.node) or '?' + local method = searcher.getKeyName(source) if oop then return ('%s:%s'):format(node, method) else @@ -34,16 +34,16 @@ local function asTableField(source) if not source.field then return end - return guide.getKeyName(source.field) + return searcher.getKeyName(source.field) end local function asGlobal(source) - return guide.getKeyName(source) + return searcher.getKeyName(source) end local function asDocFunction(source) - local doc = guide.getParentType(source, 'doc.type') - or guide.getParentType(source, 'doc.overload') + local doc = searcher.getParentType(source, 'doc.type') + or searcher.getParentType(source, 'doc.overload') if not doc or not doc.bindSources then return '' end diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index c3e9656d..0825e77d 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,11 +1,11 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local function mergeTypes(returns) if type(returns) == 'string' then return returns end - return guide.mergeTypes(returns) + return searcher.mergeTypes(returns) end local function getReturnDualByDoc(source) diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index edb7751b..137c4f6b 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -1,6 +1,6 @@ local vm = require 'vm' local util = require 'utility' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local config = require 'config' local lang = require 'language' @@ -20,7 +20,7 @@ local function getKey(src) end return '[any]' end - if guide.getKeyType(src) == 'string' then + if searcher.getKeyType(src) == 'string' then if key:match '^[%a_][%w_]*$' then return key else @@ -31,7 +31,7 @@ local function getKey(src) end local function getFieldFull(src) - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src local tp = vm.getInferType(value, 0) --local class = vm.getClass(src) local literal = vm.getInferLiteral(value) @@ -45,7 +45,7 @@ local function getFieldFast(src) if src.bindDocs then return getFieldFull(src) end - local value = guide.getObjectValue(src) or src + local value = searcher.getObjectValue(src) or src if not value then return 'any' end @@ -193,7 +193,7 @@ local function mergeTypes(types) end end end - return guide.mergeTypes(results) + return searcher.mergeTypes(results) end local function clearClasses(classes) diff --git a/script/core/infer.lua b/script/core/infer.lua new file mode 100644 index 00000000..14ec6be2 --- /dev/null +++ b/script/core/infer.lua @@ -0,0 +1,385 @@ +local searcher = require 'core.searcher' +local config = require 'config' +local linker = require 'core.linker' + +local BE_LEN = {'#'} +local CLASS = {'CLASS'} + +local m = {} + +local function mergeTable(a, b) + if not b then + return + end + for v in pairs(b) do + a[v] = true + end +end + +local function searchInferOfUnary(value, infers) + local op = value.op.type + if op == 'not' then + infers['boolean'] = true + return + end + if op == '#' then + infers['integer'] = true + return + end + if op == '-' then + if m.hasType(value[1], 'integer') then + infers['integer'] = true + else + infers['number'] = true + end + return + end + if op == '~' then + infers['integer'] = true + return + end +end + +local function searchInferOfBinary(value, infers) + local op = value.op.type + if op == 'and' then + if m.isTrue(value[1]) then + mergeTable(infers, m.searchInfers(value[2])) + else + mergeTable(infers, m.searchInfers(value[1])) + end + return + end + if op == 'or' then + if m.isTrue(value[1]) then + mergeTable(infers, m.searchInfers(value[1])) + else + mergeTable(infers, m.searchInfers(value[2])) + end + return + end + if op == '==' + or op == '~=' + or op == '<' + or op == '>' + or op == '<=' + or op == '>=' then + infers['boolean'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '&' + or op == '|' then + infers['integer'] = true + return + end + if op == '..' then + infers['string'] = true + return + end + if op == '^' + or op == '/' then + infers['number'] = true + return + end + if op == '+' + or op == '-' + or op == '*' + or op == '%' + or op == '//' then + if m.hasType(value[1], 'integer') + and m.hasType(value[2], 'integer') then + infers['integer'] = true + else + infers['number'] = true + end + return + end +end + +local function searchInferOfValue(value, infers) + if value.type == 'string' then + infers['string'] = true + return true + end + if value.type == 'boolean' then + infers['boolean'] = true + return true + end + if value.type == 'table' then + infers['table'] = true + return true + end + if value.type == 'number' then + if math.type(value[1]) == 'integer' then + infers['integer'] = true + else + infers['number'] = true + end + return true + end + if value.type == 'nil' then + infers['nil'] = true + return true + end + if value.type == 'function' then + infers['function'] = true + return true + end + if value.type == 'unary' then + searchInferOfUnary(value, infers) + return true + end + if value.type == 'binary' then + searchInferOfBinary(value, infers) + return true + end + return false +end + +local function searchLiteralOfValue(value, literals) + if value.type == 'string' + or value.type == 'boolean' + or value.tyoe == 'number' + or value.type == 'integer' then + local v = value[1] + if v ~= nil then + literals[v] = true + end + return + end + if value.type == 'unary' then + local op = value.op.type + if op == '-' then + end + if op == '~' then + end + end + return +end + +local function bindClassOrType(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' + or doc.type == 'doc.type' then + return true + end + end + return false +end + +local function cleanInfers(infers) + local version = config.config.runtime.version + local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4' + if infers['number'] then + enableInteger = false + end + if not enableInteger and infers['integer'] then + infers['integer'] = nil + infers['number'] = true + end + if infers[BE_LEN] then + infers[BE_LEN] = nil + if not infers['table'] and not infers['string'] then + infers['table'] = true + infers['string'] = true + end + end + if infers[CLASS] then + infers[CLASS] = nil + infers['table'] = nil + end +end + +---合并对象的推断类型 +---@param infers string[] +---@return string +function m.viewInfers(infers) + if infers[0] then + return infers[0] + end + -- 如果有显性的 any ,则直接显示为 any + if infers['any'] then + infers[0] = 'any' + return 'any' + end + local result = {} + local count = 0 + for infer in pairs(infers) do + count = count + 1 + result[count] = infer + end + -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any + if count == 0 then + infers[0] = 'any' + return 'any' + end + table.sort(infers) + infers[0] = table.concat(result, '|') + return infers[0] +end + +---显示对象的推断类型 +---@param source parser.guide.object +---@return string +local function searchInfer(source, infers) + if bindClassOrType(source) then + return + end + if searchInferOfValue(source, infers) then + return + end + local value = searcher.getObjectValue(source) + if value then + searchInferOfValue(value, infers) + return + end + if source.type == 'doc.class.name' then + local name = source[1] + if name then + infers[name] = true + infers[CLASS] = true + end + return + end + -- X.a -> table + if source.next and source.next.node == source then + if source.next.type == 'setfield' + or source.next.type == 'setindex' + or source.next.type == 'setmethod' then + infers['table'] = true + end + return + end + -- return XX + if source.parent.type == 'return' then + infers['any'] = true + return + end + if source.parent.type == 'unary' then + local op = source.parent.op.type + -- # XX -> string | table + if op == '#' then + infers[BE_LEN] = true + return + end + if op == '-' then + infers['number'] = true + return + end + if op == '~' then + infers['integer'] = true + return + end + return + end + if source.parent.type == 'binary' then + local op = source.parent.op.type + if op == '+' + or op == '-' + or op == '*' + or op == '/' + or op == '//' + or op == '^' + or op == '%' then + infers['number'] = true + return + end + if op == '<<' + or op == '>>' + or op == '~' + or op == '|' + or op == '&' then + infers['integer'] = true + return + end + return + end +end + +local function searchLiteral(source, literals) + local value = searcher.getObjectValue(source) + if value then + searchLiteralOfValue(value, literals) + return + end +end + +---搜索对象的推断类型 +---@param source parser.guide.object +---@return string[] +function m.searchInfers(source) + if not source then + return nil + end + local defs = searcher.requestDefinition(source) + local infers = {} + searchInfer(source, infers) + for _, def in ipairs(defs) do + searchInfer(def, infers) + end + local id = linker.getID(source) + if id then + local link = linker.getLinkByID(source, id) + if link and link.sources then + for _, src in ipairs(link.sources) do + searchInfer(src, infers) + end + end + end + cleanInfers(infers) + return infers +end + +---搜索对象的字面量值 +---@param source parser.guide.object +---@return table +function m.searchLiterals(source) + local defs = searcher.requestDefinition(source) + local literals = {} + searchLiteral(source, literals) + for _, def in ipairs(defs) do + searchLiteral(def, literals) + end + return literals +end + +---判断对象的推断值是否是 true +---@param source parser.guide.object +function m.isTrue(source) + if not source then + return false + end + local literals = m.searchLiterals(source) + for literal in pairs(literals) do + if literal ~= false then + return true + end + end + return false +end + +---判断对象的推断类型是否包含某个类型 +function m.hasType(source, tp) + local infers = m.searchInfers(source) + return infers[tp] +end + +---搜索并显示推断类型 +---@param source parser.guide.object +---@return string +function m.searchAndViewInfers(source) + if not source then + return 'any' + end + local infers = m.searchInfers(source) + local view = m.viewInfers(infers) + return view +end + +return m diff --git a/script/core/keyword.lua b/script/core/keyword.lua index 71ea4969..538936f0 100644 --- a/script/core/keyword.lua +++ b/script/core/keyword.lua @@ -1,5 +1,5 @@ local define = require 'proto.define' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local files = require 'files' local keyWordMap = { @@ -24,7 +24,7 @@ end", end return true end, function (info) - return guide.eachSourceContain(info.ast.ast, info.start, function (source) + return searcher.eachSourceContain(info.ast.ast, info.start, function (source) if source.type == 'while' or source.type == 'in' or source.type == 'loop' then @@ -275,8 +275,8 @@ until $1" if first == 'end' or first == 'else' or first == 'elseif' then - local startRow = guide.positionOf(lines, info.start) - local finishRow = guide.positionOf(lines, pos) + local startRow = searcher.positionOf(lines, info.start) + local finishRow = searcher.positionOf(lines, pos) local startSp = info.text:match('^%s*', lines[startRow].start) local finishSp = info.text:match('^%s*', lines[finishRow].start) if startSp == finishSp then diff --git a/script/core/linker.lua b/script/core/linker.lua new file mode 100644 index 00000000..d9f3630a --- /dev/null +++ b/script/core/linker.lua @@ -0,0 +1,742 @@ +local util = require 'utility' +local guide = require 'parser.guide' + +local Linkers +local LastIDCache = {} +local FirstIDCache = {} +local SPLIT_CHAR = '\x1F' +local LAST_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']*$' +local FIRST_REGEX = '^[^' .. SPLIT_CHAR .. ']*' +local RETURN_INDEX_CHAR = '#' +local PARAM_INDEX_CHAR = '@' + +---创建source的链接信息 +---@param id string +---@return link +local function getLink(id) + if not Linkers[id] then + Linkers[id] = { + id = id, + } + end + return Linkers[id] +end + +---是否是全局变量(包括 _G.XXX 形式) +---@param source parser.guide.object +---@return boolean +local function isGlobal(source) + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.special == '_G' then + return true + end + return false +end + +---获取语法树单元的key +---@param source parser.guide.object +---@return string? key +---@return parser.guide.object? node +local function getKey(source) + if source.type == 'local' then + return tostring(source.start), nil + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return tostring(source.node.start), nil + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + return ('%q'):format(source[1] or ''), nil + elseif source.type == 'getfield' + or source.type == 'setfield' then + return ('%q'):format(source.field and source.field[1] or ''), source.node + elseif source.type == 'tablefield' then + return ('%q'):format(source.field and source.field[1] or ''), source.parent + elseif source.type == 'getmethod' + or source.type == 'setmethod' then + return ('%q'):format(source.method and source.method[1] or ''), source.node + elseif source.type == 'setindex' + or source.type == 'getindex' then + local index = source.index + if not index then + return '', source.node + end + if index.type == 'string' then + return ('%q'):format(index[1] or ''), source.node + else + return '', source.node + end + elseif source.type == 'tableindex' then + local index = source.index + if not index then + return '', source.parent + end + if index.type == 'string' then + return ('%q'):format(index[1] or ''), source.parent + else + return '', source.parent + end + elseif source.type == 'table' then + return source.start, nil + elseif source.type == 'label' then + return source.start, nil + elseif source.type == 'goto' then + if source.node then + return source.node.start, nil + end + return nil, nil + elseif source.type == 'function' then + return source.start, nil + elseif source.type == 'string' then + return '', nil + elseif source.type == 'integer' + or source.type == 'number' + or source.type == 'boolean' + or source.type == 'nil' then + return source.start, nil + elseif source.type == '...' then + return source.start, nil + elseif source.type == 'select' then + return ('%d%s%s%d'):format(source.start, SPLIT_CHAR, RETURN_INDEX_CHAR, source.index) + elseif source.type == 'call' then + local node = source.node + if node.special == 'rawget' + or node.special == 'rawset' then + if not source.args then + return nil, nil + end + local tbl, key = source.args[1], source.args[2] + if not tbl or not key then + return nil, nil + end + if key.type == 'string' then + return ('%q'):format(key[1] or ''), tbl + else + return '', tbl + end + end + return source.start, nil + elseif source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.see.name' then + local name = source[1] + return name, nil + elseif source.type == 'doc.type.name' then + if source.typeGeneric then + return source.start, nil + else + local name = source[1] + return name, nil + end + elseif source.type == 'doc.class' + or source.type == 'doc.type' + or source.type == 'doc.alias' + or source.type == 'doc.param' + or source.type == 'doc.vararg' + or source.type == 'doc.field.name' + or source.type == 'doc.type.table' + or source.type == 'doc.type.array' + or source.type == 'doc.type.function' then + return source.start, nil + elseif source.type == 'doc.see.field' then + return ('%q'):format(source[1]), source.parent.name + elseif source.type == 'generic.closure' then + return source.call.start, nil + elseif source.type == 'generic.value' then + return ('%s|%s'):format( + source.closure.call.start, + getKey(source.proto) + ) + end + return nil, nil +end + +local function checkMode(source) + if source.type == 'table' then + return 't:' + end + if source.type == 'select' then + return 's:' + end + if source.type == 'function' then + return 'f:' + end + if source.type == 'string' then + return 'str:' + end + if source.type == 'number' + or source.type == 'integer' + or source.type == 'boolean' + or source.type == 'nil' then + return 'l:' + end + if source.type == 'call' then + return 'c:' + end + if source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' then + return 'dn:' + end + if source.type == 'doc.field.name' then + return 'dfn:' + end + if source.type == 'doc.see.name' then + return 'dsn:' + end + if source.type == 'doc.class' then + return 'dc:' + end + if source.type == 'doc.type' then + return 'dt:' + end + if source.type == 'doc.param' then + return 'dp:' + end + if source.type == 'doc.alias' then + return 'da:' + end + if source.type == 'doc.type.function' then + return 'dfun:' + end + if source.type == 'doc.type.table' then + return 'dtable:' + end + if source.type == 'doc.type.array' then + return 'darray:' + end + if source.type == 'doc.vararg' then + return 'dv:' + end + if source.type == 'generic.closure' then + return 'gc:' + end + if source.type == 'generic.value' then + return 'gv:' + end + if isGlobal(source) then + return 'g:' + end + return 'l:' +end + +local IDList = {} +---获取语法树单元的字符串ID +---@param source parser.guide.object +---@return string? id +local function getID(source) + if not source then + return nil + end + if source._id ~= nil then + return source._id or nil + end + if source.type == 'field' + or source.type == 'method' then + source._id = false + return nil + end + local current = source + local index = 0 + while true do + if current.type == 'paren' then + current = current.exp + goto CONTINUE + end + local id, node = getKey(current) + if not id then + break + end + index = index + 1 + IDList[index] = id + if not node then + break + end + current = node + if current.special == '_G' then + break + end + ::CONTINUE:: + end + if index == 0 then + source._id = false + return nil + end + for i = index + 1, #IDList do + IDList[i] = nil + end + local mode = checkMode(current) + if not mode then + source._id = false + return nil + end + util.revertTable(IDList) + local id = mode .. table.concat(IDList, SPLIT_CHAR) + source._id = id + return id +end + +---添加关联的前进ID +---@param id string +---@param forwardID string +local function pushForward(id, forwardID) + if not id + or not forwardID + or forwardID == '' + or id == forwardID then + return + end + local link = getLink(id) + if not link.forward then + link.forward = {} + end + link.forward[#link.forward+1] = forwardID +end + +---添加关联的后退ID +---@param id string +---@param backwardID string +local function pushBackward(id, backwardID) + if not id + or not backwardID + or backwardID == '' + or id == backwardID then + return + end + local link = getLink(id) + if not link.backward then + link.backward = {} + end + link.backward[#link.backward+1] = backwardID +end + +---@class link +-- 当前节点的id +---@field id string +-- 使用该ID的单元 +---@field sources parser.guide.object[] +-- 前进的关联ID +---@field forward string[] +-- 后退的关联ID +---@field backward string[] +-- 函数调用参数信息(用于泛型) +---@field call parser.guide.object + +local m = {} + +m.SPLIT_CHAR = SPLIT_CHAR +m.RETURN_INDEX_CHAR = RETURN_INDEX_CHAR +m.PARAM_INDEX_CHAR = PARAM_INDEX_CHAR + +---添加关联单元 +---@param source parser.guide.object +function m.pushSource(source) + local id = m.getID(source) + if not id then + return + end + local link = getLink(id) + if not link.sources then + link.sources = {} + end + link.sources[#link.sources+1] = source +end + +---@param source parser.guide.object +---@return parser.guide.object[] +function m.compileLink(source) + local id = getID(source) + local parent = source.parent + if not parent then + return + end + if source.value then + -- x = y : x -> y + pushForward(id, getID(source.value)) + pushBackward(getID(source.value), id) + end + -- self -> mt:xx + if source.type == 'local' and source[1] == 'self' then + local func = guide.getParentFunction(source) + local setmethod = func.parent + -- guess `self` + if setmethod and ( setmethod.type == 'setmethod' + or setmethod.type == 'setfield' + or setmethod.type == 'setindex') then + pushForward(id, getID(setmethod.node)) + pushBackward(getID(setmethod.node), id) + end + end + -- 分解 @type + if source.type == 'doc.type' then + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(getID(src), id) + pushForward(id, getID(src)) + end + end + for _, typeUnit in ipairs(source.types) do + pushForward(id, getID(typeUnit)) + pushBackward(getID(typeUnit), id) + end + end + -- 分解 @class + if source.type == 'doc.class' then + pushForward(id, getID(source.class)) + pushForward(getID(source.class), id) + if source.extends then + for _, ext in ipairs(source.extends) do + pushForward(id, getID(ext)) + pushBackward(getID(ext), id) + end + end + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(getID(src), id) + pushForward(id, getID(src)) + end + end + do + local start + for _, doc in ipairs(source.bindGroup) do + if doc.type == 'doc.class' then + start = doc == source + end + if start and doc.type == 'doc.field' then + local key = doc.field[1] + if key then + local keyID = ('%s%s%q'):format( + id, + SPLIT_CHAR, + key + ) + pushForward(keyID, getID(doc.field)) + pushBackward(getID(doc.field), keyID) + pushForward(keyID, getID(doc.extends)) + pushBackward(getID(doc.extends), keyID) + end + end + end + end + end + if source.type == 'doc.param' then + pushForward(getID(source), getID(source.extends)) + end + if source.type == 'doc.vararg' then + pushForward(getID(source), getID(source.vararg)) + end + if source.type == 'doc.see' then + local nameID = getID(source.name) + local classID = nameID:gsub('^dsn:', 'dn:') + pushForward(nameID, classID) + if source.field then + local fieldID = getID(source.field) + local fieldClassID = fieldID:gsub('^dsn:', 'dn:') + pushForward(fieldID, fieldClassID) + end + end + if source.type == 'call' then + local node = source.node + local nodeID = getID(node) + if not nodeID then + return + end + getLink(id).call = source + -- 将 call 映射到 node#1 上 + local callID = ('%s%s%s%s'):format( + nodeID, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + 1 + ) + pushForward(id, callID) + -- 将setmetatable映射到 param1 以及 param2.__index 上 + if node.special == 'setmetatable' then + local tblID = getID(source.args and source.args[1]) + local metaID = getID(source.args and source.args[2]) + local indexID + if metaID then + indexID = ('%s%s%q'):format( + metaID, + SPLIT_CHAR, + '__index' + ) + end + pushForward(id, callID) + pushBackward(callID, id) + pushForward(callID, tblID) + pushForward(callID, indexID) + pushBackward(tblID, callID) + --pushBackward(indexID, callID) + end + end + if source.type == 'select' then + if source.vararg.type == 'call' then + local call = source.vararg + local node = call.node + local nodeID = getID(node) + if not nodeID then + return + end + -- 将call的返回值接收映射到函数返回值上 + local callXID = ('%s%s%s%s'):format( + nodeID, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + source.index + ) + pushForward(id, callXID) + pushBackward(callXID, id) + getLink(id).call = call + if node.special == 'pcall' + or node.special == 'xpcall' then + local index = source.index - 1 + if index <= 0 then + return + end + local funcID = call.args and getID(call.args[1]) + if not funcID then + return + end + local funcXID = ('%s%s%s%s'):format( + funcID, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + index + ) + pushForward(id, funcXID) + pushBackward(funcXID, id) + end + end + end + if source.type == 'doc.type.function' then + if source.returns then + for index, rtn in ipairs(source.returns) do + local returnID = ('%s%s%s%s'):format( + id, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + index + ) + pushForward(returnID, getID(rtn)) + end + end + end + if source.type == 'doc.type.table' then + if source.value then + local valueID = ('%s%s'):format( + id, + SPLIT_CHAR + ) + pushForward(valueID, getID(source.value)) + end + end + if source.type == 'doc.type.array' then + if source.node then + local nodeID = ('%s%s'):format( + id, + SPLIT_CHAR + ) + pushForward(nodeID, getID(source.node)) + end + end + -- 将函数的返回值映射到具体的返回值上 + if source.type == 'function' then + -- 检查实体返回值 + if source.returns then + local returns = {} + for _, rtn in ipairs(source.returns) do + for index, rtnObj in ipairs(rtn) do + if not returns[index] then + returns[index] = {} + end + returns[index][#returns[index]+1] = rtnObj + end + end + for index, rtnObjs in ipairs(returns) do + local returnID = ('%s%s%s%s'):format( + id, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + index + ) + for _, rtnObj in ipairs(rtnObjs) do + pushForward(returnID, getID(rtnObj)) + if rtnObj.type == 'function' + or rtnObj.type == 'call' then + pushBackward(getID(rtnObj), returnID) + end + end + end + end + -- 检查 luadoc + if source.bindDocs then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + local fullID = ('%s%s%s%s'):format( + id, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + rtn.returnIndex + ) + pushForward(fullID, getID(rtn)) + pushBackward(getID(rtn), fullID) + end + end + if doc.type == 'doc.param' then + local paramName = doc.param[1] + if source.docParamMap then + local paramIndex = source.docParamMap[paramName] + local param = source.args[paramIndex] + if param then + pushForward(getID(param), getID(doc)) + param.docParam = doc + end + end + end + if doc.type == 'doc.vararg' then + for _, param in ipairs(source.args) do + if param.type == '...' then + pushForward(getID(param), getID(doc)) + end + end + end + if doc.type == 'doc.generic' then + source.isGeneric = true + end + end + end + end + if source.type == 'generic.closure' then + for i, rtn in ipairs(source.returns) do + local closureID = ('%s%s%s%s'):format( + id, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + i + ) + local returnID = getID(rtn) + pushForward(closureID, returnID) + pushBackward(returnID, closureID) + end + end + if source.type == 'generic.value' then + local proto = source.proto + local closure = source.closure + local upvalues = closure.upvalues + if proto.type == 'doc.type.name' then + local key = proto[1] + if upvalues[key] then + for _, paramID in ipairs(upvalues[key]) do + pushForward(id, paramID) + pushBackward(paramID, id) + end + end + end + if proto.type == 'doc.type' then + for _, tp in ipairs(source.types) do + pushForward(id, getID(tp)) + pushBackward(getID(tp), id) + end + end + if proto.type == 'doc.type.array' then + local nodeID = ('%s%s'):format( + id, + SPLIT_CHAR + ) + pushForward(nodeID, getID(source.node)) + end + if proto.type == 'doc.type.table' then + if source.value then + local valueID = ('%s%s'):format( + id, + SPLIT_CHAR + ) + pushForward(valueID, getID(source.value)) + end + end + end +end + +---根据ID来获取所有的link +---@param root parser.guide.object +---@param id string +---@return link? +function m.getLinkByID(root, id) + root = guide.getRoot(root) + local linkers = root._linkers + if not linkers then + return nil + end + return linkers[id] +end + +---根据ID来获取第一个节点的ID +---@param id string +---@return string +function m.getFirstID(id) + if FirstIDCache[id] then + return FirstIDCache[id] or nil + end + local firstID, count = id:match(FIRST_REGEX) + if count == 0 then + FirstIDCache[id] = false + return nil + end + FirstIDCache[id] = firstID + return firstID +end + +---根据ID来获取上个节点的ID +---@param id string +---@return string +function m.getLastID(id) + if LastIDCache[id] then + return LastIDCache[id] or nil + end + local lastID, count = id:gsub(LAST_REGEX, '') + if count == 0 then + LastIDCache[id] = false + return nil + end + LastIDCache[id] = lastID + return lastID +end + +---获取source的ID +---@param source parser.guide.object +---@return string +function m.getID(source) + return getID(source) +end + +---编译整个文件的link +---@param source parser.guide.object +---@return table +function m.compileLinks(source) + local root = guide.getRoot(source) + if root._linkers then + return root._linkers + end + Linkers = {} + root._linkers = Linkers + guide.eachSource(root, function (src) + m.pushSource(src) + m.compileLink(src) + end) + -- Special rule: ('').XX -> stringlib.XX + pushForward('str:', 'dn:stringlib') + return Linkers +end + +return m diff --git a/script/core/reference.lua b/script/core/reference.lua index 7620b09e..efeb28b6 100644 --- a/script/core/reference.lua +++ b/script/core/reference.lua @@ -1,4 +1,5 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' +local guide = require 'parser.guide' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' @@ -6,8 +7,8 @@ local findSource = require 'core.find-source' local function sortResults(results) -- 先按照顺序排序 table.sort(results, function (a, b) - local u1 = guide.getUri(a.target) - local u2 = guide.getUri(b.target) + local u1 = searcher.getUri(a.target) + local u2 = searcher.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else @@ -19,7 +20,7 @@ local function sortResults(results) for i = #results, 1, -1 do local res = results[i].target local f = res.finish - local uri = guide.getUri(res) + local uri = searcher.getUri(res) if lf and f > lf and uri == lu then table.remove(results, i) else diff --git a/script/core/rename.lua b/script/core/rename.lua index da82b0a6..0851f191 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -1,6 +1,6 @@ local files = require 'files' local vm = require 'vm' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local proto = require 'proto' local define = require 'proto.define' local util = require 'utility' @@ -185,7 +185,7 @@ local function renameField(source, newname, callback) end callback(source, source.start, source.finish, newname) elseif parent.type == 'setmethod' then - local uri = guide.getUri(source) + local uri = searcher.getUri(source) local text = files.getText(uri) local func = parent.value -- function mt:name () end --> mt['newname'] = function (self) end @@ -284,7 +284,7 @@ local function ofFieldThen(key, src, newname, callback) end local function ofField(source, newname, callback) - local key = guide.getKeyName(source) + local key = searcher.getKeyName(source) local node if source.type == 'tablefield' or source.type == 'tableindex' then @@ -298,7 +298,7 @@ local function ofField(source, newname, callback) end local function ofGlobal(source, newname, callback) - local key = guide.getKeyName(source) + local key = searcher.getKeyName(source) for _, src in ipairs(vm.getRefs(source, 0)) do ofFieldThen(key, src, newname, callback) end @@ -325,7 +325,7 @@ end local function ofDocParamName(source, newname, callback) callback(source, source.start, source.finish, newname) - local doc = guide.getDocState(source) + local doc = searcher.getDocState(source) if doc.bindSources then for _, src in ipairs(doc.bindSources) do if src.type == 'local' @@ -452,7 +452,7 @@ function m.rename(uri, pos, newname) local mark = {} rename(source, newname, function (target, start, finish, text) - local turi = files.getOriginUri(guide.getUri(target)) + local turi = files.getOriginUri(searcher.getUri(target)) if not turi then return end diff --git a/script/core/searcher.lua b/script/core/searcher.lua new file mode 100644 index 00000000..c869f456 --- /dev/null +++ b/script/core/searcher.lua @@ -0,0 +1,394 @@ +local linker = require 'core.linker' +local guide = require 'parser.guide' +local files = require 'files' +local generic = require 'core.generic' + +local function checkFunctionReturn(source) + if source.parent + and source.parent.type == 'return' then + if source.parent.parent.type == 'main' then + return 0 + elseif source.parent.parent.type == 'function' then + for i = 1, #source.parent do + if source.parent[i] == source then + return i + end + end + end + end + return nil +end + +local m = {} + +---@alias guide.searchmode '"ref"'|'"def"'|'"field"' + +---添加结果 +---@param status guide.status +---@param mode guide.searchmode +---@param source parser.guide.object +function m.pushResult(status, mode, source) + if not source then + return + end + local results = status.results + local parent = source.parent + if mode == 'def' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'label' + or source.type == 'setfield' + or source.type == 'setmethod' + or source.type == 'setindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if linker.getID(source) ~= status.id then + results[#results+1] = source + end + end + elseif mode == 'ref' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'label' + or source.type == 'goto' + or source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' + or source.type == 'setindex' + or source.type == 'getindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if linker.getID(source) ~= status.id then + results[#results+1] = source + end + end + elseif mode == 'field' then + end +end + +---获取uri +---@param obj parser.guide.object +---@return uri +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = guide.getRoot(obj) + if root then + return root.uri + end + return '' +end + +-- TODO +function m.findGlobals(root) + linker.compileLinks(root) + -- TODO + return {} +end + +-- TODO +function m.isGlobal(source) + return false +end + +---@param obj parser.guide.object +---@return parser.guide.object? +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' + or obj.type == 'doc.type.table' + or obj.type == 'doc.type.arrary' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent and obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args and obj.args[3] + else + return obj + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +function m.searchRefsByID(status, uri, expect, mode) + local ast = files.getAst(uri) + if not ast then + return + end + local root = ast.ast + local searchStep + linker.compileLinks(root) + + status.id = expect + + local mark = status.mark + + local callStack = {} + + local function search(id, field) + local fieldLen + if field then + local _, len = field:gsub(linker.SPLIT_CHAR, '') + fieldLen = len + else + fieldLen = 0 + end + if mark[id] and ((mark[id] < fieldLen) or fieldLen == 0) then + return + end + mark[id] = fieldLen + searchStep(id, field) + end + + local function checkLastID(id, field) + local lastID = linker.getLastID(id) + if lastID then + local newField = id:sub(#lastID + 1) + if field then + newField = newField .. field + end + search(lastID, newField) + end + end + + local function searchID(id, field) + if not id then + return + end + if field then + id = id .. field + end + search(id, nil) + end + + local function searchFunction(id) + local link = linker.getLinkByID(root, id) + if not link or not link.sources then + return + end + local obj = link.sources[1] + if not obj or obj.type ~= 'function' then + return + end + local returnIndex = checkFunctionReturn(obj) + if not returnIndex then + return + end + local func = guide.getParentFunction(obj) + if not func or func.type ~= 'function' then + return + end + local parentID = linker.getID(func) + if not parentID then + return + end + search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex) + end + + local function isCallID(field) + if not field then + return false + end + if field:sub(1, 1) == linker.SPLIT_CHAR + and field:sub(2, 2) == linker.RETURN_INDEX_CHAR then + return true + end + return false + end + + local function findLastCall() + for i = #callStack, 1, -1 do + local call = callStack[i] + if call then + -- 标记此处的call失效,等待在堆栈平衡时弹出 + callStack[i] = false + return call + end + end + return nil + end + + local function checkGeneric(source, field) + if not source.isGeneric then + return + end + if not isCallID(field) then + return + end + local call = findLastCall() + if not call then + return + end + local closure = generic.createClosure(source, call) + if not closure then + return + end + local id = linker.getID(closure) + searchID(id, field) + end + + local stepCount = 0 + function searchStep(id, field) + stepCount = stepCount + 1 + if stepCount > 1000 then + error('too large') + end + local link = linker.getLinkByID(root, id) + if link then + if link.call then + callStack[#callStack+1] = link.call + end + if field == nil and link.sources then + for _, source in ipairs(link.sources) do + m.pushResult(status, mode, source) + end + end + if link.forward then + for _, forwardID in ipairs(link.forward) do + searchID(forwardID, field) + end + end + if link.backward and (mode == 'ref' or field) then + for _, backwardID in ipairs(link.backward) do + searchID(backwardID, field) + end + end + + if link.sources then + checkGeneric(link.sources[1], field) + end + + if link.call then + callStack[#callStack] = nil + end + end + checkLastID(id, field) + end + + search(expect) + searchFunction(expect) +end + +---搜索对象的引用 +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +function m.searchRefs(status, source, mode) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local root = guide.getRoot(source) + linker.compileLinks(root) + local uri = guide.getUri(source) + local id = linker.getID(source) + if not id then + return + end + + m.searchRefsByID(status, uri, id, mode) +end + +---@class guide.status +---搜索结果 +---@field results parser.guide.object[] + +---创建搜索状态 +---@param parentStatus guide.status +---@param interface table +---@param deep integer +---@return guide.status +function m.status(parentStatus, interface, deep) + local status = { + mark = parentStatus and parentStatus.mark or {}, + results = {}, + } + return status +end + +--- 请求对象的引用 +---@param obj parser.guide.object +---@param interface table +---@param deep integer +---@return parser.guide.object[] +---@return integer +function m.requestReference(obj, interface, deep) + local status = m.status(nil, interface, deep) + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'ref') + + return status.results, 0 +end + +--- 请求对象的定义 +---@param obj parser.guide.object +---@param interface table +---@param deep integer +---@return parser.guide.object[] +---@return integer +function m.requestDefinition(obj, interface, deep) + local status = m.status(nil, interface, deep) + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'def') + + return status.results, 0 +end + +return m diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index f8feaa09..5e9ee9b1 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local await = require 'await' local define = require 'proto.define' local vm = require 'vm' @@ -221,7 +221,7 @@ return function (uri, start, finish) local results = {} local count = 0 - guide.eachSourceBetween(ast.ast, start, finish, function (source) + searcher.eachSourceBetween(ast.ast, start, finish, function (source) local method = Care[source.type] if not method then return diff --git a/script/core/signature.lua b/script/core/signature.lua index a35f3593..7d391c94 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local vm = require 'vm' local hoverLabel = require 'core.hover.label' local hoverDesc = require 'core.hover.description' @@ -7,7 +7,7 @@ local hoverDesc = require 'core.hover.description' local function findNearCall(uri, ast, pos) local text = files.getText(uri) local nearCall - guide.eachSourceContain(ast.ast, pos, function (src) + searcher.eachSourceContain(ast.ast, pos, function (src) if src.type == 'call' or src.type == 'table' or src.type == 'function' then @@ -96,7 +96,7 @@ local function makeSignatures(call, pos) local defs = vm.getDefs(node, 0) local mark = {} for _, src in ipairs(defs) do - src = guide.getObjectValue(src) or src + src = searcher.getObjectValue(src) or src if src.type == 'function' or src.type == 'doc.type.function' then if not mark[src] then diff --git a/script/core/type-formatting.lua b/script/core/type-formatting.lua index c2290ef3..79dccc8f 100644 --- a/script/core/type-formatting.lua +++ b/script/core/type-formatting.lua @@ -1,11 +1,11 @@ local files = require 'files' local lookBackward = require 'core.look-backward' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local function insertIndentation(uri, offset, edits) local lines = files.getLines(uri) local text = files.getOriginText(uri) - local row = guide.positionOf(lines, offset) + local row = searcher.positionOf(lines, offset) local line = lines[row] local indent = text:sub(line.start, line.finish):match '^%s*' for _, edit in ipairs(edits) do diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua index ae420d32..2df23a4d 100644 --- a/script/core/workspace-symbol.lua +++ b/script/core/workspace-symbol.lua @@ -1,5 +1,5 @@ local files = require 'files' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local matchKey = require 'core.matchkey' local define = require 'proto.define' local await = require 'await' @@ -52,7 +52,7 @@ local function searchFile(uri, key, results) return end - guide.eachSource(ast.ast, function (source) + searcher.eachSource(ast.ast, function (source) buildSource(uri, source, key, results) end) end diff --git a/script/files.lua b/script/files.lua index 9cc6b549..3f3d633e 100644 --- a/script/files.lua +++ b/script/files.lua @@ -9,7 +9,7 @@ local await = require 'await' local timer = require 'timer' local plugin = require 'plugin' local util = require 'utility' -local guide = require 'core.guide' +local guide = require 'parser.guide' local smerger = require 'string-merger' local progress = require "progress" diff --git a/script/parser/ast.lua b/script/parser/ast.lua index 45d77631..b2a9fa37 100644 --- a/script/parser/ast.lua +++ b/script/parser/ast.lua @@ -1460,8 +1460,13 @@ local Defs = { local values if func then local call = createCall(exp, func.finish + 1, exp.finish) + if #exp == 0 then + exp[1] = getSelect(func, 2) + exp[2] = getSelect(func, 3) + exp[3] = getSelect(func, 4) + end call.node = func - call.start = func.start + call.start = inA func.next = call func.iterator = true values = { call } diff --git a/script/parser/compile.lua b/script/parser/compile.lua index a7e0dc1f..ae4808df 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -150,8 +150,8 @@ local vmMap = { local value = obj.value local localself = { type = 'local', - start = 0, - finish = 0, + start = value.start, + finish = value.finish, method = obj, effect = obj.finish, tag = 'self', diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 2369e84f..a838a42e 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -1,34 +1,8 @@ -local util = require 'utility' local error = error local type = type -local next = next -local tostring = tostring -local print = print -local ipairs = ipairs -local tableInsert = table.insert -local tableUnpack = table.unpack -local tableRemove = table.remove -local tableMove = table.move -local tableSort = table.sort -local tableConcat = table.concat -local mathType = math.type -local pairs = pairs -local setmetatable = setmetatable -local assert = assert -local select = select -local osClock = os.clock -local tonumber = tonumber -local tointeger = math.tointeger -local DEVELOP = _G.DEVELOP -local log = log -local _G = _G ---@class parser.guide.object -local function logWarn(...) - log.warn(...) -end - ---@class guide ---@field debugMode boolean local m = {} @@ -102,7 +76,7 @@ m.childMap = { ['doc.type.array'] = {'node'}, ['doc.type.table'] = {'node', 'key', 'value', 'comment'}, ['doc.type.function'] = {'#args', '#returns', 'comment'}, - ['doc.type.typeliteral'] = {'node'}, + ['doc.type.literal'] = {'node'}, ['doc.type.arg'] = {'extends'}, ['doc.overload'] = {'overload', 'comment'}, ['doc.see'] = {'name', 'field'}, @@ -123,20 +97,6 @@ m.actionMap = { ['funcargs'] = {'#'}, } -local TypeSort = { - ['boolean'] = 1, - ['string'] = 2, - ['integer'] = 3, - ['number'] = 4, - ['table'] = 5, - ['function'] = 6, - ['true'] = 101, - ['false'] = 102, - ['nil'] = 999, -} - -local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) - --- 是否是字面量 ---@param obj parser.guide.object ---@return boolean @@ -293,10 +253,19 @@ end ---@param obj parser.guide.object ---@return parser.guide.object function m.getRoot(obj) + local source = obj + if source._root then + return source._root + end for _ = 1, 1000 do if obj.type == 'main' then + source._root = obj return obj end + if obj._root then + source._root = obj._root + return source._root + end local parent = obj.parent if not parent then return nil @@ -501,8 +470,8 @@ function m.addChilds(list, obj, map) for i = 1, #keys do local key = keys[i] if key == '#' then - for i = 1, #obj do - list[#list+1] = obj[i] + for j = 1, #obj do + list[#list+1] = obj[j] end elseif obj[key] then list[#list+1] = obj[key] @@ -510,8 +479,8 @@ function m.addChilds(list, obj, map) and key:sub(1, 1) == '#' then key = key:sub(2) if obj[key] then - for i = 1, #obj[key] do - list[#list+1] = obj[key][i] + for j = 1, #obj[key] do + list[#list+1] = obj[key][j] end end end @@ -718,4 +687,50 @@ function m.lineData(lines, row) return lines[row] end +function m.isSet(source) + local tp = source.type + if tp == 'setglobal' + or tp == 'local' + or tp == 'setlocal' + or tp == 'setfield' + or tp == 'setmethod' + or tp == 'setindex' + or tp == 'tablefield' + or tp == 'tableindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawset' then + return true + end + end + return false +end + +function m.isGet(source) + local tp = source.type + if tp == 'getglobal' + or tp == 'getlocal' + or tp == 'getfield' + or tp == 'getmethod' + or tp == 'getindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawget' then + return true + end + end + return false +end + +function m.getSpecial(source) + if not source then + return nil + end + return source.special +end + return m diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index ae8e3f34..5f70a9e5 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -1,7 +1,7 @@ local m = require 'lpeglabel' local re = require 'parser.relabel' local lines = require 'parser.lines' -local guide = require 'core.guide' +local guide = require 'parser.guide' local grammar = require 'parser.grammar' local TokenTypes, TokenStarts, TokenFinishs, TokenContents @@ -484,13 +484,7 @@ function parseType(parent) break end -- TypeLiteral,指代类型的字面值。比如,对于类 Cat 来说,它的 TypeLiteral 是 "Cat" - typeLiteral = { - type = 'doc.type.typeliteral', - parent = result, - start = getStart(), - finish = nil, - node = nil, - } + typeLiteral = true end if tp == 'name' then @@ -501,10 +495,7 @@ function parseType(parent) end if typeLiteral then nextToken() - typeLiteral.finish = getFinish() - typeLiteral.node = typeUnit - typeUnit.parent = typeLiteral - typeUnit = typeLiteral + typeUnit.literal = true end result.types[#result.types+1] = typeUnit if not result.start then @@ -1152,21 +1143,22 @@ local function bindParamAndReturnIndex(binded) if not func then return end - if not func.args then - return - end - local paramIndex = 0 - local paramMap = {} - for _, param in ipairs(func.args) do - paramIndex = paramIndex + 1 - if param[1] then - paramMap[param[1]] = paramIndex + local paramMap + if func.args then + local paramIndex = 0 + paramMap = {} + for _, param in ipairs(func.args) do + paramIndex = paramIndex + 1 + if param[1] then + paramMap[param[1]] = paramIndex + end end + func.docParamMap = paramMap end local returnIndex = 0 for _, doc in ipairs(binded) do if doc.type == 'doc.param' then - if doc.extends then + if paramMap and doc.extends then doc.extends.paramIndex = paramMap[doc.param[1]] end elseif doc.type == 'doc.return' then diff --git a/script/vm/eachDef.lua b/script/vm/eachDef.lua index d72c8f01..a00b61bb 100644 --- a/script/vm/eachDef.lua +++ b/script/vm/eachDef.lua @@ -1,10 +1,10 @@ ---@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local files = require 'files' -local util = require 'utility' -local await = require 'await' -local config = require 'config' +local vm = require 'vm.vm' +local searcher = require 'core.searcher' +local files = require 'files' +local util = require 'utility' +local await = require 'await' +local config = require 'config' local function getDefs(source, deep) local results = {} @@ -18,9 +18,9 @@ local function getDefs(source, deep) deep = config.config.intelliSense.searchDepth + (deep or 0) local clock = os.clock() - local myResults, count = guide.requestDefinition(source, vm.interface, deep) + local myResults, count = searcher.requestDefinition(source, vm.interface, deep) if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestDefinition', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) + log.warn('requestDefinition', count, os.clock() - clock, searcher.getUri(source), util.dump(source, { deep = 1 })) end vm.mergeResults(results, myResults) @@ -31,8 +31,8 @@ end function vm.getDefs(source, deep) deep = deep or -999 - if guide.isGlobal(source) then - local key = guide.getKeyName(source) + if searcher.isGlobal(source) then + local key = searcher.getKeyName(source) if not key then return {} end diff --git a/script/vm/eachField.lua b/script/vm/eachField.lua index 59f35f0c..9a40fb1c 100644 --- a/script/vm/eachField.lua +++ b/script/vm/eachField.lua @@ -1,6 +1,6 @@ ---@type vm local vm = require 'vm.vm' -local guide = require 'core.guide' +local searcher= require 'core.searcher' local await = require 'await' local config = require 'config' @@ -19,7 +19,7 @@ local function getFields(source, deep, filterKey) deep = config.config.intelliSense.searchDepth + (deep or 0) await.delay() - local results = guide.requestFields(source, vm.interface, deep, filterKey) + local results = searcher.requestFields(source, vm.interface, deep, filterKey) unlock() return results @@ -40,7 +40,7 @@ local function getDefFields(source, deep, filterKey) deep = config.config.intelliSense.searchDepth + (deep or 0) await.delay() - local results = guide.requestDefFields(source, vm.interface, deep, filterKey) + local results = searcher.requestDefFields(source, vm.interface, deep, filterKey) unlock() return results @@ -76,8 +76,8 @@ function vm.getFields(source, deep) if source.special == '_G' then return vm.getGlobals '*' end - if guide.isGlobal(source) then - local name = guide.getKeyName(source) + if searcher.isGlobal(source) then + local name = searcher.getKeyName(source) if not name then return {} end @@ -94,8 +94,8 @@ function vm.getDefFields(source, deep) if source.special == '_G' then return vm.getGlobalSets '*' end - if guide.isGlobal(source) then - local name = guide.getKeyName(source) + if searcher.isGlobal(source) then + local name = searcher.getKeyName(source) if not name then return {} end diff --git a/script/vm/eachRef.lua b/script/vm/eachRef.lua index 9d0f061c..0556a6a3 100644 --- a/script/vm/eachRef.lua +++ b/script/vm/eachRef.lua @@ -1,9 +1,9 @@ ---@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' -local util = require 'utility' -local await = require 'await' -local config = require 'config' +local vm = require 'vm.vm' +local searcher = require 'core.searcher' +local util = require 'utility' +local await = require 'await' +local config = require 'config' local function getRefs(source, deep) local results = {} @@ -17,9 +17,9 @@ local function getRefs(source, deep) deep = config.config.intelliSense.searchDepth + (deep or 0) local clock = os.clock() - local myResults, count = guide.requestReference(source, vm.interface, deep) + local myResults, count = searcher.requestReference(source, vm.interface, deep) if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestReference', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) + log.warn('requestReference', count, os.clock() - clock, searcher.getUri(source), util.dump(source, { deep = 1 })) end vm.mergeResults(results, myResults) @@ -30,8 +30,8 @@ end function vm.getRefs(source, deep) deep = deep or -999 - if guide.isGlobal(source) then - local key = guide.getKeyName(source) + if searcher.isGlobal(source) then + local key = searcher.getKeyName(source) if not key then return {} end diff --git a/script/vm/getClass.lua b/script/vm/getClass.lua index 5c68e0bb..fbd50fc8 100644 --- a/script/vm/getClass.lua +++ b/script/vm/getClass.lua @@ -1,13 +1,13 @@ ---@type vm -local vm = require 'vm.vm' -local guide = require 'core.guide' +local vm = require 'vm.vm' +local searcher = require 'core.searcher' local function lookUpDocClass(source) local infers = vm.getInfers(source, 0) for _, infer in ipairs(infers) do if infer.source.type == 'doc.class' or infer.source.type == 'doc.type' then - return guide.viewInferType(infers) + return searcher.viewInferType(infers) end end return nil @@ -22,7 +22,7 @@ local function getClass(source, classes, depth, deep) if depth > 3 then return end - local value = guide.getObjectValue(source) or source + local value = searcher.getObjectValue(source) or source if not deep then if value and value.type == 'string' then classes[value[1]] = true @@ -38,7 +38,7 @@ local function getClass(source, classes, depth, deep) or lkey == '__name' or lkey == 'name' or lkey == 'class' then - local value = guide.getObjectValue(src) + local value = searcher.getObjectValue(src) if value and value.type == 'string' then classes[value[1]] = true end @@ -60,5 +60,5 @@ function vm.getClass(source, deep) if not next(classes) then return nil end - return guide.mergeTypes(classes) + return searcher.mergeTypes(classes) end diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua index cfa9326f..a230a160 100644 --- a/script/vm/getDocs.lua +++ b/script/vm/getDocs.lua @@ -1,39 +1,52 @@ local files = require 'files' local util = require 'utility' -local guide = require 'core.guide' +local guide = require 'parser.guide' ---@type vm local vm = require 'vm.vm' local config = require 'config' -local function getTypesOfFile(uri) - local types = {} +local typeMap = { + ['doc.type.name'] = 'type', + ['doc.class.name'] = 'class', + ['doc.extends.name'] = 'extends', + ['doc.alias.name'] = 'alias', +} + +local function getNamesOfFile(uri) + local names = { + type = {}, + class = {}, + extends = {}, + alias = {}, + } local ast = files.getAst(uri) if not ast or not ast.ast.docs then - return types + return names end guide.eachSource(ast.ast.docs, function (src) - if src.type == 'doc.type.name' - or src.type == 'doc.class.name' - or src.type == 'doc.extends.name' - or src.type == 'doc.alias.name' then - if src.type == 'doc.type.name' then - if guide.getParentDocTypeTable(src) then - return - end - end - local name = src[1] - if name then - if not types[name] then - types[name] = {} - end - types[name][#types[name]+1] = src - end + local type = typeMap[src.type] + if not type then + return + end + --if src.type == 'doc.type.name' then + -- if guide.getParentDocTypeTable(src) then + -- return + -- end + --end + local name = src[1] + if not name then + return + end + local list = names[type] + if not list[name] then + list[name] = {} end + list[name][#list[name]+1] = src end) - return types + return names end -local function getDocTypes(name) +local function getDocNames(name, type) local results = {} if name == 'any' or name == 'nil' then @@ -41,16 +54,16 @@ local function getDocTypes(name) end for uri in files.eachFile() do local cache = files.getCache(uri) - cache.classes = cache.classes or getTypesOfFile(uri) + cache = cache or getNamesOfFile(uri) if name == '*' then - for _, sources in util.sortPairs(cache.classes) do + for _, sources in util.sortPairs(cache[type]) do for _, source in ipairs(sources) do results[#results+1] = source end end else - if cache.classes[name] then - for _, source in ipairs(cache.classes[name]) do + if cache[type][name] then + for _, source in ipairs(cache[type][name]) do results[#results+1] = source end end @@ -119,29 +132,14 @@ function vm.getDocTypeUnits(doc, mark, results) return results end -function vm.getDocTypes(name) - local cache = vm.getCache('getDocTypes')[name] +function vm.getDocNames(name, type) + local cacheName = 'docNames:' .. type + local cache = vm.getCache(cacheName)[name] if cache ~= nil then return cache end - cache = getDocTypes(name) - vm.getCache('getDocTypes')[name] = cache - return cache -end - -function vm.getDocClass(name) - local cache = vm.getCache('getDocClass')[name] - if cache ~= nil then - return cache - end - cache = {} - local results = getDocTypes(name) - for _, doc in ipairs(results) do - if doc.type == 'doc.class.name' then - cache[#cache+1] = doc - end - end - vm.getCache('getDocClass')[name] = cache + cache = getDocNames(name, type) + vm.getCache(cacheName)[name] = cache return cache end diff --git a/script/vm/getGlobals.lua b/script/vm/getGlobals.lua index 2752ce09..08f9d049 100644 --- a/script/vm/getGlobals.lua +++ b/script/vm/getGlobals.lua @@ -1,4 +1,4 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local await = require "await" ---@type vm local vm = require 'vm.vm' @@ -265,7 +265,7 @@ files.watch(function (ev, uri) end needUpdateGlobals[uri] = true elseif ev == 'create' then - getGlobalsOfFile(uri) - getGlobalSetsOfFile(uri) + --getGlobalsOfFile(uri) + --getGlobalSetsOfFile(uri) end end) diff --git a/script/vm/getInfer.lua b/script/vm/getInfer.lua index 5447ca23..be9a66ab 100644 --- a/script/vm/getInfer.lua +++ b/script/vm/getInfer.lua @@ -1,6 +1,6 @@ ---@type vm local vm = require 'vm.vm' -local guide = require 'core.guide' +local searcher= require 'core.searcher' local util = require 'utility' local await = require 'await' local config = require 'config' @@ -12,7 +12,7 @@ function vm.hasType(source, type, deep) local defs = vm.getDefs(source, deep) for i = 1, #defs do local def = defs[i] - local value = guide.getObjectValue(def) or def + local value = searcher.getObjectValue(def) or def if value.type == type then return true end @@ -34,7 +34,7 @@ end function vm.getInferType(source, deep) local infers = vm.getInfers(source, deep) - return guide.viewInferType(infers) + return searcher.viewInferType(infers) end function vm.getInferLiteral(source, deep) @@ -67,9 +67,9 @@ local function getInfers(source, deep) await.delay() local clock = os.clock() - local myResults, count = guide.requestInfer(source, vm.interface, deep) + local myResults, count = searcher.requestInfer(source, vm.interface, deep) if DEVELOP and os.clock() - clock > 0.1 then - log.warn('requestInfer', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) + log.warn('requestInfer', count, os.clock() - clock, searcher.getUri(source), util.dump(source, { deep = 1 })) end vm.mergeResults(results, myResults) @@ -92,8 +92,8 @@ end --- 获取对象的值 --- 会尝试穿透函数调用 function vm.getInfers(source, deep) - if guide.isGlobal(source) then - local name = guide.getKeyName(source) + if searcher.isGlobal(source) then + local name = searcher.getKeyName(source) local cache = vm.getCache('getInfersOfGlobal')[name] or getInfersBySource(source, deep) vm.getCache('getInfersOfGlobal')[name] = cache diff --git a/script/vm/getLinks.lua b/script/vm/getLinks.lua index 91a5f1a0..86a38cfc 100644 --- a/script/vm/getLinks.lua +++ b/script/vm/getLinks.lua @@ -1,7 +1,7 @@ -local guide = require 'core.guide' +local searcher = require 'core.searcher' ---@type vm -local vm = require 'vm.vm' -local files = require 'files' +local vm = require 'vm.vm' +local files = require 'files' local function getFileLinks(uri) local ws = require 'workspace' @@ -11,7 +11,7 @@ local function getFileLinks(uri) return links end tracy.ZoneBeginN('getFileLinks') - guide.eachSpecialOf(ast.ast, 'require', function (source) + searcher.eachSpecialOf(ast.ast, 'require', function (source) local call = source.parent if not call or call.type ~= 'call' then return diff --git a/script/vm/guideInterface.lua b/script/vm/guideInterface.lua index ae060481..e59fc6e3 100644 --- a/script/vm/guideInterface.lua +++ b/script/vm/guideInterface.lua @@ -2,7 +2,7 @@ local vm = require 'vm.vm' local files = require 'files' local ws = require 'workspace' -local guide = require 'core.guide' +local searcher = require 'core.searcher' local await = require 'await' local config = require 'config' @@ -27,7 +27,7 @@ function m.require(args, index) return nil end local results = {} - local myUri = guide.getUri(args[1]) + local myUri = searcher.getUri(args[1]) local uris = ws.findUrisByRequirePath(reqName) for _, uri in ipairs(uris) do if not files.eq(myUri, uri) then @@ -47,7 +47,7 @@ function m.dofile(args, index) return end local results = {} - local myUri = guide.getUri(args[1]) + local myUri = searcher.getUri(args[1]) local uris = ws.findUrisByFilePath(reqName) for _, uri in ipairs(uris) do if not files.eq(myUri, uri) then @@ -87,9 +87,9 @@ function vm.interface.global(name, onlyDef) end end -function vm.interface.docType(name) +function vm.interface.doc(name, type) await.delay() - return vm.getDocTypes(name) + return vm.getDocNames(name, type) end function vm.interface.link(uri) diff --git a/script/vm/vm.lua b/script/vm/vm.lua index 0248ad8c..ebd0102b 100644 --- a/script/vm/vm.lua +++ b/script/vm/vm.lua @@ -1,18 +1,14 @@ -local guide = require 'core.guide' +local guide = require 'parser.guide' local util = require 'utility' local files = require 'files' local timer = require 'timer' local setmetatable = setmetatable -local assert = assert -local require = require -local type = type local running = coroutine.running local ipairs = ipairs local log = log local xpcall = xpcall local mathHuge = math.huge -local collectgarbage = collectgarbage _ENV = nil @@ -63,10 +59,6 @@ function m.getArgInfo(source) return nil end -function m.getSpecial(source) - return guide.getSpecial(source) -end - function m.getKeyName(source) if not source then return nil |