diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2020-12-05 11:13:05 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2020-12-05 11:13:05 +0800 |
commit | becf78043079b7df7d9fba2941d6c134fbc0aa79 (patch) | |
tree | 035bed287c90f3b44029fa37a6ea708461cf10c2 | |
parent | de4463286505e59158f60f65466d07d61d21cd0b (diff) | |
download | lua-language-server-becf78043079b7df7d9fba2941d6c134fbc0aa79.zip |
update test
-rw-r--r-- | script/parser/compile.lua | 3 | ||||
-rw-r--r-- | script/parser/parse.lua | 2 | ||||
-rw-r--r-- | test/example/guide.txt | 4014 | ||||
-rw-r--r-- | test/full/example.lua | 14 |
4 files changed, 4032 insertions, 1 deletions
diff --git a/script/parser/compile.lua b/script/parser/compile.lua index 1ba111ed..d4129ab4 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -1,5 +1,6 @@ local guide = require 'parser.guide' local type = type +local os = os local specials = { ['_G'] = true, @@ -537,6 +538,7 @@ return function (self, lua, mode, version, options) if not state then return nil, err end + local clock = os.clock() pushError = state.pushError if version == 'Lua 5.1' or version == 'LuaJIT' then ENVMode = '@fenv' @@ -557,6 +559,7 @@ return function (self, lua, mode, version, options) Compile(state.ast) end PostCompile() + state.compileClock = os.clock() - clock Compiled = nil GoToTag = nil return state diff --git a/script/parser/parse.lua b/script/parser/parse.lua index 909ce315..9b8d5496 100644 --- a/script/parser/parse.lua +++ b/script/parser/parse.lua @@ -36,6 +36,7 @@ return function (self, lua, mode, version, options) comms[#comms+1] = comment end } + local clock = os.clock() ast.init(state) local suc, res, err = xpcall(self.grammar, debug.traceback, self, lua, mode) ast.close() @@ -46,5 +47,6 @@ return function (self, lua, mode, version, options) state.pushError(err) end state.ast = res + state.parseClock = os.clock() - clock return state end diff --git a/test/example/guide.txt b/test/example/guide.txt new file mode 100644 index 00000000..437e37b0 --- /dev/null +++ b/test/example/guide.txt @@ -0,0 +1,4014 @@ +local util = require 'utility' +local error = error +local type = type +local next = next +local tostring = tostring +local print = print +local ipairs = ipairs +local tableInsert = table.insert +local tableUnpack = table.unpack +local tableRemove = table.remove +local tableMove = table.move +local tableSort = table.sort +local tableConcat = table.concat +local mathType = math.type +local pairs = pairs +local setmetatable = setmetatable +local assert = assert +local select = select +local osClock = os.clock +local DEVELOP = _G.DEVELOP +local log = log +local _G = _G + +local function logWarn(...) + log.warn(...) +end + +_ENV = nil + +local m = {} + +m.ANY = {} + +local blockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, + ['do'] = true, + ['function'] = true, + ['ifblock'] = true, + ['elseblock'] = true, + ['elseifblock'] = true, + ['main'] = true, +} + +local breakBlockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, +} + +m.childMap = { + ['main'] = {'#', 'docs'}, + ['repeat'] = {'#', 'filter'}, + ['while'] = {'filter', '#'}, + ['in'] = {'keys', '#'}, + ['loop'] = {'loc', 'max', 'step', '#'}, + ['if'] = {'#'}, + ['ifblock'] = {'filter', '#'}, + ['elseifblock'] = {'filter', '#'}, + ['elseblock'] = {'#'}, + ['setfield'] = {'node', 'field', 'value'}, + ['setglobal'] = {'value'}, + ['local'] = {'attrs', 'value'}, + ['setlocal'] = {'value'}, + ['return'] = {'#'}, + ['do'] = {'#'}, + ['select'] = {'vararg'}, + ['table'] = {'#'}, + ['tableindex'] = {'index', 'value'}, + ['tablefield'] = {'field', 'value'}, + ['function'] = {'args', '#'}, + ['funcargs'] = {'#'}, + ['setmethod'] = {'node', 'method', 'value'}, + ['getmethod'] = {'node', 'method'}, + ['setindex'] = {'node', 'index', 'value'}, + ['getindex'] = {'node', 'index'}, + ['paren'] = {'exp'}, + ['call'] = {'node', 'args'}, + ['callargs'] = {'#'}, + ['getfield'] = {'node', 'field'}, + ['list'] = {'#'}, + ['binary'] = {1, 2}, + ['unary'] = {1}, + + ['doc'] = {'#'}, + ['doc.class'] = {'class', 'extends', 'comment'}, + ['doc.type'] = {'#types', '#enums', 'name', 'comment'}, + ['doc.alias'] = {'alias', 'extends', 'comment'}, + ['doc.param'] = {'param', 'extends', 'comment'}, + ['doc.return'] = {'#returns', 'comment'}, + ['doc.field'] = {'field', 'extends', 'comment'}, + ['doc.generic'] = {'#generics', 'comment'}, + ['doc.generic.object'] = {'generic', 'extends', 'comment'}, + ['doc.vararg'] = {'vararg', 'comment'}, + ['doc.type.table'] = {'key', 'value', 'comment'}, + ['doc.type.function'] = {'#args', '#returns', 'comment'}, + ['doc.overload'] = {'overload', 'comment'}, + ['doc.see'] = {'name', 'field'}, +} + +m.actionMap = { + ['main'] = {'#'}, + ['repeat'] = {'#'}, + ['while'] = {'#'}, + ['in'] = {'#'}, + ['loop'] = {'#'}, + ['if'] = {'#'}, + ['ifblock'] = {'#'}, + ['elseifblock'] = {'#'}, + ['elseblock'] = {'#'}, + ['do'] = {'#'}, + ['function'] = {'#'}, + ['funcargs'] = {'#'}, +} + +local TypeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['nil'] = 999, +} + +local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) + +--- 是否是字面量 +function m.isLiteral(obj) + local tp = obj.type + return tp == 'nil' + or tp == 'boolean' + or tp == 'string' + or tp == 'number' + or tp == 'table' +end + +--- 获取字面量 +function m.getLiteral(obj) + local tp = obj.type + if tp == 'boolean' then + return obj[1] + elseif tp == 'string' then + return obj[1] + elseif tp == 'number' then + return obj[1] + end + return nil +end + +--- 寻找父函数 +function m.getParentFunction(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + break + end + local tp = obj.type + if tp == 'function' or tp == 'main' then + return obj + end + end + return nil +end + +--- 寻找所在区块 +function m.getBlock(obj) + for _ = 1, 1000 do + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + obj = obj.parent + end + error('guide.getBlock overstack') +end + +--- 寻找所在父区块 +function m.getParentBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + end + error('guide.getParentBlock overstack') +end + +--- 寻找所在可break的父区块 +function m.getBreakBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if breakBlockTypes[tp] then + return obj + end + if tp == 'function' then + return nil + end + end + error('guide.getBreakBlock overstack') +end + +--- 寻找doc的主体 +function m.getDocState(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return obj + end + if parent.type == 'doc' then + return obj + end + obj = parent + end + error('guide.getDocState overstack') +end + +--- 寻找所在父类型 +function m.getParentType(obj, want) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + if want == obj.type then + return obj + end + end + error('guide.getParentType overstack') +end + +--- 寻找根区块 +function m.getRoot(obj) + for _ = 1, 1000 do + if obj.type == 'main' then + return obj + end + local parent = obj.parent + if not parent then + return nil + end + obj = parent + end + error('guide.getRoot overstack') +end + +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = m.getRoot(obj) + if root then + return root.uri + end + return '' +end + +function m.getENV(source, start) + if not start then + start = 1 + end + return m.getLocal(source, '_ENV', start) + or m.getLocal(source, '@fenv', start) +end + +--- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。 +--- 如果函数是主函数,则返回`0, nil`。 +---@return table +---@return integer +function m.getFunctionVarArgs(func) + if func.type == 'main' then + return 0, nil + end + if func.type ~= 'function' then + return nil, nil + end + local args = func.args + if not args then + return nil, nil + end + for i = 1, #args do + local arg = args[i] + if arg.type == '...' then + return i, arg + end + end + return nil, nil +end + +--- 获取指定区块中可见的局部变量 +---@param block table +---@param name string {comment = '变量名'} +---@param pos integer {comment = '可见位置'} +function m.getLocal(block, name, pos) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local locals = block.locals + local res + if not locals then + goto CONTINUE + end + for i = 1, #locals do + local loc = locals[i] + if loc.effect > pos then + break + end + if loc[1] == name then + if not res or res.effect < loc.effect then + res = loc + end + end + end + if res then + return res, res + end + ::CONTINUE:: + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +--- 获取指定区块中所有的可见局部变量名称 +function m.getVisibleLocals(block, pos) + local result = {} + m.eachSourceContain(m.getRoot(block), pos, function (source) + local locals = source.locals + if locals then + for i = 1, #locals do + local loc = locals[i] + local name = loc[1] + if loc.effect <= pos then + result[name] = loc + end + end + end + end) + return result +end + +--- 获取指定区块中可见的标签 +---@param block table +---@param name string {comment = '标签名'} +function m.getLabel(block, name) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local labels = block.labels + if labels then + local label = labels[name] + if label then + return label + end + end + if block.type == 'function' then + return nil + end + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +function m.getStartFinish(source) + local start = source.start + local finish = source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.start + finish = last.finish + end + return start, finish +end + +function m.getRange(source) + local start = source.vstart or source.start + local finish = source.range or source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.vstart or first.start + finish = last.range or last.finish + end + return start, finish +end + +--- 判断source是否包含offset +function m.isContain(source, offset) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +--- 判断offset在source的影响范围内 +--- +--- 主要针对赋值等语句时,key包含value +function m.isInRange(source, offset) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +function m.isBetween(source, tStart, tFinish) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +function m.isBetweenRange(source, tStart, tFinish) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +--- 添加child +function m.addChilds(list, obj, map) + local keys = map[obj.type] + if keys then + for i = 1, #keys do + local key = keys[i] + if key == '#' then + for i = 1, #obj do + list[#list+1] = obj[i] + end + elseif obj[key] then + list[#list+1] = obj[key] + elseif type(key) == 'string' + and key:sub(1, 1) == '#' then + key = key:sub(2) + for i = 1, #obj[key] do + list[#list+1] = obj[key][i] + end + end + end + end +end + +--- 遍历所有包含offset的source +function m.eachSourceContain(ast, offset, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isInRange(obj, offset) then + if m.isContain(obj, offset) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有在某个范围内的source +function m.eachSourceBetween(ast, start, finish, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isBetweenRange(obj, start, finish) then + if m.isBetween(obj, start, finish) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有指定类型的source +function m.eachSourceType(ast, type, callback) + local cache = ast.typeCache + if not cache then + cache = {} + ast.typeCache = cache + m.eachSource(ast, function (source) + local tp = source.type + if not tp then + return + end + local myCache = cache[tp] + if not myCache then + myCache = {} + cache[tp] = myCache + end + myCache[#myCache+1] = source + end) + end + local myCache = cache[type] + if not myCache then + return + end + for i = 1, #myCache do + callback(myCache[i]) + end +end + +--- 遍历所有的source +function m.eachSource(ast, callback) + local list = { ast } + local mark = {} + local index = 1 + while true do + local obj = list[index] + if not obj then + return + end + list[index] = false + index = index + 1 + if not mark[obj] then + mark[obj] = true + callback(obj) + m.addChilds(list, obj, m.childMap) + end + end +end + +--- 获取指定的 special +function m.eachSpecialOf(ast, name, callback) + local root = m.getRoot(ast) + if not root.specials then + return + end + local specials = root.specials[name] + if not specials then + return + end + for i = 1, #specials do + callback(specials[i]) + end +end + +--- 获取偏移对应的坐标 +---@param lines table +---@return integer {name = 'row'} +---@return integer {name = 'col'} +function m.positionOf(lines, offset) + if offset < 1 then + return 0, 0 + end + local lastLine = lines[#lines] + if offset > lastLine.finish then + return #lines, lastLine.finish - lastLine.start + 1 + end + local min = 1 + local max = #lines + for _ = 1, 100 do + if max <= min then + local line = lines[min] + return min, offset - line.start + 1 + end + local row = (max - min) // 2 + min + local line = lines[row] + if offset < line.start then + max = row - 1 + elseif offset > line.finish then + min = row + 1 + else + return row, offset - line.start + 1 + end + end + error('Stack overflow!') +end + +--- 获取坐标对应的偏移 +---@param lines table +---@param row integer +---@param col integer +---@return integer {name = 'offset'} +function m.offsetOf(lines, row, col) + if row < 1 then + return 0 + end + if row > #lines then + local lastLine = lines[#lines] + return lastLine.finish + end + local line = lines[row] + local len = line.finish - line.start + 1 + if col < 0 then + return line.start + elseif col > len then + return line.finish + else + return line.start + col - 1 + end +end + +function m.lineContent(lines, text, row, ignoreNL) + local line = lines[row] + if not line then + return '' + end + if ignoreNL then + return text:sub(line.start, line.range) + else + return text:sub(line.start, line.finish) + end +end + +function m.lineRange(lines, row, ignoreNL) + local line = lines[row] + if not line then + return 0, 0 + end + if ignoreNL then + return line.start, line.range + else + return line.start, line.finish + end +end + +function m.getKeyTypeOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return 'string' + elseif tp == 'string' then + return 'string' + elseif tp == 'number' then + return 'number' + elseif tp == 'boolean' then + return 'boolean' + end +end + +function m.getKeyType(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return 'string' + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return 'local' + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + return 'string' + elseif tp == 'getmethod' + or tp == 'setmethod' then + return 'string' + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyTypeOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' + or tp == 'doc.see.field' then + return 'string' + elseif tp == 'doc.class' then + return 'string' + elseif tp == 'doc.alias' then + return 'string' + elseif tp == 'doc.field' then + return 'string' + end + return m.getKeyTypeOfLiteral(obj) +end + +function m.getKeyNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return obj[1] + elseif tp == 'string' then + local s = obj[1] + if s then + return s + end + elseif tp == 'number' then + local n = obj[1] + if n then + return ('%s'):format(util.viewLiteral(obj[1])) + end + elseif tp == 'boolean' then + local b = obj[1] + if b then + return tostring(b) + end + end +end + +function m.getKeyName(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + if obj.field then + return obj.field[1] + end + elseif tp == 'getmethod' + or tp == 'setmethod' then + if obj.method then + return obj.method[1] + end + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' + or tp == 'doc.see.field' then + return obj[1] + elseif tp == 'doc.class' then + return obj.class[1] + elseif tp == 'doc.alias' then + return obj.alias[1] + elseif tp == 'doc.field' then + return obj.field[1] + end + return m.getKeyNameOfLiteral(obj) +end + +function m.getSimpleName(obj) + if obj.type == 'call' then + local key = obj.args and obj.args[2] + return m.getKeyName(key) + elseif obj.type == 'table' then + return ('%p'):format(obj) + elseif obj.type == 'select' then + return ('%p'):format(obj) + elseif obj.type == 'string' then + return ('%p'):format(obj) + elseif obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.see.name' then + return ('%s'):format(obj[1]) + elseif obj.type == 'doc.class' then + return ('%s'):format(obj.class[1]) + end + return m.getKeyName(obj) +end + +--- 测试 a 到 b 的路径(不经过函数,不考虑 goto), +--- 每个路径是一个 block 。 +--- +--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>` +--- +--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>` +--- +--- 否则返回 `false` +--- +--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。 +---@param a table +---@param b table +---@return string|boolean mode +---@return table pathA? +---@return table pathB? +function m.getPath(a, b, sameFunction) + --- 首先测试双方在同一个函数内 + if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then + return false + end + local mode + local objA + local objB + if a.finish < b.start then + mode = 'before' + objA = a + objB = b + elseif a.start > b.finish then + mode = 'after' + objA = b + objB = a + else + return 'equal', {}, {} + end + local pathA = {} + local pathB = {} + for _ = 1, 1000 do + objA = m.getParentBlock(objA) + pathA[#pathA+1] = objA + if (not sameFunction and objA.type == 'function') or objA.type == 'main' then + break + end + end + for _ = 1, 1000 do + objB = m.getParentBlock(objB) + pathB[#pathB+1] = objB + if (not sameFunction and objA.type == 'function') or objB.type == 'main' then + break + end + end + -- pathA: {1, 2, 3, 4, 5} + -- pathB: {5, 6, 2, 3} + local top = #pathB + local start + for i = #pathA, 1, -1 do + local currentBlock = pathA[i] + if currentBlock == pathB[top] then + start = i + break + end + end + if not start then + return nil + end + -- pathA: { 1, 2, 3} + -- pathB: {5, 6, 2, 3} + local extra = 0 + local align = top - start + for i = start, 1, -1 do + local currentA = pathA[i] + local currentB = pathB[i+align] + if currentA ~= currentB then + extra = i + break + end + end + -- pathA: {1} + local resultA = {} + for i = extra, 1, -1 do + resultA[#resultA+1] = pathA[i] + end + -- pathB: {5, 6} + local resultB = {} + for i = extra + align, 1, -1 do + resultB[#resultB+1] = pathB[i] + end + return mode, resultA, resultB +end + +-- 根据语法,单步搜索定义 +local function stepRefOfLocal(loc, mode) + local results = {} + if loc.start ~= 0 then + results[#results+1] = loc + end + local refs = loc.ref + if not refs then + return results + end + for i = 1, #refs do + local ref = refs[i] + if ref.start == 0 then + goto CONTINUE + end + if mode == 'def' then + if ref.type == 'local' + or ref.type == 'setlocal' then + results[#results+1] = ref + end + else + if ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + end + end + ::CONTINUE:: + end + return results +end + +local function stepRefOfLabel(label, mode) + local results = { label } + if not label or mode == 'def' then + return results + end + local refs = label.ref + if not refs then + return results + end + for i = 1, #refs do + local ref = refs[i] + results[#results+1] = ref + end + return results +end + +local function stepRefOfDocType(status, obj, mode) + local results = {} + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.alias.name' + or obj.type == 'doc.extends.name' + or obj.type == 'doc.see.name' then + local name = obj[1] + if not name or not status.interface.docType then + return results + end + local docs = status.interface.docType(name) + for i = 1, #docs do + local doc = docs[i] + if mode == 'def' then + if doc.type == 'doc.class.name' + or doc.type == 'doc.alias.name' then + results[#results+1] = doc + end + else + results[#results+1] = doc + end + end + else + results[#results+1] = obj + end + return results +end + +function m.getStepRef(status, obj, mode) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepRefOfLocal(obj.node, mode) + end + if obj.type == 'local' then + return stepRefOfLocal(obj, mode) + end + if obj.type == 'label' then + return stepRefOfLabel(obj, mode) + end + if obj.type == 'goto' then + return stepRefOfLabel(obj.node, mode) + end + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.extends.name' + or obj.type == 'doc.alias.name' then + return stepRefOfDocType(status, obj, mode) + end + return nil +end + +-- 根据语法,单步搜索field +local function stepFieldOfLocal(loc) + local results = {} + local refs = loc.ref + for i = 1, #refs do + local ref = refs[i] + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'getlocal' then + local nxt = ref.next + if nxt then + if nxt.type == 'setfield' + or nxt.type == 'getfield' + or nxt.type == 'setmethod' + or nxt.type == 'getmethod' + or nxt.type == 'setindex' + or nxt.type == 'getindex' then + results[#results+1] = nxt + end + end + end + end + return results +end +local function stepFieldOfTable(tbl) + local result = {} + for i = 1, #tbl do + result[i] = tbl[i] + end + return result +end +function m.getStepField(obj) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepFieldOfLocal(obj.node) + end + if obj.type == 'local' then + return stepFieldOfLocal(obj) + end + if obj.type == 'table' then + return stepFieldOfTable(obj) + end +end + +local function convertSimpleList(list) + local simple = {} + for i = #list, 1, -1 do + local c = list[i] + if c.type == 'getglobal' + or c.type == 'setglobal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + local loc = c.node + if loc.special == '_G' then + simple.mode = 'global' + if not simple.node then + simple.node = c + end + else + simple.mode = 'local' + simple[#simple+1] = m.getSimpleName(loc) + if not simple.node then + simple.node = loc + end + end + elseif c.type == 'getlocal' + or c.type == 'setlocal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + simple.mode = 'local' + if not simple.node then + simple.node = c.node + end + elseif c.type == 'local' then + simple.mode = 'local' + if not simple.node then + simple.node = c + end + else + if not simple.node then + simple.node = c + end + end + simple[#simple+1] = m.getSimpleName(c) or m.ANY + ::CONTINUE:: + end + if simple.mode == 'global' and #simple == 0 then + simple[1] = '_G' + simple.node = list[#list] + end + return simple +end + +-- 搜索 `a.b.c` 的等价表达式 +local function buildSimpleList(obj, max) + local list = {} + local cur = obj + local limit = max and (max + 1) or 11 + for i = 1, max or limit do + if i == limit then + return nil + end + while cur.type == 'paren' do + cur = cur.exp + if not cur then + return nil + end + end + if cur.type == 'setfield' + or cur.type == 'getfield' + or cur.type == 'setmethod' + or cur.type == 'getmethod' + or cur.type == 'setindex' + or cur.type == 'getindex' then + list[i] = cur + cur = cur.node + elseif cur.type == 'tablefield' + or cur.type == 'tableindex' then + list[i] = cur + cur = cur.parent.parent + if cur.type == 'return' then + list[i+1] = list[i].parent + break + end + elseif cur.type == 'getlocal' + or cur.type == 'setlocal' + or cur.type == 'local' then + list[i] = cur + break + elseif cur.type == 'setglobal' + or cur.type == 'getglobal' then + list[i] = cur + break + elseif cur.type == 'select' + or cur.type == 'table' then + list[i] = cur + break + elseif cur.type == 'string' then + list[i] = cur + break + elseif cur.type == 'doc.class.name' + or cur.type == 'doc.type.name' + or cur.type == 'doc.class' + or cur.type == 'doc.see.name' then + list[i] = cur + break + elseif cur.type == 'doc.see.field' then + list[i] = cur + cur = cur.parent.name + elseif cur.type == 'function' + or cur.type == 'main' then + break + else + return nil + end + end + return convertSimpleList(list) +end + +function m.getSimple(obj, max) + local simpleList + if obj.type == 'getfield' + or obj.type == 'setfield' + or obj.type == 'getmethod' + or obj.type == 'setmethod' + or obj.type == 'getindex' + or obj.type == 'setindex' + or obj.type == 'local' + or obj.type == 'getlocal' + or obj.type == 'setlocal' + or obj.type == 'setglobal' + or obj.type == 'getglobal' + or obj.type == 'tablefield' + or obj.type == 'tableindex' + or obj.type == 'select' + or obj.type == 'table' + or obj.type == 'string' + or obj.type == 'doc.class.name' + or obj.type == 'doc.class' + or obj.type == 'doc.type.name' + or obj.type == 'doc.see.name' + or obj.type == 'doc.see.field' then + simpleList = buildSimpleList(obj, max) + elseif obj.type == 'field' + or obj.type == 'method' then + simpleList = buildSimpleList(obj.parent, max) + end + return simpleList +end + +function m.status(parentStatus, interface, deep) + local status = { + share = parentStatus and parentStatus.share or { + count = 0, + }, + depth = parentStatus and (parentStatus.depth + 1) or 0, + searchDeep= parentStatus and parentStatus.searchDeep or deep or -999, + interface = parentStatus and parentStatus.interface or {}, + deep = parentStatus and parentStatus.deep, + results = {}, + } + if interface then + for k, v in pairs(interface) do + status.interface[k] = v + end + end + status.deep = status.depth <= status.searchDeep + return status +end + +function m.copyStatusResults(a, b) + local ra = a.results + local rb = b.results + for i = 1, #rb do + ra[#ra+1] = rb[i] + end +end + +function m.isGlobal(source) + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.type == 'getfield' + or source.type == 'setfield' then + local node = source.node + if node and node.special == '_G' then + return true + end + end + return false +end + +function m.isDoc(source) + return source.type:sub(1, 4) == 'doc.' +end + +--- 根据函数的调用参数,获取:调用,参数索引 +function m.getCallAndArgIndex(callarg) + local callargs = callarg.parent + if not callargs or callargs.type ~= 'callargs' then + return nil + end + local index + for i = 1, #callargs do + if callargs[i] == callarg then + index = i + break + end + end + local call = callargs.parent + return call, index +end + +--- 根据函数调用的返回值,获取:调用的函数,参数列表,自己是第几个返回值 +function m.getCallValue(source) + local value = m.getObjectValue(source) or source + if not value then + return + end + local call, index + if value.type == 'call' then + call = value + index = 1 + elseif value.type == 'select' then + call = value.vararg + index = value.index + if call.type ~= 'call' then + return + end + else + return + end + return call.node, call.args, index +end + +function m.getNextRef(ref) + local nextRef = ref.next + if nextRef then + if nextRef.type == 'setfield' + or nextRef.type == 'getfield' + or nextRef.type == 'setmethod' + or nextRef.type == 'getmethod' + or nextRef.type == 'setindex' + or nextRef.type == 'getindex' then + return nextRef + end + end + -- 穿透 rawget 与 rawset + local call, index = m.getCallAndArgIndex(ref) + if call then + if call.node.special == 'rawset' and index == 1 then + return call + end + if call.node.special == 'rawget' and index == 1 then + return call + end + end + + return nil +end + +function m.checkSameSimpleInValueOfTable(status, value, start, queue) + if value.type ~= 'table' then + return + end + for i = 1, #value do + local field = value[i] + queue[#queue+1] = { + obj = field, + start = start + 1, + } + end +end + +function m.searchFields(status, obj, key) + local simple = m.getSimple(obj) + if not simple then + return + end + simple[#simple+1] = key or m.ANY + m.searchSameFields(status, simple, 'field') + m.cleanResults(status.results) +end + +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args[3] + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +function m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + local newStatus = m.status(status) + m.searchFields(newStatus, mt, '__index') + local refsStatus = m.status(status) + for i = 1, #newStatus.results do + local indexValue = m.getObjectValue(newStatus.results[i]) + if indexValue then + m.searchRefs(refsStatus, indexValue, 'ref') + end + end + for i = 1, #refsStatus.results do + local obj = refsStatus.results[i] + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end +function m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + if not func or func.special ~= 'setmetatable' then + return + end + local call = func.parent + local args = call.args + local obj = args[1] + local mt = args[2] + if obj then + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + if mt then + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.checkSameSimpleInValueOfCallMetaTable(status, call, start, queue) + if call.type == 'call' then + m.checkSameSimpleInValueOfSetMetaTable(status, call.node, start, queue) + end +end + +function m.checkSameSimpleInSpecialBranch(status, obj, start, queue) + if status.interface.index then + local results = status.interface.index(obj) + if not results then + return + end + for _, res in ipairs(results) do + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end + end +end + +local function stepRefOfGeneric(status, typeUnit, args, mode) + if not args then + return nil + end + local results = {} + local myName = typeUnit[1] + for _, typeName in ipairs(typeUnit.typeGeneric[myName]) do + if typeName == typeUnit then + goto CONTINUE + end + local doc = m.getDocState(typeName) + if doc.type ~= 'doc.param' then + goto CONTINUE + end + if not doc.bindSources then + goto CONTINUE + end + local paramName = doc.param[1] + for _, source in ipairs(doc.bindSources) do + if source.type == 'local' + and source[1] == paramName + and source.parent.type == 'funcargs' then + for index, arg in ipairs(source.parent) do + if arg == source then + results[#results+1] = args[index] + end + end + end + end + ::CONTINUE:: + end + return results +end + +function m.checkSameSimpleByDocType(status, doc, args) + if status.share.searchingBindedDoc then + return + end + if doc.type ~= 'doc.type' then + return + end + local results = {} + for _, piece in ipairs(doc.types) do + if piece.typeGeneric then + local pieceResult = stepRefOfGeneric(status, piece, args, 'def') + for _, res in ipairs(pieceResult) do + results[#results+1] = res + end + else + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + results[#results+1] = res + end + end + end + return results +end + +function m.checkSameSimpleByBindDocs(status, obj, start, queue, mode) + if not obj.bindDocs then + return + end + if status.share.searchingBindedDoc then + return + end + local skipInfer = false + local results = {} + for _, doc in ipairs(obj.bindDocs) do + if doc.type == 'doc.class' then + results[#results+1] = doc + elseif doc.type == 'doc.type' then + results[#results+1] = doc + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if obj.type == 'local' + and m.getKeyName(obj) == doc.param[1] then + if obj.parent.type == 'funcargs' + or obj.parent.type == 'in' + or obj.parent.type == 'loop' then + results[#results+1] = doc.extends + end + end + elseif doc.type == 'doc.field' then + results[#results+1] = doc + end + end + for _, res in ipairs(results) do + if res.type == 'doc.class' + or res.type == 'doc.type' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + skipInfer = true + end + if res.type == 'doc.type.function' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + elseif res.type == 'doc.field' then + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end + end + return skipInfer +end + +function m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if status.share.searchingBindedDoc then + return + end + if not obj.bindSources then + return + end + status.share.searchingBindedDoc = true + local mark = {} + local newStatus = m.status(status) + for _, ref in ipairs(obj.bindSources) do + if not mark[ref] then + mark[ref] = true + m.searchRefs(newStatus, ref, mode) + end + end + status.share.searchingBindedDoc = nil + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSameSimpleByDoc(status, obj, start, queue, mode) + if obj.type == 'doc.class.name' + or obj.type == 'doc.class' then + if obj.type == 'doc.class.name' then + obj = m.getDocState(obj) + end + local classStart + for _, doc in ipairs(obj.bindGroup) do + if doc == obj then + classStart = true + elseif doc.type == 'doc.class' then + classStart = false + end + if classStart and doc.type == 'doc.field' then + queue[#queue+1] = { + obj = doc, + start = start + 1, + } + end + end + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if mode == 'ref' then + local pieceResult = stepRefOfDocType(status, obj.class, 'ref') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + if obj.extends then + local pieceResult = stepRefOfDocType(status, obj.extends, 'def') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + end + return true + elseif obj.type == 'doc.type' then + for _, piece in ipairs(obj.types) do + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + if mode == 'ref' then + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + end + return true + elseif obj.type == 'doc.type.name' + or obj.type == 'doc.see.name' then + local pieceResult = stepRefOfDocType(status, obj, 'def') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + local state = m.getDocState(obj) + if state.type == 'doc.type' and mode == 'ref' then + m.checkSameSimpleOfRefByDocSource(status, state, start, queue, mode) + end + return true + elseif obj.type == 'doc.field' then + if mode ~= 'field' then + return m.checkSameSimpleByDoc(status, obj.extends, start, queue, mode) + end + elseif obj.type == 'doc.type.array' then + queue[#queue+1] = { + obj = obj.node, + start = start + 1, + force = true, + } + return true + end +end + +function m.checkSameSimpleInArg1OfSetMetaTable(status, obj, start, queue) + local args = obj.parent + if not args or args.type ~= 'callargs' then + return + end + if args[1] ~= obj then + return + end + local mt = args[2] + if mt then + if m.checkValueMark(status, obj, mt) then + return + end + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.searchSameMethodCrossSelf(ref, mark) + local selfNode + if ref.tag == 'self' then + selfNode = ref + else + if ref.type == 'getlocal' + or ref.type == 'setlocal' then + local node = ref.node + if node.tag == 'self' then + selfNode = node + end + end + end + if selfNode then + if mark[selfNode] then + return nil + end + mark[selfNode] = true + return selfNode.method.node + end +end + +function m.searchSameMethod(ref, mark) + if mark['method'] then + return nil + end + local nxt = ref.next + if not nxt then + return nil + end + if nxt.type == 'setmethod' then + mark['method'] = true + return ref + end + return nil +end + +function m.searchSameFieldsCrossMethod(status, ref, start, queue) + local mark = status.crossMethodMark + if not mark then + mark = {} + status.crossMethodMark = mark + end + local method = m.searchSameMethod(ref, mark) + or m.searchSameMethodCrossSelf(ref, mark) + if not method then + return + end + local methodStatus = m.status(status) + m.searchRefs(methodStatus, method, 'ref') + for _, md in ipairs(methodStatus.results) do + queue[#queue+1] = { + obj = md, + start = start, + force = true, + } + local nxt = md.next + if not nxt then + goto CONTINUE + end + if nxt.type == 'setmethod' then + local func = nxt.value + if not func then + goto CONTINUE + end + local selfNode = func.locals and func.locals[1] + if not selfNode or not selfNode.ref then + goto CONTINUE + end + if mark[selfNode] then + goto CONTINUE + end + mark[selfNode] = true + for _, selfRef in ipairs(selfNode.ref) do + queue[#queue+1] = { + obj = selfRef, + start = start, + force = true, + } + end + end + ::CONTINUE:: + end +end + +local function checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, source, index, args) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local types = m.checkSameSimpleByDocType(status, rtn, args) + if not types then + return + end + for _, res in ipairs(types) do + results[#results+1] = res + end + return true +end + +local function checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local types = m.checkSameSimpleByDocType(status, rtn) + if types then + for _, res in ipairs(types) do + results[#results+1] = res + end + return true + end + end + end + end + end + end +end + +function m.checkSameSimpleInCallInSameFile(status, func, args, index) + local results = {} + if func.special then + return results + end + local newStatus = m.status(status) + m.searchRefs(newStatus, func, 'def') + for _, def in ipairs(newStatus.results) do + local hasDocReturn = checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, def, index) + or checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, def, index, args) + if not hasDocReturn then + local value = m.getObjectValue(def) or def + if value.type == 'function' then + local returns = value.returns + if returns then + for _, ret in ipairs(returns) do + local exp = ret[index] + if exp then + results[#results+1] = exp + end + end + end + end + end + end + return results +end + +function m.checkSameSimpleInCall(status, ref, start, queue, mode) + local func, args, index = m.getCallValue(ref) + if not func then + return + end + if m.checkCallMark(status, func.parent, true) then + return + end + status.share.crossCallCount = status.share.crossCallCount or 0 + if status.share.crossCallCount >= 5 then + return + end + status.share.crossCallCount = status.share.crossCallCount + 1 + -- 检查赋值是 semetatable() 的情况 + m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + -- 检查赋值是 func() 的情况 + local objs = m.checkSameSimpleInCallInSameFile(status, func, args, index) + if status.interface.call then + local cobjs = status.interface.call(func, args, index) + if cobjs then + for _, obj in ipairs(cobjs) do + if not m.checkReturnMark(status, obj) then + objs[#objs+1] = obj + end + end + end + end + m.cleanResults(objs) + local newStatus = m.status(status) + for _, obj in ipairs(objs) do + m.searchRefs(newStatus, obj, mode) + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + status.share.crossCallCount = status.share.crossCallCount - 1 + for _, obj in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end + +local function searchRawset(ref, results) + if m.getKeyName(ref) ~= 'rawset' then + return + end + local call = ref.parent + if call.type ~= 'call' or call.node ~= ref then + return + end + if not call.args then + return + end + local arg1 = call.args[1] + if arg1.special ~= '_G' then + -- 不会吧不会吧,不会真的有人写成 `rawset(_G._G._G, 'xxx', value)` 吧 + return + end + results[#results+1] = call +end + +local function searchG(ref, results) + while ref and m.getKeyName(ref) == '_G' do + results[#results+1] = ref + ref = ref.next + end + if ref then + results[#results+1] = ref + searchRawset(ref, results) + end +end + +local function searchEnvRef(ref, results) + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + searchG(ref, results) + elseif ref.type == 'getlocal' then + results[#results+1] = ref.next + searchG(ref.next, results) + end +end + +function m.findGlobals(ast) + local root = m.getRoot(ast) + local results = {} + local env = m.getENV(root) + if env.ref then + for _, ref in ipairs(env.ref) do + searchEnvRef(ref, results) + end + end + return results +end + +function m.findGlobalsOfName(ast, name) + local root = m.getRoot(ast) + local results = {} + local globals = m.findGlobals(root) + for _, global in ipairs(globals) do + if m.getKeyName(global) == name then + results[#results+1] = global + end + end + return results +end + +function m.checkSameSimpleInGlobal(status, name, source, start, queue) + if not name then + return + end + local objs + if status.interface.global then + objs = status.interface.global(name) + else + objs = m.findGlobalsOfName(source, name) + end + if objs then + for _, obj in ipairs(objs) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + end +end + +function m.checkValueMark(status, a, b) + if not status.share.valueMark then + status.share.valueMark = {} + end + if status.share.valueMark[a] + or status.share.valueMark[b] then + return true + end + status.share.valueMark[a] = true + status.share.valueMark[b] = true + return false +end + +function m.checkCallMark(status, a, mark) + if not status.share.callMark then + status.share.callMark = {} + end + if mark then + status.share.callMark[a] = mark + else + return status.share.callMark[a] + end + return false +end + +function m.checkReturnMark(status, a, mark) + if not status.share.returnMark then + status.share.returnMark = {} + end + local result = status.share.returnMark[a] + if mark then + status.share.returnMark[a] = mark + end + return result +end + +function m.searchSameFieldsInValue(status, ref, start, queue, mode) + local value = m.getObjectValue(ref) + if not value then + return + end + if m.checkValueMark(status, ref, value) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, value, mode) + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + queue[#queue+1] = { + obj = value, + start = start, + force = true, + } + -- 检查形如 a = f() 的分支情况 + m.checkSameSimpleInCall(status, value, start, queue, mode) +end + +function m.checkSameSimpleAsTableField(status, ref, start, queue) + if not status.deep then + --return + end + local parent = ref.parent + if not parent or parent.type ~= 'tablefield' then + return + end + if m.checkValueMark(status, parent, ref) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, parent.field, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSearchLevel(status) + status.share.back = status.share.back or 0 + if status.share.back >= (status.interface.searchLevel or 0) then + -- TODO 限制向前搜索的次数 + --return true + end + status.share.back = status.share.back + 1 + return false +end + +function m.checkSameSimpleAsReturn(status, ref, start, queue) + if not status.deep then + return + end + if not ref.parent or ref.parent.type ~= 'return' then + return + end + if ref.parent.parent.type ~= 'main' then + return + end + if m.checkSearchLevel(status) then + return + end + local newStatus = m.status(status) + m.searchRefsAsFunctionReturn(newStatus, ref, 'ref') + for _, res in ipairs(newStatus.results) do + if not m.checkCallMark(status, res) then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end +end + +function m.checkSameSimpleAsSetValue(status, ref, start, queue) + if ref.type == 'select' then + return + end + local parent = ref.parent + if not parent then + return + end + if m.getObjectValue(parent) ~= ref then + return + end + if m.checkValueMark(status, ref, parent) then + return + end + if m.checkSearchLevel(status) then + return + end + local obj + if parent.type == 'local' + or parent.type == 'setglobal' + or parent.type == 'setlocal' then + obj = parent + elseif parent.type == 'setfield' then + obj = parent.field + elseif parent.type == 'setmethod' then + obj = parent.method + end + if not obj then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, obj, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +local function hasTypeName(doc, name) + if doc.type ~= 'doc.type' then + return false + end + for _, tunit in ipairs(doc.types) do + if tunit.type == 'doc.type.name' + and tunit[1] == name then + return true + end + end + return false +end + +function m.checkSameSimpleInString(status, ref, start, queue, mode) + -- 特殊处理 ('xxx').xxx 的形式 + if ref.type ~= 'string' + and not hasTypeName(ref, 'string') then + return + end + if not status.interface.docType then + return + end + if status.share.searchingBindedDoc then + return + end + local newStatus = m.status(status) + local docs = status.interface.docType('string*') + local mark = {} + for i = 1, #docs do + local doc = docs[i] + m.searchFields(newStatus, doc) + end + for _, res in ipairs(newStatus.results) do + if mark[res] then + goto CONTINUE + end + mark[res] = true + queue[#queue+1] = { + obj = res, + start = start + 1, + } + ::CONTINUE:: + end + return true +end + +function m.pushResult(status, mode, ref, simple) + local results = status.results + if mode == 'def' then + if ref.type == 'setglobal' + or ref.type == 'setlocal' + or ref.type == 'local' then + results[#results+1] = ref + elseif ref.type == 'setfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' then + results[#results+1] = ref + end + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + if m.getParentFunction(ref) ~= m.getParentFunction(simple.node) then + results[#results+1] = ref + end + end + if m.isLiteral(ref) + and ref.parent.type == 'callargs' + and ref ~= simple.node then + results[#results+1] = ref + end + elseif mode == 'ref' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' + or ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + results[#results+1] = ref + end + if m.isLiteral(ref) + and ref.parent.type == 'callargs' + and ref ~= simple.node then + results[#results+1] = ref + end + elseif mode == 'field' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + end +end + +function m.checkSameSimpleName(ref, sm) + if sm == m.ANY then + return true + end + if m.getSimpleName(ref) == sm then + return true + end + if ref.type == 'doc.type' + and ref.array == true then + return true + end + return false +end + +function m.checkSameSimple(status, simple, data, mode, queue) + local ref = data.obj + local start = data.start + local force = data.force + if start > #simple then + return + end + for i = start, #simple do + local sm = simple[i] + if not force and not m.checkSameSimpleName(ref, sm) then + return + end + force = false + local cmode = mode + if i < #simple then + cmode = 'ref' + end + -- 检查 doc + local skipInfer = m.checkSameSimpleByBindDocs(status, ref, i, queue, cmode) + or m.checkSameSimpleByDoc(status, ref, i, queue, cmode) + -- 检查自己是字符串的分支情况 + m.checkSameSimpleInString(status, ref, i, queue, cmode) + if not skipInfer then + -- 穿透 self:func 与 mt:func + m.searchSameFieldsCrossMethod(status, ref, i, queue) + -- 穿透赋值 + m.searchSameFieldsInValue(status, ref, i, queue, cmode) + -- 检查自己是字面量表的情况 + m.checkSameSimpleInValueOfTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 第一个参数的情况 + m.checkSameSimpleInArg1OfSetMetaTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 调用的情况 + m.checkSameSimpleInValueOfCallMetaTable(status, ref, i, queue) + -- 检查自己是特殊变量的分支的情况 + m.checkSameSimpleInSpecialBranch(status, ref, i, queue) + if cmode == 'ref' then + -- 检查形如 { a = f } 的情况 + m.checkSameSimpleAsTableField(status, ref, i, queue) + -- 检查形如 return m 的情况 + m.checkSameSimpleAsReturn(status, ref, i, queue) + -- 检查形如 a = f 的情况 + m.checkSameSimpleAsSetValue(status, ref, i, queue) + end + end + if i == #simple then + break + end + ref = m.getNextRef(ref) + if not ref then + return + end + end + m.pushResult(status, mode, ref, simple) + local value = m.getObjectValue(ref) + if value then + m.pushResult(status, mode, value, simple) + end +end + +function m.searchSameFields(status, simple, mode) + local queue = {} + if simple.mode == 'global' then + -- 全局变量开头 + m.checkSameSimpleInGlobal(status, simple[1], simple.node, 1, queue) + elseif simple.mode == 'local' then + -- 局部变量开头 + queue[1] = { + obj = simple.node, + start = 1, + } + local refs = simple.node.ref + if refs then + for i = 1, #refs do + queue[#queue+1] = { + obj = refs[i], + start = 1, + } + end + end + else + queue[1] = { + obj = simple.node, + start = 1, + } + end + local max = 0 + local lock = {} + for i = 1, 1e6 do + local data = queue[i] + if not data then + return + end + if not lock[data.obj] then + lock[data.obj] = true + max = max + 1 + status.share.count = status.share.count + 1 + m.checkSameSimple(status, simple, data, mode, queue) + if max >= 10000 then + logWarn('Queue too large!') + break + end + end + end +end + +function m.getCallerInSameFile(status, func) + -- 搜索所有所在函数的调用者 + local funcRefs = m.status(status) + m.searchRefOfValue(funcRefs, func) + + local calls = {} + if #funcRefs.results == 0 then + return calls + end + for _, res in ipairs(funcRefs.results) do + local call = res.parent + if call.type == 'call' then + calls[#calls+1] = call + end + end + return calls +end + +function m.getCallerCrossFiles(status, main) + if status.interface.link then + return status.interface.link(main.uri) + end + return {} +end + +function m.searchRefsAsFunctionReturn(status, obj, mode) + if not status.deep then + return + end + if mode == 'def' then + return + end + if m.checkReturnMark(status, obj, true) then + return + end + status.results[#status.results+1] = obj + -- 搜索所在函数 + local currentFunc = m.getParentFunction(obj) + local rtn = obj.parent + if rtn.type ~= 'return' then + return + end + -- 看看他是第几个返回值 + local index + for i = 1, #rtn do + if obj == rtn[i] then + index = i + break + end + end + if not index then + return + end + local calls + if currentFunc.type == 'main' then + calls = m.getCallerCrossFiles(status, currentFunc) + else + calls = m.getCallerInSameFile(status, currentFunc) + end + -- 搜索调用者的返回值 + if #calls == 0 then + return + end + local selects = {} + for i = 1, #calls do + local parent = calls[i].parent + if parent.type == 'select' and parent.index == index then + selects[#selects+1] = parent.parent + end + local extParent = calls[i].extParent + if extParent then + for j = 1, #extParent do + local ext = extParent[j] + if ext.type == 'select' and ext.index == index then + selects[#selects+1] = ext.parent + end + end + end + end + -- 搜索调用者的引用 + for i = 1, #selects do + m.searchRefs(status, selects[i], 'ref') + end +end + +function m.searchRefsAsFunctionSet(status, obj, mode) + local parent = obj.parent + if not parent then + return + end + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' + or parent.type == 'setfield' + or parent.type == 'setmethod' + or parent.type == 'tablefield' then + m.searchRefs(status, parent, mode) + elseif parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == obj then + m.searchRefs(status, parent, mode) + end + end +end + +function m.searchRefsAsFunction(status, obj, mode) + if obj.type ~= 'function' + and obj.type ~= 'table' then + return + end + m.searchRefsAsFunctionSet(status, obj, mode) + -- 检查自己作为返回函数时的引用 + m.searchRefsAsFunctionReturn(status, obj, mode) +end + +function m.cleanResults(results) + local mark = {} + for i = #results, 1, -1 do + local res = results[i] + if res.tag == 'self' + or mark[res] then + results[i] = results[#results] + results[#results] = nil + else + mark[res] = true + end + end +end + +--function m.getRefCache(status, obj, mode) +-- local cache = status.interface.cache and status.interface.cache() +-- if not cache then +-- return +-- end +-- if m.isGlobal(obj) then +-- obj = m.getKeyName(obj) +-- end +-- if not cache[mode] then +-- cache[mode] = {} +-- end +-- local sourceCache = cache[mode][obj] +-- if sourceCache then +-- return sourceCache +-- end +-- sourceCache = {} +-- cache[mode][obj] = sourceCache +-- return nil, function (results) +-- for i = 1, #results do +-- sourceCache[i] = results[i] +-- end +-- end +--end + +function m.getRefCache(status, obj, mode) + local cache, globalCache + if status.depth == 1 + and status.deep then + globalCache = status.interface.cache and status.interface.cache() or {} + end + cache = status.share.refCache or {} + status.share.refCache = cache + if m.isGlobal(obj) then + obj = m.getKeyName(obj) + end + if not cache[mode] then + cache[mode] = {} + end + if globalCache and not globalCache[mode] then + globalCache[mode] = {} + end + local sourceCache = globalCache and globalCache[mode][obj] or cache[mode][obj] + if sourceCache then + return sourceCache + end + sourceCache = {} + cache[mode][obj] = sourceCache + if globalCache then + globalCache[mode][obj] = sourceCache + end + return nil, function (results) + for i = 1, #results do + sourceCache[i] = results[i] + end + end +end + +function m.searchRefs(status, obj, mode) + local cache, makeCache = m.getRefCache(status, obj, mode) + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + -- 检查单步引用 + local res = m.getStepRef(status, obj, mode) + if res then + for i = 1, #res do + status.results[#status.results+1] = res[i] + end + end + -- 检查simple + if status.depth <= 100 then + local simple = m.getSimple(obj) + if simple then + m.searchSameFields(status, simple, mode) + end + else + if m.debugMode then + error('status.depth overflow') + elseif DEVELOP then + --log.warn(debug.traceback('status.depth overflow')) + logWarn('status.depth overflow') + end + end + + m.cleanResults(status.results) + + if makeCache then + makeCache(status.results) + end +end + +function m.searchRefOfValue(status, obj) + local var = obj.parent + if var.type == 'local' + or var.type == 'set' then + return m.searchRefs(status, var, 'ref') + end +end + +function m.allocInfer(o) + if type(o.type) == 'table' then + local infers = {} + for i = 1, #o.type do + infers[i] = { + type = o.type[i], + value = o.value, + source = o.source, + level = o.level + } + end + return infers + else + return { + [1] = o, + } + end +end + +function m.mergeTypes(types) + local results = {} + local mark = {} + local hasAny + -- 这里把 any 去掉 + for i = 1, #types do + local tp = types[i] + if tp == 'any' then + hasAny = true + end + if not mark[tp] and tp ~= 'any' then + mark[tp] = true + results[#results+1] = tp + end + end + if #results == 0 then + return 'any' + end + -- 只有显性的 nil 与 any 时,取 any + if #results == 1 then + if results[1] == 'nil' and hasAny then + return 'any' + else + return results[1] + end + end + -- 同时包含 number 与 integer 时,去掉 integer + if mark['number'] and mark['integer'] then + for i = 1, #results do + if results[i] == 'integer' then + tableRemove(results, i) + break + end + end + end + tableSort(results, function (a, b) + local sa = TypeSort[a] or 100 + local sb = TypeSort[b] or 100 + return sa < sb + end) + return tableConcat(results, '|') +end + +function m.viewInferType(infers) + if not infers then + return 'any' + end + local mark = {} + local types = {} + local hasDoc + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' then + if infer.type ~= 'any' then + hasDoc = true + break + end + end + end + if hasDoc then + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' + or src.type == 'doc.type.enum' + or src.type == 'doc.resume' then + local tp = infer.type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + end + end + else + for i = 1, #infers do + local infer = infers[i] + if infer.source.typeGeneric then + goto CONTINUE + end + local tp = infer.type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + ::CONTINUE:: + end + end + return m.mergeTypes(types) +end + +function m.checkTrue(status, source) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + -- 当前认为的结果 + local current + for _, infer in ipairs(newStatus.results) do + -- 新的结果 + local new + if infer.type == 'nil' then + new = false + elseif infer.type == 'boolean' then + if infer.value == true then + new = true + elseif infer.value == false then + new = false + end + end + if new ~= nil then + if current == nil then + current = new + else + -- 如果2个结果完全相反,则返回 nil 表示不确定 + if new ~= current then + return nil + end + end + end + end + return current +end + +--- 获取特定类型的字面量值 +function m.getInferLiteral(status, source, type) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + for _, infer in ipairs(newStatus.results) do + if infer.value ~= nil then + if type == nil or infer.type == type then + return infer.value + end + end + end + return nil +end + +--- 是否包含某种类型 +function m.hasType(status, source, type) + m.searchInfer(status, source) + for _, infer in ipairs(status.results) do + if infer.type == type then + return true + end + end + return false +end + +function m.isSameValue(status, a, b) + local statusA = m.status(status) + m.searchInfer(statusA, a) + local statusB = m.status(status) + m.searchInfer(statusB, b) + local infers = {} + for _, infer in ipairs(statusA.results) do + local literal = infer.value + if literal then + infers[literal] = false + end + end + for _, infer in ipairs(statusB.results) do + local literal = infer.value + if literal then + if infers[literal] == nil then + return false + end + infers[literal] = true + end + end + for k, v in pairs(infers) do + if v == false then + return false + end + end + return true +end + +function m.inferCheckLiteralTableWithDocVararg(status, source) + if #source ~= 1 then + return + end + local vararg = source[1] + if vararg.type ~= 'varargs' then + return + end + local results = m.getVarargDocType(status, source) + status.results[#status.results+1] = { + type = m.viewInferType(results) .. '[]', + source = source, + level = 100, + } + return true +end + +function m.inferCheckLiteral(status, source) + if source.type == 'string' then + status.results = m.allocInfer { + type = 'string', + value = source[1], + source = source, + level = 100, + } + return true + elseif source.type == 'nil' then + status.results = m.allocInfer { + type = 'nil', + value = NIL, + source = source, + level = 100, + } + return true + elseif source.type == 'boolean' then + status.results = m.allocInfer { + type = 'boolean', + value = source[1], + source = source, + level = 100, + } + return true + elseif source.type == 'number' then + if mathType(source[1]) == 'integer' then + status.results = m.allocInfer { + type = 'integer', + value = source[1], + source = source, + level = 100, + } + return true + else + status.results = m.allocInfer { + type = 'number', + value = source[1], + source = source, + level = 100, + } + return true + end + elseif source.type == 'integer' then + status.results = m.allocInfer { + type = 'integer', + source = source, + level = 100, + } + return true + elseif source.type == 'table' then + if m.inferCheckLiteralTableWithDocVararg(status, source) then + return true + end + status.results = m.allocInfer { + type = 'table', + source = source, + level = 100, + } + return true + elseif source.type == 'function' then + status.results = m.allocInfer { + type = 'function', + source = source, + level = 100, + } + return true + elseif source.type == '...' then + status.results = m.allocInfer { + type = '...', + source = source, + level = 100, + } + return true + end +end + +local function getDocAliasExtends(status, typeUnit) + if not status.interface.docType then + return nil + end + if typeUnit.type ~= 'doc.type.name' then + return nil + end + for _, doc in ipairs(status.interface.docType(typeUnit[1])) do + if doc.type == 'doc.alias.name' then + return doc.parent.extends + end + end + return nil +end + +local function getDocTypeUnitName(status, unit) + local typeName + if unit.type == 'doc.type.name' then + typeName = unit[1] + elseif unit.type == 'doc.type.function' then + typeName = 'function' + elseif unit.type == 'doc.type.array' then + typeName = getDocTypeUnitName(status, unit.node) .. '[]' + elseif unit.type == 'doc.type.generic' then + typeName = ('%s<%s, %s>'):format( + getDocTypeUnitName(status, unit.node), + m.viewInferType(m.getDocTypeNames(status, unit.key)), + m.viewInferType(m.getDocTypeNames(status, unit.value)) + ) + end + if unit.typeGeneric then + typeName = ('<%s>'):format(typeName) + end + return typeName +end + +function m.getDocTypeNames(status, doc) + local results = {} + if not doc then + return results + end + for _, unit in ipairs(doc.types) do + local alias = getDocAliasExtends(status, unit) + if alias then + local aliasResults = m.getDocTypeNames(status, alias) + for _, res in ipairs(aliasResults) do + results[#results+1] = res + end + else + local typeName = getDocTypeUnitName(status, unit) + results[#results+1] = { + type = typeName, + source = unit, + level = 100, + } + end + end + for _, enum in ipairs(doc.enums) do + results[#results+1] = { + type = enum[1], + source = enum, + level = 100, + } + end + for _, resume in ipairs(doc.resumes) do + if not resume.additional then + results[#results+1] = { + type = resume[1], + source = resume, + level = 100, + } + end + end + return results +end + +function m.inferCheckDoc(status, source) + if source.type == 'doc.class.name' then + status.results[#status.results+1] = { + type = source[1], + source = source, + level = 100, + } + return true + end + if source.type == 'doc.class' then + status.results[#status.results+1] = { + type = source.class[1], + source = source, + level = 100, + } + return true + end + if source.type == 'doc.type' then + local results = m.getDocTypeNames(status, source) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + if source.type == 'doc.field' then + local results = m.getDocTypeNames(status, source.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end +end + +function m.getVarargDocType(status, source) + local func = m.getParentFunction(source) + if not func then + return + end + if not func.args then + return + end + for _, arg in ipairs(func.args) do + if arg.type == '...' then + if arg.bindDocs then + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.vararg' then + return m.getDocTypeNames(status, doc.vararg) + end + end + end + end + end +end + +function m.inferCheckUpDocOfVararg(status, source) + if not source.vararg then + return + end + local results = m.getVarargDocType(status, source) + if not results then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +function m.inferCheckUpDoc(status, source) + if m.inferCheckUpDocOfVararg(status, source) then + return true + end + local parent = source.parent + if parent then + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' then + source = parent + end + if parent.type == 'setfield' + or parent.type == 'tablefield' then + if parent.field == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setmethod' then + if parent.method == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == source + or parent.value == source then + source = parent + end + end + end + local binds = source.bindDocs + if not binds then + return + end + status.results = {} + for _, doc in ipairs(binds) do + if doc.type == 'doc.class' then + status.results[#status.results+1] = { + type = doc.class[1], + source = doc, + level = 100, + } + -- ---@class Class + -- local x = { field = 1 } + -- 这种情况下,将字面量表接受为Class的定义 + if source.value and source.value.type == 'table' then + status.results[#status.results+1] = { + type = source.value.type, + source = source.value, + level = 100, + } + end + return true + elseif doc.type == 'doc.type' then + local results = m.getDocTypeNames(status, doc) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if source.type == 'local' + and m.getKeyName(source) == doc.param[1] then + if source.parent.type == 'funcargs' + or source.parent.type == 'in' + or source.parent.type == 'loop' then + local results = m.getDocTypeNames(status, doc.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + end + end + end +end + +function m.inferCheckUnary(status, source) + if source.type ~= 'unary' then + return + end + local op = source.op + if op.type == 'not' then + local checkTrue = m.checkTrue(status, source[1]) + local value = nil + if checkTrue == true then + value = false + elseif checkTrue == false then + value = true + end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + level = 100, + } + return true + elseif op.type == '#' then + status.results = m.allocInfer { + type = 'integer', + source = source, + level = 100, + } + return true + elseif op.type == '~' then + local l = m.getInferLiteral(status, source[1], 'integer') + status.results = m.allocInfer { + type = 'integer', + value = l and ~l or nil, + source = source, + level = 100, + } + return true + elseif op.type == '-' then + local v = m.getInferLiteral(status, source[1], 'integer') + if v then + status.results = m.allocInfer { + type = 'integer', + value = - v, + source = source, + level = 100, + } + return true + end + v = m.getInferLiteral(status, source[1], 'number') + status.results = m.allocInfer { + type = 'number', + value = v and -v or nil, + source = source, + level = 100, + } + return true + end +end + +local function mathCheck(status, a, b) + local v1 = m.getInferLiteral(status, a, 'integer') + or m.getInferLiteral(status, a, 'number') + local v2 = m.getInferLiteral(status, b, 'integer') + or m.getInferLiteral(status, a, 'number') + local int = m.hasType(status, a, 'integer') + and m.hasType(status, b, 'integer') + and not m.hasType(status, a, 'number') + and not m.hasType(status, b, 'number') + return int and 'integer' or 'number', v1, v2 +end + +function m.inferCheckBinary(status, source) + if source.type ~= 'binary' then + return + end + local op = source.op + if op.type == 'and' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[2]) + return true + elseif isTrue == false then + m.searchInfer(status, source[1]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == 'or' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[1]) + return true + elseif isTrue == false then + m.searchInfer(status, source[2]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == '==' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + level = 100, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = true + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + level = 100, + } + return true + elseif op.type == '~=' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = not value, + source = source, + level = 100, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = false + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + level = 100, + } + return true + elseif op.type == '<=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 <= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '>=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 >= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 < v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '|' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 | v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '~' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 ~ v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '&' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 & v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '<<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 << v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '>>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 >> v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '..' then + local v1 = m.getInferLiteral(status, source[1], 'string') + local v2 = m.getInferLiteral(status, source[2], 'string') + local v + if v1 and v2 then + v = v1 .. v2 + end + status.results = m.allocInfer { + type = 'string', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '^' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 ^ v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + level = 100, + } + return true + elseif op.type == '/' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + level = 100, + } + return true + -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数 + elseif op.type == '+' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 + v2) or nil, + source = source, + level = 100, + } + return true + elseif op.type == '-' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 - v2) or nil, + source = source, + level = 100, + } + return true + elseif op.type == '*' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 * v2) or nil, + source = source, + level = 100, + } + return true + elseif op.type == '%' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 % v2) or nil, + source = source, + level = 100, + } + return true + elseif op.type == '//' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 // v2) or nil, + source = source, + level = 100, + } + return true + end +end + +function m.inferByDef(status, obj) + if not status.share.inferedDef then + status.share.inferedDef = {} + end + if status.share.inferedDef[obj] then + return + end + status.share.inferedDef[obj] = true + local mark = {} + local newStatus = m.status(status, status.interface) + m.searchRefs(newStatus, obj, 'def') + for _, src in ipairs(newStatus.results) do + local inferStatus = m.status(newStatus) + m.searchInfer(inferStatus, src) + for _, infer in ipairs(inferStatus.results) do + if not mark[infer.source] then + mark[infer.source] = true + status.results[#status.results+1] = infer + end + end + end +end + +local function inferBySetOfLocal(status, source) + if status.share[source] then + return + end + status.share[source] = true + local newStatus = m.status(status) + if source.value then + m.searchInfer(newStatus, source.value) + end + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + break + end + m.searchInfer(newStatus, ref) + end + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end +end + +function m.inferBySet(status, source) + if #status.results ~= 0 then + return + end + if source.type == 'local' then + inferBySetOfLocal(status, source) + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + inferBySetOfLocal(status, source.node) + end +end + +function m.inferByCall(status, source) + if not source.parent then + return + end + if source.parent.type ~= 'call' then + return + end + if source.parent.node == source then + status.results[#status.results+1] = { + type = 'function', + source = source, + level = 10, + } + return + end +end + +function m.inferByGetTable(status, source) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local next = source.next + if not next then + return + end + if next.type == 'getfield' + or next.type == 'getindex' + or next.type == 'setmethod' + or next.type == 'setfield' + or next.type == 'setindex' then + status.results[#status.results+1] = { + type = 'table', + source = source, + level = 10, + } + elseif next.type == 'getmethod' then + status.results[#status.results+1] = { + type = 'table', + source = source, + level = 5, + } + status.results[#status.results+1] = { + type = 'string', + source = source, + level = 5, + } + end +end + +function m.inferByUnary(status, source) + local parent = source.parent + if not parent or parent.type ~= 'unary' then + return + end + local op = parent.op + if op.type == '#' then + status.results[#status.results+1] = { + type = 'string', + source = source, + level = 5, + } + status.results[#status.results+1] = { + type = 'table', + source = source, + level = 5, + } + elseif op.type == '~' then + status.results[#status.results+1] = { + type = 'integer', + source = source, + level = 10, + } + elseif op.type == '-' then + status.results[#status.results+1] = { + type = 'number', + source = source, + level = 10, + } + end +end + +function m.inferByBinary(status, source) + local parent = source.parent + if not parent or parent.type ~= 'binary' then + return + end + local op = parent.op + if op.type == '<=' + or op.type == '>=' + or op.type == '<' + or op.type == '>' + or op.type == '^' + or op.type == '/' + or op.type == '+' + or op.type == '-' + or op.type == '*' + or op.type == '%' then + status.results[#status.results+1] = { + type = 'number', + source = source, + level = 10, + } + elseif op.type == '|' + or op.type == '~' + or op.type == '&' + or op.type == '<<' + or op.type == '>>' + -- 整数的可能性比较高 + or op.type == '//' then + status.results[#status.results+1] = { + type = 'integer', + source = source, + level = 10, + } + elseif op.type == '..' then + status.results[#status.results+1] = { + type = 'string', + source = source, + level = 10, + } + end +end + +local function mergeFunctionReturnsByDoc(status, source, index, call) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local results = m.getDocTypeNames(status, rtn) + if #results == 0 then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +local function mergeDocTypeFunctionReturns(status, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local results = m.getDocTypeNames(status, rtn) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + end + end + end + end + end +end + +local function mergeFunctionReturns(status, source, index, call) + local returns = source.returns + if not returns then + return + end + for i = 1, #returns do + local rtn = returns[i] + if rtn[index] then + if rtn[index].type == 'call' then + if not m.checkReturnMark(status, rtn[index], true) then + m.inferByCallReturnAndIndex(status, rtn[index], index) + end + else + local newStatus = m.status(status) + m.searchInfer(newStatus, rtn[index]) + if #newStatus.results == 0 then + status.results[#status.results+1] = { + type = 'any', + source = rtn[index], + level = 0, + } + else + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end + end + end + end +end + +function m.inferByCallReturnAndIndex(status, call, index) + local node = call.node + local newStatus = m.status(status, status.interface) + m.searchRefs(newStatus, node, 'def') + local hasDocReturn + for _, src in ipairs(newStatus.results) do + if mergeDocTypeFunctionReturns(status, src, index) then + hasDocReturn = true + elseif mergeFunctionReturnsByDoc(status, src.value, index, call) then + hasDocReturn = true + end + end + if not hasDocReturn then + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + if not m.checkReturnMark(status, src.value, true) then + mergeFunctionReturns(status, src.value, index, call) + end + end + end + end +end + +function m.inferByCallReturn(status, source) + if source.type == 'call' then + m.inferByCallReturnAndIndex(status, source, 1) + return + end + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + if not source.vararg or source.vararg.type ~= 'call' then + return + end + m.inferByCallReturnAndIndex(status, source.vararg, source.index) +end + +function m.inferByPCallReturn(status, source) + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + local call = source.vararg + if not call or call.type ~= 'call' then + return + end + local node = call.node + local specialName = node.special + local func, index + if specialName == 'pcall' then + func = call.args[1] + index = source.index - 1 + elseif specialName == 'xpcall' then + func = call.args[1] + index = source.index - 2 + else + return + end + local newStatus = m.status(status, status.interface) + m.searchRefs(newStatus, func, 'def') + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + mergeFunctionReturns(status, src.value, index) + end + end +end + +function m.cleanInfers(infers, obj) + -- kick lower level infers + local level = 0 + if obj.type ~= 'select' then + for i = 1, #infers do + local infer = infers[i] + if infer.level > level then + level = infer.level + end + end + end + -- merge infers + local mark = {} + for i = #infers, 1, -1 do + local infer = infers[i] + if infer.level < level then + infers[i] = infers[#infers] + infers[#infers] = nil + goto CONTINUE + end + local key = ('%p'):format(infer.type, infer.source) + if mark[key] then + infers[i] = infers[#infers] + infers[#infers] = nil + else + mark[key] = true + end + ::CONTINUE:: + end + -- kick doc.generic + if #infers > 1 then + for i = #infers, 1, -1 do + local infer = infers[i] + if infer.source.typeGeneric then + infers[i] = infers[#infers] + infers[#infers] = nil + end + end + end +end + +function m.searchInfer(status, obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return + end + end + while true do + local value = m.getObjectValue(obj) + if not value or value == obj then + break + end + obj = value + end + + local cache, makeCache = m.getRefCache(status, obj, 'infer') + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + if DEVELOP then + status.share.clock = status.share.clock or osClock() + end + + if not status.share.lockInfer then + status.share.lockInfer = {} + end + if status.share.lockInfer[obj] then + return + end + status.share.lockInfer[obj] = true + + local checked = m.inferCheckDoc(status, obj) + or m.inferCheckUpDoc(status, obj) + or m.inferCheckLiteral(status, obj) + or m.inferCheckUnary(status, obj) + or m.inferCheckBinary(status, obj) + if checked then + m.cleanInfers(status.results, obj) + if makeCache then + makeCache(status.results) + end + return + end + + if status.deep then + m.inferByDef(status, obj) + end + m.inferBySet(status, obj) + m.inferByCall(status, obj) + m.inferByGetTable(status, obj) + m.inferByUnary(status, obj) + m.inferByBinary(status, obj) + m.inferByCallReturn(status, obj) + m.inferByPCallReturn(status, obj) + m.cleanInfers(status.results, obj) + if makeCache then + makeCache(status.results) + end +end + +--- 请求对象的引用,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 不穿透 `setmetatable` ,考虑由 +--- 业务层进行反向 def 搜索。 +function m.requestReference(obj, interface, deep) + local status = m.status(nil, interface, deep) + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'ref') + + m.searchRefsAsFunction(status, obj, 'ref') + + if m.debugMode then + print('count:', status.share.count) + end + + return status.results, status.share.count +end + +--- 请求对象的定义,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 穿透 `setmetatable` 。 +function m.requestDefinition(obj, interface, deep) + local status = m.status(nil, interface, deep) + -- 根据 field 搜索定义 + m.searchRefs(status, obj, 'def') + + return status.results, status.share.count +end + +--- 请求对象的域 +function m.requestFields(obj, interface, deep) + local status = m.status(nil, interface, deep) + + m.searchFields(status, obj) + + return status.results, status.share.count +end + +--- 请求对象的类型推测 +function m.requestInfer(obj, interface, deep) + local status = m.status(nil, interface, deep) + m.searchInfer(status, obj) + + return status.results, status.share.count +end + +return m diff --git a/test/full/example.lua b/test/full/example.lua index e409b4b2..b19f0485 100644 --- a/test/full/example.lua +++ b/test/full/example.lua @@ -14,15 +14,26 @@ local function testIfExit(path) local clock = os.clock() local max = 100 local need + local parseClock = 0 + local compileClock = 0 + local total for i = 1, max do vm = TEST(buf) local passed = os.clock() - clock + parseClock = parseClock + vm.parseClock + compileClock = compileClock + vm.compileClock if passed >= 1.0 or i == max then need = passed / i + total = i break end end - print(('基准编译测试[%s]单次耗时:%.10f'):format(path:filename():string(), need)) + print(('基准编译测试[%s]单次耗时:%.10f(解析:%.10f, 编译:%.10f)'):format( + path:filename():string(), + need, + parseClock / total, + compileClock / total + )) local clock = os.clock() local max = 100 @@ -43,3 +54,4 @@ local function testIfExit(path) end testIfExit(ROOT / 'test' / 'example' / 'vm.txt') testIfExit(ROOT / 'test' / 'example' / 'largeGlobal.txt') +testIfExit(ROOT / 'test' / 'example' / 'guide.txt') |