diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2020-11-20 21:57:09 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2020-11-20 21:57:09 +0800 |
commit | 4ca61ec457822dd14966afa0752340ae8ce180a1 (patch) | |
tree | ae8adb1ad82c717868e551e699fd3cf3bb290089 /script/parser/guide.lua | |
parent | c63b2e404d8d2bb984afe3678a5ba2b2836380cc (diff) | |
download | lua-language-server-4ca61ec457822dd14966afa0752340ae8ce180a1.zip |
no longer beta
Diffstat (limited to 'script/parser/guide.lua')
-rw-r--r-- | script/parser/guide.lua | 3884 |
1 files changed, 3884 insertions, 0 deletions
diff --git a/script/parser/guide.lua b/script/parser/guide.lua new file mode 100644 index 00000000..6ef239f1 --- /dev/null +++ b/script/parser/guide.lua @@ -0,0 +1,3884 @@ +local util = require 'utility' +local error = error +local type = type +local next = next +local tostring = tostring +local print = print +local ipairs = ipairs +local tableInsert = table.insert +local tableUnpack = table.unpack +local tableRemove = table.remove +local tableMove = table.move +local tableSort = table.sort +local tableConcat = table.concat +local mathType = math.type +local pairs = pairs +local setmetatable = setmetatable +local assert = assert +local select = select +local osClock = os.clock +local DEVELOP = _G.DEVELOP +local log = log +local _G = _G + +local function logWarn(...) + log.warn(...) +end + +_ENV = nil + +local m = {} + +local blockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, + ['do'] = true, + ['function'] = true, + ['ifblock'] = true, + ['elseblock'] = true, + ['elseifblock'] = true, + ['main'] = true, +} + +local breakBlockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, +} + +m.childMap = { + ['main'] = {'#', 'docs'}, + ['repeat'] = {'#', 'filter'}, + ['while'] = {'filter', '#'}, + ['in'] = {'keys', '#'}, + ['loop'] = {'loc', 'max', 'step', '#'}, + ['if'] = {'#'}, + ['ifblock'] = {'filter', '#'}, + ['elseifblock'] = {'filter', '#'}, + ['elseblock'] = {'#'}, + ['setfield'] = {'node', 'field', 'value'}, + ['setglobal'] = {'value'}, + ['local'] = {'attrs', 'value'}, + ['setlocal'] = {'value'}, + ['return'] = {'#'}, + ['do'] = {'#'}, + ['select'] = {'vararg'}, + ['table'] = {'#'}, + ['tableindex'] = {'index', 'value'}, + ['tablefield'] = {'field', 'value'}, + ['function'] = {'args', '#'}, + ['funcargs'] = {'#'}, + ['setmethod'] = {'node', 'method', 'value'}, + ['getmethod'] = {'node', 'method'}, + ['setindex'] = {'node', 'index', 'value'}, + ['getindex'] = {'node', 'index'}, + ['paren'] = {'exp'}, + ['call'] = {'node', 'args'}, + ['callargs'] = {'#'}, + ['getfield'] = {'node', 'field'}, + ['list'] = {'#'}, + ['binary'] = {1, 2}, + ['unary'] = {1}, + + ['doc'] = {'#'}, + ['doc.class'] = {'class', 'extends'}, + ['doc.type'] = {'#types', '#enums', 'name'}, + ['doc.alias'] = {'alias', 'extends'}, + ['doc.param'] = {'param', 'extends'}, + ['doc.return'] = {'#returns'}, + ['doc.field'] = {'field', 'extends'}, + ['doc.generic'] = {'#generics'}, + ['doc.generic.object'] = {'generic', 'extends'}, + ['doc.vararg'] = {'vararg'}, + ['doc.type.table'] = {'key', 'value'}, + ['doc.type.function'] = {'#args', '#returns'}, + ['doc.overload'] = {'overload'}, +} + +m.actionMap = { + ['main'] = {'#'}, + ['repeat'] = {'#'}, + ['while'] = {'#'}, + ['in'] = {'#'}, + ['loop'] = {'#'}, + ['if'] = {'#'}, + ['ifblock'] = {'#'}, + ['elseifblock'] = {'#'}, + ['elseblock'] = {'#'}, + ['do'] = {'#'}, + ['function'] = {'#'}, + ['funcargs'] = {'#'}, +} + +local TypeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['nil'] = 999, +} + +local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) + +--- 是否是字面量 +function m.isLiteral(obj) + local tp = obj.type + return tp == 'nil' + or tp == 'boolean' + or tp == 'string' + or tp == 'number' + or tp == 'table' +end + +--- 获取字面量 +function m.getLiteral(obj) + local tp = obj.type + if tp == 'boolean' then + return obj[1] + elseif tp == 'string' then + return obj[1] + elseif tp == 'number' then + return obj[1] + end + return nil +end + +--- 寻找父函数 +function m.getParentFunction(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + break + end + local tp = obj.type + if tp == 'function' or tp == 'main' then + return obj + end + end + return nil +end + +--- 寻找所在区块 +function m.getBlock(obj) + for _ = 1, 1000 do + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + obj = obj.parent + end + error('guide.getBlock overstack') +end + +--- 寻找所在父区块 +function m.getParentBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + end + error('guide.getParentBlock overstack') +end + +--- 寻找所在可break的父区块 +function m.getBreakBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if breakBlockTypes[tp] then + return obj + end + if tp == 'function' then + return nil + end + end + error('guide.getBreakBlock overstack') +end + +--- 寻找doc的主体 +function m.getDocState(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return obj + end + if parent.type == 'doc' then + return obj + end + obj = parent + end + error('guide.getDocState overstack') +end + +--- 寻找所在父类型 +function m.getParentType(obj, want) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + if want == obj.type then + return obj + end + end + error('guide.getParentType overstack') +end + +--- 寻找根区块 +function m.getRoot(obj) + for _ = 1, 1000 do + if obj.type == 'main' then + return obj + end + local parent = obj.parent + if not parent then + return nil + end + obj = parent + end + error('guide.getRoot overstack') +end + +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = m.getRoot(obj) + if root then + return root.uri + end + return '' +end + +function m.getENV(source, start) + if not start then + start = 1 + end + return m.getLocal(source, '_ENV', start) + or m.getLocal(source, '@fenv', start) +end + +--- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。 +--- 如果函数是主函数,则返回`0, nil`。 +---@return table +---@return integer +function m.getFunctionVarArgs(func) + if func.type == 'main' then + return 0, nil + end + if func.type ~= 'function' then + return nil, nil + end + local args = func.args + if not args then + return nil, nil + end + for i = 1, #args do + local arg = args[i] + if arg.type == '...' then + return i, arg + end + end + return nil, nil +end + +--- 获取指定区块中可见的局部变量 +---@param block table +---@param name string {comment = '变量名'} +---@param pos integer {comment = '可见位置'} +function m.getLocal(block, name, pos) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local locals = block.locals + local res + if not locals then + goto CONTINUE + end + for i = 1, #locals do + local loc = locals[i] + if loc.effect > pos then + break + end + if loc[1] == name then + if not res or res.effect < loc.effect then + res = loc + end + end + end + if res then + return res, res + end + ::CONTINUE:: + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +--- 获取指定区块中所有的可见局部变量名称 +function m.getVisibleLocals(block, pos) + local result = {} + m.eachSourceContain(m.getRoot(block), pos, function (source) + local locals = source.locals + if locals then + for i = 1, #locals do + local loc = locals[i] + local name = loc[1] + if loc.effect <= pos then + result[name] = loc + end + end + end + end) + return result +end + +--- 获取指定区块中可见的标签 +---@param block table +---@param name string {comment = '标签名'} +function m.getLabel(block, name) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local labels = block.labels + if labels then + local label = labels[name] + if label then + return label + end + end + if block.type == 'function' then + return nil + end + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +function m.getStartFinish(source) + local start = source.start + local finish = source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.start + finish = last.finish + end + return start, finish +end + +function m.getRange(source) + local start = source.vstart or source.start + local finish = source.range or source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.vstart or first.start + finish = last.range or last.finish + end + return start, finish +end + +--- 判断source是否包含offset +function m.isContain(source, offset) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +--- 判断offset在source的影响范围内 +--- +--- 主要针对赋值等语句时,key包含value +function m.isInRange(source, offset) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +function m.isBetween(source, tStart, tFinish) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +function m.isBetweenRange(source, tStart, tFinish) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +--- 添加child +function m.addChilds(list, obj, map) + local keys = map[obj.type] + if keys then + for i = 1, #keys do + local key = keys[i] + if key == '#' then + for i = 1, #obj do + list[#list+1] = obj[i] + end + elseif obj[key] then + list[#list+1] = obj[key] + elseif type(key) == 'string' + and key:sub(1, 1) == '#' then + key = key:sub(2) + for i = 1, #obj[key] do + list[#list+1] = obj[key][i] + end + end + end + end +end + +--- 遍历所有包含offset的source +function m.eachSourceContain(ast, offset, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isInRange(obj, offset) then + if m.isContain(obj, offset) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有在某个范围内的source +function m.eachSourceBetween(ast, start, finish, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isBetweenRange(obj, start, finish) then + if m.isBetween(obj, start, finish) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有指定类型的source +function m.eachSourceType(ast, type, callback) + local cache = ast.typeCache + if not cache then + cache = {} + ast.typeCache = cache + m.eachSource(ast, function (source) + local tp = source.type + if not tp then + return + end + local myCache = cache[tp] + if not myCache then + myCache = {} + cache[tp] = myCache + end + myCache[#myCache+1] = source + end) + end + local myCache = cache[type] + if not myCache then + return + end + for i = 1, #myCache do + callback(myCache[i]) + end +end + +--- 遍历所有的source +function m.eachSource(ast, callback) + local list = { ast } + local mark = {} + local index = 1 + while true do + local obj = list[index] + if not obj then + return + end + list[index] = false + index = index + 1 + if not mark[obj] then + mark[obj] = true + callback(obj) + m.addChilds(list, obj, m.childMap) + end + end +end + +--- 获取指定的 special +function m.eachSpecialOf(ast, name, callback) + local root = m.getRoot(ast) + if not root.specials then + return + end + local specials = root.specials[name] + if not specials then + return + end + for i = 1, #specials do + callback(specials[i]) + end +end + +--- 获取偏移对应的坐标 +---@param lines table +---@return integer {name = 'row'} +---@return integer {name = 'col'} +function m.positionOf(lines, offset) + if offset < 1 then + return 0, 0 + end + local lastLine = lines[#lines] + if offset > lastLine.finish then + return #lines, lastLine.finish - lastLine.start + 1 + end + local min = 1 + local max = #lines + for _ = 1, 100 do + if max <= min then + local line = lines[min] + return min, offset - line.start + 1 + end + local row = (max - min) // 2 + min + local line = lines[row] + if offset < line.start then + max = row - 1 + elseif offset > line.finish then + min = row + 1 + else + return row, offset - line.start + 1 + end + end + error('Stack overflow!') +end + +--- 获取坐标对应的偏移 +---@param lines table +---@param row integer +---@param col integer +---@return integer {name = 'offset'} +function m.offsetOf(lines, row, col) + if row < 1 then + return 0 + end + if row > #lines then + local lastLine = lines[#lines] + return lastLine.finish + end + local line = lines[row] + local len = line.finish - line.start + 1 + if col < 0 then + return line.start + elseif col > len then + return line.finish + else + return line.start + col - 1 + end +end + +function m.lineContent(lines, text, row, ignoreNL) + local line = lines[row] + if not line then + return '' + end + if ignoreNL then + return text:sub(line.start, line.range) + else + return text:sub(line.start, line.finish) + end +end + +function m.lineRange(lines, row, ignoreNL) + local line = lines[row] + if not line then + return 0, 0 + end + if ignoreNL then + return line.start, line.range + else + return line.start, line.finish + end +end + +function m.getNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'string' then + return obj[1] + end + return nil +end + +function m.getName(obj) + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + return obj.field and obj.field[1] + elseif tp == 'getmethod' + or tp == 'setmethod' then + return obj.method and obj.method[1] + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' then + return obj[1] + elseif tp == 'doc.class' then + return obj.class[1] + elseif tp == 'doc.alias' then + return obj.alias[1] + elseif tp == 'doc.field' then + return obj.field[1] + end + return m.getNameOfLiteral(obj) +end + +function m.getKeyNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return 's|' .. obj[1] + elseif tp == 'string' then + local s = obj[1] + if s then + return 's|' .. s + else + return 's' + end + elseif tp == 'number' then + local n = obj[1] + if n then + return ('n|%s'):format(util.viewLiteral(obj[1])) + else + return 'n' + end + elseif tp == 'boolean' then + local b = obj[1] + if b then + return 'b|' .. tostring(b) + else + return 'b' + end + end + return nil +end + +function m.getKeyName(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return 's|' .. obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return 'l|' .. obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + if obj.field then + return 's|' .. obj.field[1] + end + elseif tp == 'getmethod' + or tp == 'setmethod' then + if obj.method then + return 's|' .. obj.method[1] + end + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' then + return 's|' .. obj[1] + elseif tp == 'doc.class' then + return 's|' .. obj.class[1] + elseif tp == 'doc.alias' then + return 's|' .. obj.alias[1] + elseif tp == 'doc.field' then + return 's|' .. obj.field[1] + end + return m.getKeyNameOfLiteral(obj) +end + +function m.getSimpleName(obj) + if obj.type == 'call' then + local key = obj.args and obj.args[2] + return m.getKeyName(key) + elseif obj.type == 'table' then + return ('t|%p'):format(obj) + elseif obj.type == 'select' then + return ('v|%p'):format(obj) + elseif obj.type == 'string' then + return ('z|%p'):format(obj) + elseif obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' then + return ('c|%s'):format(obj[1]) + elseif obj.type == 'doc.class' then + return ('c|%s'):format(obj.class[1]) + end + return m.getKeyName(obj) +end + +--- 测试 a 到 b 的路径(不经过函数,不考虑 goto), +--- 每个路径是一个 block 。 +--- +--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>` +--- +--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>` +--- +--- 否则返回 `false` +--- +--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。 +---@param a table +---@param b table +---@return string|boolean mode +---@return table pathA? +---@return table pathB? +function m.getPath(a, b, sameFunction) + --- 首先测试双方在同一个函数内 + if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then + return false + end + local mode + local objA + local objB + if a.finish < b.start then + mode = 'before' + objA = a + objB = b + elseif a.start > b.finish then + mode = 'after' + objA = b + objB = a + else + return 'equal', {}, {} + end + local pathA = {} + local pathB = {} + for _ = 1, 1000 do + objA = m.getParentBlock(objA) + pathA[#pathA+1] = objA + if (not sameFunction and objA.type == 'function') or objA.type == 'main' then + break + end + end + for _ = 1, 1000 do + objB = m.getParentBlock(objB) + pathB[#pathB+1] = objB + if (not sameFunction and objA.type == 'function') or objB.type == 'main' then + break + end + end + -- pathA: {1, 2, 3, 4, 5} + -- pathB: {5, 6, 2, 3} + local top = #pathB + local start + for i = #pathA, 1, -1 do + local currentBlock = pathA[i] + if currentBlock == pathB[top] then + start = i + break + end + end + if not start then + return nil + end + -- pathA: { 1, 2, 3} + -- pathB: {5, 6, 2, 3} + local extra = 0 + local align = top - start + for i = start, 1, -1 do + local currentA = pathA[i] + local currentB = pathB[i+align] + if currentA ~= currentB then + extra = i + break + end + end + -- pathA: {1} + local resultA = {} + for i = extra, 1, -1 do + resultA[#resultA+1] = pathA[i] + end + -- pathB: {5, 6} + local resultB = {} + for i = extra + align, 1, -1 do + resultB[#resultB+1] = pathB[i] + end + return mode, resultA, resultB +end + +-- 根据语法,单步搜索定义 +local function stepRefOfLocal(loc, mode) + local results = {} + if loc.start ~= 0 then + results[#results+1] = loc + end + local refs = loc.ref + if not refs then + return results + end + for i = 1, #refs do + local ref = refs[i] + if ref.start == 0 then + goto CONTINUE + end + if mode == 'def' then + if ref.type == 'local' + or ref.type == 'setlocal' then + results[#results+1] = ref + end + else + if ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + end + end + ::CONTINUE:: + end + return results +end + +local function stepRefOfLabel(label, mode) + local results = { label } + if not label or mode == 'def' then + return results + end + local refs = label.ref + for i = 1, #refs do + local ref = refs[i] + results[#results+1] = ref + end + return results +end + +local function stepRefOfDocType(status, obj, mode) + local results = {} + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.alias.name' + or obj.type == 'doc.extends.name' then + local name = obj[1] + if not name or not status.interface.docType then + return results + end + local docs = status.interface.docType(name) + for i = 1, #docs do + local doc = docs[i] + if mode == 'def' then + if doc.type == 'doc.class.name' + or doc.type == 'doc.alias.name' then + results[#results+1] = doc + end + else + results[#results+1] = doc + end + end + else + results[#results+1] = obj + end + return results +end + +function m.getStepRef(status, obj, mode) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepRefOfLocal(obj.node, mode) + end + if obj.type == 'local' then + return stepRefOfLocal(obj, mode) + end + if obj.type == 'label' then + return stepRefOfLabel(obj, mode) + end + if obj.type == 'goto' then + return stepRefOfLabel(obj.node, mode) + end + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.extends.name' + or obj.type == 'doc.alias.name' then + return stepRefOfDocType(status, obj, mode) + end + return nil +end + +-- 根据语法,单步搜索field +local function stepFieldOfLocal(loc) + local results = {} + local refs = loc.ref + for i = 1, #refs do + local ref = refs[i] + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'getlocal' then + local nxt = ref.next + if nxt then + if nxt.type == 'setfield' + or nxt.type == 'getfield' + or nxt.type == 'setmethod' + or nxt.type == 'getmethod' + or nxt.type == 'setindex' + or nxt.type == 'getindex' then + results[#results+1] = nxt + end + end + end + end + return results +end +local function stepFieldOfTable(tbl) + local result = {} + for i = 1, #tbl do + result[i] = tbl[i] + end + return result +end +function m.getStepField(obj) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepFieldOfLocal(obj.node) + end + if obj.type == 'local' then + return stepFieldOfLocal(obj) + end + if obj.type == 'table' then + return stepFieldOfTable(obj) + end +end + +local function convertSimpleList(list) + local simple = {} + for i = #list, 1, -1 do + local c = list[i] + if c.type == 'getglobal' + or c.type == 'setglobal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + local loc = c.node + if loc.special == '_G' then + simple.mode = 'global' + if not simple.node then + simple.node = c + end + else + simple.mode = 'local' + simple[#simple+1] = m.getSimpleName(loc) + if not simple.node then + simple.node = loc + end + end + elseif c.type == 'getlocal' + or c.type == 'setlocal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + simple.mode = 'local' + if not simple.node then + simple.node = c.node + end + elseif c.type == 'local' then + simple.mode = 'local' + if not simple.node then + simple.node = c + end + else + if not simple.node then + simple.node = c + end + end + simple[#simple+1] = m.getSimpleName(c) + ::CONTINUE:: + end + if simple.mode == 'global' and #simple == 0 then + simple[1] = 's|_G' + simple.node = list[#list] + end + return simple +end + +-- 搜索 `a.b.c` 的等价表达式 +local function buildSimpleList(obj, max) + local list = {} + local cur = obj + local limit = max and (max + 1) or 11 + for i = 1, max or limit do + if i == limit then + return nil + end + while cur.type == 'paren' do + cur = cur.exp + if not cur then + return nil + end + end + if cur.type == 'setfield' + or cur.type == 'getfield' + or cur.type == 'setmethod' + or cur.type == 'getmethod' + or cur.type == 'setindex' + or cur.type == 'getindex' then + list[i] = cur + cur = cur.node + elseif cur.type == 'tablefield' + or cur.type == 'tableindex' then + list[i] = cur + cur = cur.parent.parent + if cur.type == 'return' then + list[i+1] = list[i].parent + break + end + elseif cur.type == 'getlocal' + or cur.type == 'setlocal' + or cur.type == 'local' then + list[i] = cur + break + elseif cur.type == 'setglobal' + or cur.type == 'getglobal' then + list[i] = cur + break + elseif cur.type == 'select' + or cur.type == 'table' then + list[i] = cur + break + elseif cur.type == 'string' then + list[i] = cur + break + elseif cur.type == 'doc.class.name' + or cur.type == 'doc.type.name' + or cur.type == 'doc.class' then + list[i] = cur + break + elseif cur.type == 'function' + or cur.type == 'main' then + break + else + return nil + end + end + return convertSimpleList(list) +end + +function m.getSimple(obj, max) + local simpleList + if obj.type == 'getfield' + or obj.type == 'setfield' + or obj.type == 'getmethod' + or obj.type == 'setmethod' + or obj.type == 'getindex' + or obj.type == 'setindex' + or obj.type == 'local' + or obj.type == 'getlocal' + or obj.type == 'setlocal' + or obj.type == 'setglobal' + or obj.type == 'getglobal' + or obj.type == 'tablefield' + or obj.type == 'tableindex' + or obj.type == 'select' + or obj.type == 'table' + or obj.type == 'string' + or obj.type == 'doc.class.name' + or obj.type == 'doc.class' + or obj.type == 'doc.type.name' then + simpleList = buildSimpleList(obj, max) + elseif obj.type == 'field' + or obj.type == 'method' then + simpleList = buildSimpleList(obj.parent, max) + end + return simpleList +end + +function m.status(parentStatus, interface) + local status = { + cache = parentStatus and parentStatus.cache or { + count = 0, + }, + depth = parentStatus and (parentStatus.depth + 1) or 1, + interface = parentStatus and parentStatus.interface or {}, + locks = parentStatus and parentStatus.locks or {}, + deep = parentStatus and parentStatus.deep, + results = {}, + } + status.lock = status.locks[status.depth] or {} + status.locks[status.depth] = status.lock + if interface then + for k, v in pairs(interface) do + status.interface[k] = v + end + end + local searchDepth = status.interface.getSearchDepth and status.interface.getSearchDepth() or 0 + if status.depth >= searchDepth then + status.deep = false + end + return status +end + +function m.copyStatusResults(a, b) + local ra = a.results + local rb = b.results + for i = 1, #rb do + ra[#ra+1] = rb[i] + end +end + +function m.isGlobal(source) + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.type == 'getfield' + or source.type == 'setfield' then + local node = source.node + if node and node.special == '_G' then + return true + end + end + return false +end + +function m.isDoc(source) + return source.type:sub(1, 4) == 'doc.' +end + +--- 根据函数的调用参数,获取:调用,参数索引 +function m.getCallAndArgIndex(callarg) + local callargs = callarg.parent + if not callargs or callargs.type ~= 'callargs' then + return nil + end + local index + for i = 1, #callargs do + if callargs[i] == callarg then + index = i + break + end + end + local call = callargs.parent + return call, index +end + +--- 根据函数调用的返回值,获取:调用的函数,参数列表,自己是第几个返回值 +function m.getCallValue(source) + local value = m.getObjectValue(source) or source + if not value then + return + end + local call, index + if value.type == 'call' then + call = value + index = 1 + elseif value.type == 'select' then + call = value.vararg + index = value.index + if call.type ~= 'call' then + return + end + else + return + end + return call.node, call.args, index +end + +function m.getNextRef(ref) + local nextRef = ref.next + if nextRef then + if nextRef.type == 'setfield' + or nextRef.type == 'getfield' + or nextRef.type == 'setmethod' + or nextRef.type == 'getmethod' + or nextRef.type == 'setindex' + or nextRef.type == 'getindex' then + return nextRef + end + end + -- 穿透 rawget 与 rawset + local call, index = m.getCallAndArgIndex(ref) + if call then + if call.node.special == 'rawset' and index == 1 then + return call + end + if call.node.special == 'rawget' and index == 1 then + return call + end + end + -- doc.type.array + if ref.type == 'doc.type' then + local arrays = {} + for _, typeUnit in ipairs(ref.types) do + if typeUnit.type == 'doc.type.array' then + arrays[#arrays+1] = typeUnit.node + end + end + -- 返回一个 dummy + -- TODO 用弱表维护唯一性? + return { + type = 'doc.type', + start = ref.start, + finish = ref.finish, + types = arrays, + parent = ref.parent, + array = true, + enums = {}, + resumes = {}, + } + end + + return nil +end + +function m.checkSameSimpleInValueOfTable(status, value, start, queue) + if value.type ~= 'table' then + return + end + for i = 1, #value do + local field = value[i] + queue[#queue+1] = { + obj = field, + start = start + 1, + } + end +end + +function m.searchFields(status, obj, key) + local simple = m.getSimple(obj) + if not simple then + return + end + simple[#simple+1] = key and ('s|' .. key) or '*' + m.searchSameFields(status, simple, 'field') + m.cleanResults(status.results) +end + +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args[3] + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +function m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + local newStatus = m.status(status) + m.searchFields(newStatus, mt, '__index') + local refsStatus = m.status(status) + for i = 1, #newStatus.results do + local indexValue = m.getObjectValue(newStatus.results[i]) + if indexValue then + m.searchRefs(refsStatus, indexValue, 'ref') + end + end + for i = 1, #refsStatus.results do + local obj = refsStatus.results[i] + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end +function m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + if not func or func.special ~= 'setmetatable' then + return + end + local call = func.parent + local args = call.args + local obj = args[1] + local mt = args[2] + if obj then + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + if mt then + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.checkSameSimpleInValueOfCallMetaTable(status, call, start, queue) + if call.type == 'call' then + m.checkSameSimpleInValueOfSetMetaTable(status, call.node, start, queue) + end +end + +function m.checkSameSimpleInSpecialBranch(status, obj, start, queue) + if not status.interface.index then + return + end + local results = status.interface.index(obj) + if not results then + return + end + for _, res in ipairs(results) do + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end +end + +function m.checkSameSimpleByDocType(status, doc) + if status.cache.searchingBindedDoc then + return + end + if doc.type ~= 'doc.type' then + return + end + local results = {} + for _, piece in ipairs(doc.types) do + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + results[#results+1] = res + end + end + return results +end + +function m.checkSameSimpleByBindDocs(status, obj, start, queue, mode) + if not obj.bindDocs then + return + end + if status.cache.searchingBindedDoc then + return + end + local skipInfer = false + local results = {} + for _, doc in ipairs(obj.bindDocs) do + if doc.type == 'doc.class' then + results[#results+1] = doc + elseif doc.type == 'doc.type' then + results[#results+1] = doc + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if obj.type == 'local' + and m.getName(obj) == doc.param[1] then + if obj.parent.type == 'funcargs' + or obj.parent.type == 'in' + or obj.parent.type == 'loop' then + results[#results+1] = doc.extends + end + end + elseif doc.type == 'doc.field' then + results[#results+1] = doc + end + end + for _, res in ipairs(results) do + if res.type == 'doc.class' + or res.type == 'doc.type' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + skipInfer = true + end + if res.type == 'doc.type.function' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + elseif res.type == 'doc.field' then + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end + end + return skipInfer +end + +function m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if status.cache.searchingBindedDoc then + return + end + if not obj.bindSources then + return + end + status.cache.searchingBindedDoc = true + local mark = {} + local newStatus = m.status(status) + for _, ref in ipairs(obj.bindSources) do + if not mark[ref] then + mark[ref] = true + m.searchRefs(newStatus, ref, mode) + end + end + status.cache.searchingBindedDoc = nil + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSameSimpleByDoc(status, obj, start, queue, mode) + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' then + obj = m.getDocState(obj) + end + if obj.type == 'doc.class' then + local classStart + for _, doc in ipairs(obj.bindGroup) do + if doc == obj then + classStart = true + elseif doc.type == 'doc.class' then + classStart = false + end + if classStart and doc.type == 'doc.field' then + queue[#queue+1] = { + obj = doc, + start = start + 1, + } + end + end + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if mode == 'ref' then + local pieceResult = stepRefOfDocType(status, obj.class, 'ref') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + return true + elseif obj.type == 'doc.type' then + for _, piece in ipairs(obj.types) do + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + if mode == 'ref' then + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + end + return true + elseif obj.type == 'doc.field' then + if mode ~= 'field' then + return m.checkSameSimpleByDoc(status, obj.extends, start, queue, mode) + end + end +end + +function m.checkSameSimpleInArg1OfSetMetaTable(status, obj, start, queue) + local args = obj.parent + if not args or args.type ~= 'callargs' then + return + end + if args[1] ~= obj then + return + end + local mt = args[2] + if mt then + if m.checkValueMark(status, obj, mt) then + return + end + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.searchSameMethodCrossSelf(ref, mark) + local selfNode + if ref.tag == 'self' then + selfNode = ref + else + if ref.type == 'getlocal' + or ref.type == 'setlocal' then + local node = ref.node + if node.tag == 'self' then + selfNode = node + end + end + end + if selfNode then + if mark[selfNode] then + return nil + end + mark[selfNode] = true + return selfNode.method.node + end +end + +function m.searchSameMethod(ref, mark) + if mark['method'] then + return nil + end + local nxt = ref.next + if not nxt then + return nil + end + if nxt.type == 'setmethod' then + mark['method'] = true + return ref + end + return nil +end + +function m.searchSameFieldsCrossMethod(status, ref, start, queue) + local mark = status.cache.crossMethodMark + if not mark then + mark = {} + status.cache.crossMethodMark = mark + end + local method = m.searchSameMethod(ref, mark) + or m.searchSameMethodCrossSelf(ref, mark) + if not method then + return + end + local methodStatus = m.status(status) + m.searchRefs(methodStatus, method, 'ref') + for _, md in ipairs(methodStatus.results) do + queue[#queue+1] = { + obj = md, + start = start, + force = true, + } + local nxt = md.next + if not nxt then + goto CONTINUE + end + if nxt.type == 'setmethod' then + local func = nxt.value + if not func then + goto CONTINUE + end + local selfNode = func.locals and func.locals[1] + if not selfNode or not selfNode.ref then + goto CONTINUE + end + if mark[selfNode] then + goto CONTINUE + end + mark[selfNode] = true + for _, selfRef in ipairs(selfNode.ref) do + queue[#queue+1] = { + obj = selfRef, + start = start, + force = true, + } + end + end + ::CONTINUE:: + end +end + +local function checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, source, index, call) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local types = m.checkSameSimpleByDocType(status, rtn) + if not types then + return + end + for _, res in ipairs(types) do + results[#results+1] = res + end + return true +end + +local function checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local types = m.checkSameSimpleByDocType(status, rtn) + if types then + for _, res in ipairs(types) do + results[#results+1] = res + end + return true + end + end + end + end + end + end +end + +function m.checkSameSimpleInCallInSameFile(status, func, args, index) + local newStatus = m.status(status) + m.searchRefs(newStatus, func, 'def') + local results = {} + for _, def in ipairs(newStatus.results) do + local hasDocReturn = checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, def, index) + or checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, def, index) + if not hasDocReturn then + local value = m.getObjectValue(def) or def + if value.type == 'function' then + local returns = value.returns + if returns then + for _, ret in ipairs(returns) do + local exp = ret[index] + if exp then + results[#results+1] = exp + end + end + end + end + end + end + return results +end + +function m.checkSameSimpleInCall(status, ref, start, queue, mode) + local func, args, index = m.getCallValue(ref) + if not func then + return + end + if m.checkCallMark(status, func.parent, true) then + return + end + status.cache.crossCallCount = status.cache.crossCallCount or 0 + if status.cache.crossCallCount >= 5 then + return + end + status.cache.crossCallCount = status.cache.crossCallCount + 1 + -- 检查赋值是 semetatable() 的情况 + m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + -- 检查赋值是 func() 的情况 + local objs = m.checkSameSimpleInCallInSameFile(status, func, args, index) + if status.interface.call then + local cobjs = status.interface.call(func, args, index) + if cobjs then + for _, obj in ipairs(cobjs) do + if not m.checkReturnMark(status, obj) then + objs[#objs+1] = obj + end + end + end + end + m.cleanResults(objs) + local newStatus = m.status(status) + for _, obj in ipairs(objs) do + m.searchRefs(newStatus, obj, mode) + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + status.cache.crossCallCount = status.cache.crossCallCount - 1 + for _, obj in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end + +local function searchRawset(ref, results) + if m.getKeyName(ref) ~= 's|rawset' then + return + end + local call = ref.parent + if call.type ~= 'call' or call.node ~= ref then + return + end + if not call.args then + return + end + local arg1 = call.args[1] + if arg1.special ~= '_G' then + -- 不会吧不会吧,不会真的有人写成 `rawset(_G._G._G, 'xxx', value)` 吧 + return + end + results[#results+1] = call +end + +local function searchG(ref, results) + while ref and m.getKeyName(ref) == 's|_G' do + results[#results+1] = ref + ref = ref.next + end + if ref then + results[#results+1] = ref + searchRawset(ref, results) + end +end + +local function searchEnvRef(ref, results) + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + searchG(ref, results) + elseif ref.type == 'getlocal' then + results[#results+1] = ref.next + searchG(ref.next, results) + end +end + +function m.findGlobals(ast) + local root = m.getRoot(ast) + local results = {} + local env = m.getENV(root) + if env.ref then + for _, ref in ipairs(env.ref) do + searchEnvRef(ref, results) + end + end + return results +end + +function m.findGlobalsOfName(ast, name) + local root = m.getRoot(ast) + local results = {} + local globals = m.findGlobals(root) + for _, global in ipairs(globals) do + if m.getKeyName(global) == name then + results[#results+1] = global + end + end + return results +end + +function m.checkSameSimpleInGlobal(status, name, source, start, queue) + if not name then + return + end + local objs + if status.interface.global then + objs = status.interface.global(name) + else + objs = m.findGlobalsOfName(source, name) + end + if objs then + for _, obj in ipairs(objs) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + end +end + +function m.checkValueMark(status, a, b) + if not status.cache.valueMark then + status.cache.valueMark = {} + end + if status.cache.valueMark[a] + or status.cache.valueMark[b] then + return true + end + status.cache.valueMark[a] = true + status.cache.valueMark[b] = true + return false +end + +function m.checkCallMark(status, a, mark) + if not status.cache.callMark then + status.cache.callMark = {} + end + if mark then + status.cache.callMark[a] = mark + else + return status.cache.callMark[a] + end + return false +end + +function m.checkReturnMark(status, a, mark) + if not status.cache.returnMark then + status.cache.returnMark = {} + end + if mark then + status.cache.returnMark[a] = mark + else + return status.cache.returnMark[a] + end + return false +end + +function m.searchSameFieldsInValue(status, ref, start, queue, mode) + local value = m.getObjectValue(ref) + if not value then + return + end + if m.checkValueMark(status, ref, value) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, value, mode) + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + queue[#queue+1] = { + obj = value, + start = start, + force = true, + } + -- 检查形如 a = f() 的分支情况 + m.checkSameSimpleInCall(status, value, start, queue, mode) +end + +function m.checkSameSimpleAsTableField(status, ref, start, queue) + if not status.deep then + --return + end + local parent = ref.parent + if not parent or parent.type ~= 'tablefield' then + return + end + if m.checkValueMark(status, parent, ref) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, parent.field, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSearchLevel(status) + status.cache.back = status.cache.back or 0 + if status.cache.back >= (status.interface.searchLevel or 0) then + -- TODO 限制向前搜索的次数 + --return true + end + status.cache.back = status.cache.back + 1 + return false +end + +function m.checkSameSimpleAsReturn(status, ref, start, queue) + if not status.deep then + return + end + if not ref.parent or ref.parent.type ~= 'return' then + return + end + if ref.parent.parent.type ~= 'main' then + return + end + if m.checkSearchLevel(status) then + return + end + local newStatus = m.status(status) + m.searchRefsAsFunctionReturn(newStatus, ref, 'ref') + for _, res in ipairs(newStatus.results) do + if not m.checkCallMark(status, res) then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end +end + +function m.checkSameSimpleAsSetValue(status, ref, start, queue) + if ref.type == 'select' then + return + end + local parent = ref.parent + if not parent then + return + end + if m.getObjectValue(parent) ~= ref then + return + end + if m.checkValueMark(status, ref, parent) then + return + end + if m.checkSearchLevel(status) then + return + end + local obj + if parent.type == 'local' + or parent.type == 'setglobal' + or parent.type == 'setlocal' then + obj = parent + elseif parent.type == 'setfield' then + obj = parent.field + elseif parent.type == 'setmethod' then + obj = parent.method + end + if not obj then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, obj, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSameSimpleInString(status, ref, start, queue, mode) + -- 特殊处理 ('xxx').xxx 的形式 + if ref.type ~= 'string' then + return + end + if not status.interface.docType then + return + end + if status.cache.searchingBindedDoc then + return + end + local newStatus = m.status(status) + local docs = status.interface.docType('string*') + local mark = {} + for i = 1, #docs do + local doc = docs[i] + m.searchFields(newStatus, doc) + end + for _, res in ipairs(newStatus.results) do + if mark[res] then + goto CONTINUE + end + mark[res] = true + queue[#queue+1] = { + obj = res, + start = start + 1, + } + ::CONTINUE:: + end + return true +end + +function m.pushResult(status, mode, ref, simple) + local results = status.results + if mode == 'def' then + if ref.type == 'setglobal' + or ref.type == 'setlocal' + or ref.type == 'local' then + results[#results+1] = ref + elseif ref.type == 'setfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' then + results[#results+1] = ref + end + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + if m.getParentFunction(ref) ~= m.getParentFunction(simple.node) then + results[#results+1] = ref + end + end + elseif mode == 'ref' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' + or ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + results[#results+1] = ref + end + elseif mode == 'field' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + end +end + +function m.checkSameSimpleName(ref, sm) + if sm == '*' then + return true + end + if m.getSimpleName(ref) == sm then + return true + end + if ref.type == 'doc.type' + and ref.array == true then + return true + end + return false +end + +function m.checkSameSimple(status, simple, data, mode, queue) + local ref = data.obj + local start = data.start + local force = data.force + if start > #simple then + return + end + for i = start, #simple do + local sm = simple[i] + if not force and not m.checkSameSimpleName(ref, sm) then + return + end + force = false + local cmode = mode + if i < #simple then + cmode = 'ref' + end + -- 检查 doc + local skipInfer = m.checkSameSimpleByBindDocs(status, ref, i, queue, cmode) + or m.checkSameSimpleByDoc(status, ref, i, queue, cmode) + if not skipInfer then + -- 穿透 self:func 与 mt:func + m.searchSameFieldsCrossMethod(status, ref, i, queue) + -- 穿透赋值 + m.searchSameFieldsInValue(status, ref, i, queue, cmode) + -- 检查自己是字面量表的情况 + m.checkSameSimpleInValueOfTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 第一个参数的情况 + m.checkSameSimpleInArg1OfSetMetaTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 调用的情况 + m.checkSameSimpleInValueOfCallMetaTable(status, ref, i, queue) + -- 检查自己是特殊变量的分支的情况 + m.checkSameSimpleInSpecialBranch(status, ref, i, queue) + -- 检查自己是字面量字符串的分支情况 + m.checkSameSimpleInString(status, ref, i, queue, cmode) + if cmode == 'ref' then + -- 检查形如 { a = f } 的情况 + m.checkSameSimpleAsTableField(status, ref, i, queue) + -- 检查形如 return m 的情况 + m.checkSameSimpleAsReturn(status, ref, i, queue) + -- 检查形如 a = f 的情况 + m.checkSameSimpleAsSetValue(status, ref, i, queue) + end + end + if i == #simple then + break + end + ref = m.getNextRef(ref) + if not ref then + return + end + end + m.pushResult(status, mode, ref, simple) + local value = m.getObjectValue(ref) + if value then + m.pushResult(status, mode, value, simple) + end +end + +function m.searchSameFields(status, simple, mode) + local queue = {} + if simple.mode == 'global' then + -- 全局变量开头 + m.checkSameSimpleInGlobal(status, simple[1], simple.node, 1, queue) + elseif simple.mode == 'local' then + -- 局部变量开头 + queue[1] = { + obj = simple.node, + start = 1, + } + local refs = simple.node.ref + if refs then + for i = 1, #refs do + queue[#queue+1] = { + obj = refs[i], + start = 1, + } + end + end + else + queue[1] = { + obj = simple.node, + start = 1, + } + end + local max = 0 + for i = 1, 1e6 do + local data = queue[i] + if not data then + return + end + if not status.lock[data.obj] then + status.lock[data.obj] = true + max = max + 1 + status.cache.count = status.cache.count + 1 + m.checkSameSimple(status, simple, data, mode, queue) + if max >= 10000 then + logWarn('Queue too large!') + break + end + end + end +end + +function m.getCallerInSameFile(status, func) + -- 搜索所有所在函数的调用者 + local funcRefs = m.status(status) + m.searchRefOfValue(funcRefs, func) + + local calls = {} + if #funcRefs.results == 0 then + return calls + end + for _, res in ipairs(funcRefs.results) do + local call = res.parent + if call.type == 'call' then + calls[#calls+1] = call + end + end + return calls +end + +function m.getCallerCrossFiles(status, main) + if status.interface.link then + return status.interface.link(main.uri) + end + return {} +end + +function m.searchRefsAsFunctionReturn(status, obj, mode) + if mode == 'def' then + return + end + if m.checkReturnMark(status, obj, true) then + return + end + status.results[#status.results+1] = obj + -- 搜索所在函数 + local currentFunc = m.getParentFunction(obj) + local rtn = obj.parent + if rtn.type ~= 'return' then + return + end + -- 看看他是第几个返回值 + local index + for i = 1, #rtn do + if obj == rtn[i] then + index = i + break + end + end + if not index then + return + end + local calls + if currentFunc.type == 'main' then + calls = m.getCallerCrossFiles(status, currentFunc) + else + calls = m.getCallerInSameFile(status, currentFunc) + end + -- 搜索调用者的返回值 + if #calls == 0 then + return + end + local selects = {} + for i = 1, #calls do + local parent = calls[i].parent + if parent.type == 'select' and parent.index == index then + selects[#selects+1] = parent.parent + end + local extParent = calls[i].extParent + if extParent then + for j = 1, #extParent do + local ext = extParent[j] + if ext.type == 'select' and ext.index == index then + selects[#selects+1] = ext.parent + end + end + end + end + -- 搜索调用者的引用 + for i = 1, #selects do + m.searchRefs(status, selects[i], 'ref') + end +end + +function m.searchRefsAsFunctionSet(status, obj, mode) + local parent = obj.parent + if not parent then + return + end + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' + or parent.type == 'setfield' + or parent.type == 'setmethod' + or parent.type == 'tablefield' then + m.searchRefs(status, parent, mode) + elseif parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == obj then + m.searchRefs(status, parent, mode) + end + end +end + +function m.searchRefsAsFunction(status, obj, mode) + if obj.type ~= 'function' + and obj.type ~= 'table' then + return + end + m.searchRefsAsFunctionSet(status, obj, mode) + -- 检查自己作为返回函数时的引用 + m.searchRefsAsFunctionReturn(status, obj, mode) +end + +function m.cleanResults(results) + local mark = {} + for i = #results, 1, -1 do + local res = results[i] + if res.tag == 'self' + or mark[res] then + results[i] = results[#results] + results[#results] = nil + else + mark[res] = true + end + end +end + +--function m.getRefCache(status, obj, mode) +-- local cache = status.interface.cache and status.interface.cache() +-- if not cache then +-- return +-- end +-- if m.isGlobal(obj) then +-- obj = m.getKeyName(obj) +-- end +-- if not cache[mode] then +-- cache[mode] = {} +-- end +-- local sourceCache = cache[mode][obj] +-- if sourceCache then +-- return sourceCache +-- end +-- sourceCache = {} +-- cache[mode][obj] = sourceCache +-- return nil, function (results) +-- for i = 1, #results do +-- sourceCache[i] = results[i] +-- end +-- end +--end + +function m.getRefCache(status, obj, mode) + local cache, globalCache + if status.depth == 1 + and status.deep then + globalCache = status.interface.cache and status.interface.cache() or {} + end + cache = status.cache.refCache or {} + status.cache.refCache = cache + if m.isGlobal(obj) then + obj = m.getKeyName(obj) + end + if not cache[mode] then + cache[mode] = {} + end + if globalCache and not globalCache[mode] then + globalCache[mode] = {} + end + local sourceCache = globalCache and globalCache[mode][obj] or cache[mode][obj] + if sourceCache then + return sourceCache + end + sourceCache = {} + cache[mode][obj] = sourceCache + if globalCache then + globalCache[mode][obj] = sourceCache + end + return nil, function (results) + for i = 1, #results do + sourceCache[i] = results[i] + end + end +end + +function m.searchRefs(status, obj, mode) + local cache, makeCache = m.getRefCache(status, obj, mode) + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + -- 检查单步引用 + local res = m.getStepRef(status, obj, mode) + if res then + for i = 1, #res do + status.results[#status.results+1] = res[i] + end + end + -- 检查simple + if status.depth <= 100 then + local simple = m.getSimple(obj) + if simple then + m.searchSameFields(status, simple, mode) + end + else + if m.debugMode then + error('status.depth overflow') + elseif DEVELOP then + --log.warn(debug.traceback('status.depth overflow')) + logWarn('status.depth overflow') + end + end + + m.cleanResults(status.results) + + if makeCache then + makeCache(status.results) + end +end + +function m.searchRefOfValue(status, obj) + local var = obj.parent + if var.type == 'local' + or var.type == 'set' then + return m.searchRefs(status, var, 'ref') + end +end + +function m.allocInfer(o) + if type(o.type) == 'table' then + local infers = {} + for i = 1, #o.type do + infers[i] = { + type = o.type[i], + value = o.value, + source = o.source, + } + end + return infers + else + return { + [1] = o, + } + end +end + +function m.mergeTypes(types) + local results = {} + local mark = {} + local hasAny + -- 这里把 any 去掉 + for i = 1, #types do + local tp = types[i] + if tp == 'any' then + hasAny = true + end + if not mark[tp] and tp ~= 'any' then + mark[tp] = true + results[#results+1] = tp + end + end + if #results == 0 then + return 'any' + end + -- 只有显性的 nil 与 any 时,取 any + if #results == 1 then + if results[1] == 'nil' and hasAny then + return 'any' + else + return results[1] + end + end + -- 同时包含 number 与 integer 时,去掉 integer + if mark['number'] and mark['integer'] then + for i = 1, #results do + if results[i] == 'integer' then + tableRemove(results, i) + break + end + end + end + tableSort(results, function (a, b) + local sa = TypeSort[a] or 100 + local sb = TypeSort[b] or 100 + return sa < sb + end) + return tableConcat(results, '|') +end + +function m.viewInferType(infers) + if not infers then + return 'any' + end + local mark = {} + local types = {} + local hasDoc + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' then + if infer.type ~= 'any' then + hasDoc = true + break + end + end + end + if hasDoc then + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' + or src.type == 'doc.type.enum' + or src.type == 'doc.resume' then + local tp = infer.type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + end + end + else + for i = 1, #infers do + local tp = infers[i].type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + end + end + return m.mergeTypes(types) +end + +function m.checkTrue(status, source) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + -- 当前认为的结果 + local current + for _, infer in ipairs(newStatus.results) do + -- 新的结果 + local new + if infer.type == 'nil' then + new = false + elseif infer.type == 'boolean' then + if infer.value == true then + new = true + elseif infer.value == false then + new = false + end + end + if new ~= nil then + if current == nil then + current = new + else + -- 如果2个结果完全相反,则返回 nil 表示不确定 + if new ~= current then + return nil + end + end + end + end + return current +end + +--- 获取特定类型的字面量值 +function m.getInferLiteral(status, source, type) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + for _, infer in ipairs(newStatus.results) do + if infer.value ~= nil then + if type == nil or infer.type == type then + return infer.value + end + end + end + return nil +end + +--- 是否包含某种类型 +function m.hasType(status, source, type) + m.searchInfer(status, source) + for _, infer in ipairs(status.results) do + if infer.type == type then + return true + end + end + return false +end + +function m.isSameValue(status, a, b) + local statusA = m.status(status) + m.searchInfer(statusA, a) + local statusB = m.status(status) + m.searchInfer(statusB, b) + local infers = {} + for _, infer in ipairs(statusA.results) do + local literal = infer.value + if literal then + infers[literal] = false + end + end + for _, infer in ipairs(statusB.results) do + local literal = infer.value + if literal then + if infers[literal] == nil then + return false + end + infers[literal] = true + end + end + for k, v in pairs(infers) do + if v == false then + return false + end + end + return true +end + +function m.inferCheckLiteralTableWithDocVararg(status, source) + if #source ~= 1 then + return + end + local vararg = source[1] + if vararg.type ~= 'varargs' then + return + end + local results = m.getVarargDocType(status, source) + status.results[#status.results+1] = { + type = m.viewInferType(results) .. '[]', + source = source, + } + return true +end + +function m.inferCheckLiteral(status, source) + if source.type == 'string' then + status.results = m.allocInfer { + type = 'string', + value = source[1], + source = source, + } + return true + elseif source.type == 'nil' then + status.results = m.allocInfer { + type = 'nil', + value = NIL, + source = source, + } + return true + elseif source.type == 'boolean' then + status.results = m.allocInfer { + type = 'boolean', + value = source[1], + source = source, + } + return true + elseif source.type == 'number' then + if mathType(source[1]) == 'integer' then + status.results = m.allocInfer { + type = 'integer', + value = source[1], + source = source, + } + return true + else + status.results = m.allocInfer { + type = 'number', + value = source[1], + source = source, + } + return true + end + elseif source.type == 'integer' then + status.results = m.allocInfer { + type = 'integer', + source = source, + } + return true + elseif source.type == 'table' then + if m.inferCheckLiteralTableWithDocVararg(status, source) then + return true + end + status.results = m.allocInfer { + type = 'table', + source = source, + } + return true + elseif source.type == 'function' then + status.results = m.allocInfer { + type = 'function', + source = source, + } + return true + elseif source.type == '...' then + status.results = m.allocInfer { + type = '...', + source = source, + } + return true + end +end + +local function getDocAliasExtends(status, name) + if not status.interface.docType then + return nil + end + for _, doc in ipairs(status.interface.docType(name)) do + if doc.type == 'doc.alias.name' then + return m.viewInferType(m.getDocTypeNames(status, doc.parent.extends)) + end + end + return nil +end + +local function getDocTypeUnitName(status, unit, genericCallback) + local typeName + if unit.type == 'doc.type.name' then + typeName = getDocAliasExtends(status, unit[1]) or unit[1] + elseif unit.type == 'doc.type.function' then + typeName = 'function' + elseif unit.type == 'doc.type.array' then + typeName = getDocTypeUnitName(status, unit.node, genericCallback) .. '[]' + elseif unit.type == 'doc.type.generic' then + typeName = ('%s<%s, %s>'):format( + getDocTypeUnitName(status, unit.node, genericCallback), + m.viewInferType(m.getDocTypeNames(status, unit.key, genericCallback)), + m.viewInferType(m.getDocTypeNames(status, unit.value, genericCallback)) + ) + end + if unit.typeGeneric then + if genericCallback then + typeName = genericCallback(typeName, unit) + or ('<%s>'):format(typeName) + else + typeName = ('<%s>'):format(typeName) + end + end + return typeName +end + +function m.getDocTypeNames(status, doc, genericCallback) + local results = {} + if not doc then + return results + end + for _, unit in ipairs(doc.types) do + local typeName = getDocTypeUnitName(status, unit, genericCallback) + results[#results+1] = { + type = typeName, + source = unit, + } + end + for _, enum in ipairs(doc.enums) do + results[#results+1] = { + type = enum[1], + source = enum, + } + end + for _, resume in ipairs(doc.resumes) do + if not resume.additional then + results[#results+1] = { + type = resume[1], + source = resume, + } + end + end + return results +end + +function m.inferCheckDoc(status, source) + if source.type == 'doc.class.name' then + status.results[#status.results+1] = { + type = source[1], + source = source, + } + return true + end + if source.type == 'doc.class' then + status.results[#status.results+1] = { + type = source.class[1], + source = source, + } + return true + end + if source.type == 'doc.type' then + local results = m.getDocTypeNames(status, source) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + if source.type == 'doc.field' then + local results = m.getDocTypeNames(status, source.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end +end + +function m.getVarargDocType(status, source) + local func = m.getParentFunction(source) + if not func then + return + end + if not func.args then + return + end + for _, arg in ipairs(func.args) do + if arg.type == '...' then + if arg.bindDocs then + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.vararg' then + return m.getDocTypeNames(status, doc.vararg) + end + end + end + end + end +end + +function m.inferCheckUpDocOfVararg(status, source) + if not source.vararg then + return + end + local results = m.getVarargDocType(status, source) + if not results then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +function m.inferCheckUpDoc(status, source) + if m.inferCheckUpDocOfVararg(status, source) then + return true + end + local parent = source.parent + if parent then + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' then + source = parent + end + if parent.type == 'setfield' + or parent.type == 'tablefield' then + if parent.field == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setmethod' then + if parent.method == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == source + or parent.value == source then + source = parent + end + end + end + local binds = source.bindDocs + if not binds then + return + end + status.results = {} + for _, doc in ipairs(binds) do + if doc.type == 'doc.class' then + status.results[#status.results+1] = { + type = doc.class[1], + source = doc, + } + -- ---@class Class + -- local x = { field = 1 } + -- 这种情况下,将字面量表接受为Class的定义 + if source.value and source.value.type == 'table' then + status.results[#status.results+1] = { + type = source.value.type, + source = source.value, + } + end + return true + elseif doc.type == 'doc.type' then + local results = m.getDocTypeNames(status, doc) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if source.type == 'local' + and m.getName(source) == doc.param[1] then + if source.parent.type == 'funcargs' + or source.parent.type == 'in' + or source.parent.type == 'loop' then + local results = m.getDocTypeNames(status, doc.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + end + end + end +end + +function m.inferCheckFieldDoc(status, source) + -- 检查 string[] 的情况 + if source.type == 'getindex' then + local node = source.node + if not node then + return + end + local newStatus = m.status(status) + m.searchInfer(newStatus, node) + local ok + for _, infer in ipairs(newStatus.results) do + local src = infer.source + if src.type == 'doc.type.array' then + ok = true + status.results[#status.results+1] = { + type = infer.type:gsub('%[%]$', ''), + source = src.node, + } + end + end + return ok + end +end + +function m.inferCheckUnary(status, source) + if source.type ~= 'unary' then + return + end + local op = source.op + if op.type == 'not' then + local checkTrue = m.checkTrue(status, source[1]) + local value = nil + if checkTrue == true then + value = false + elseif checkTrue == false then + value = true + end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '#' then + status.results = m.allocInfer { + type = 'integer', + source = source, + } + return true + elseif op.type == '~' then + local l = m.getInferLiteral(status, source[1], 'integer') + status.results = m.allocInfer { + type = 'integer', + value = l and ~l or nil, + source = source, + } + return true + elseif op.type == '-' then + local v = m.getInferLiteral(status, source[1], 'integer') + if v then + status.results = m.allocInfer { + type = 'integer', + value = - v, + source = source, + } + return true + end + v = m.getInferLiteral(status, source[1], 'number') + status.results = m.allocInfer { + type = 'number', + value = v and -v or nil, + source = source, + } + return true + end +end + +local function mathCheck(status, a, b) + local v1 = m.getInferLiteral(status, a, 'integer') + or m.getInferLiteral(status, a, 'number') + local v2 = m.getInferLiteral(status, b, 'integer') + or m.getInferLiteral(status, a, 'number') + local int = m.hasType(status, a, 'integer') + and m.hasType(status, b, 'integer') + and not m.hasType(status, a, 'number') + and not m.hasType(status, b, 'number') + return int and 'integer' or 'number', v1, v2 +end + +function m.inferCheckBinary(status, source) + if source.type ~= 'binary' then + return + end + local op = source.op + if op.type == 'and' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[2]) + return true + elseif isTrue == false then + m.searchInfer(status, source[1]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == 'or' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[1]) + return true + elseif isTrue == false then + m.searchInfer(status, source[2]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == '==' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = true + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '~=' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = not value, + source = source, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = false + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '<=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 <= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '>=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 >= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 < v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '|' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 | v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '~' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 ~ v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '&' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 & v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '<<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 << v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '>>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 >> v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '..' then + local v1 = m.getInferLiteral(status, source[1], 'string') + local v2 = m.getInferLiteral(status, source[2], 'string') + local v + if v1 and v2 then + v = v1 .. v2 + end + status.results = m.allocInfer { + type = 'string', + value = v, + source = source, + } + return true + elseif op.type == '^' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 ^ v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + } + return true + elseif op.type == '/' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + } + return true + -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数 + elseif op.type == '+' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 + v2) or nil, + source = source, + } + return true + elseif op.type == '-' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 - v2) or nil, + source = source, + } + return true + elseif op.type == '*' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 * v2) or nil, + source = source, + } + return true + elseif op.type == '%' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 % v2) or nil, + source = source, + } + return true + elseif op.type == '//' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 // v2) or nil, + source = source, + } + return true + end +end + +function m.inferByDef(status, obj) + if not status.cache.inferedDef then + status.cache.inferedDef = {} + end + if status.cache.inferedDef[obj] then + return + end + status.cache.inferedDef[obj] = true + local mark = {} + local newStatus = m.status(status, status.interface) + m.searchRefs(newStatus, obj, 'def') + for _, src in ipairs(newStatus.results) do + local inferStatus = m.status(newStatus) + m.searchInfer(inferStatus, src) + for _, infer in ipairs(inferStatus.results) do + if not mark[infer.source] then + mark[infer.source] = true + status.results[#status.results+1] = infer + end + end + end +end + +local function inferBySetOfLocal(status, source) + if status.cache[source] then + return + end + status.cache[source] = true + local newStatus = m.status(status) + if source.value then + m.searchInfer(newStatus, source.value) + end + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + break + end + m.searchInfer(newStatus, ref) + end + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end +end + +function m.inferBySet(status, source) + if #status.results ~= 0 then + return + end + if source.type == 'local' then + inferBySetOfLocal(status, source) + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + inferBySetOfLocal(status, source.node) + end +end + +function m.inferByCall(status, source) + if #status.results ~= 0 then + return + end + if not source.parent then + return + end + if source.parent.type ~= 'call' then + return + end + if source.parent.node == source then + status.results[#status.results+1] = { + type = 'function', + source = source, + } + return + end +end + +function m.inferByGetTable(status, source) + if #status.results ~= 0 then + return + end + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local next = source.next + if not next then + return + end + if next.type == 'getfield' + or next.type == 'getindex' + or next.type == 'getmethod' + or next.type == 'setfield' + or next.type == 'setindex' + or next.type == 'setmethod' then + status.results[#status.results+1] = { + type = 'table', + source = source, + } + end +end + +function m.inferByUnary(status, source) + if #status.results ~= 0 then + return + end + local parent = source.parent + if not parent or parent.type ~= 'unary' then + return + end + local op = parent.op + if op.type == '#' then + status.results[#status.results+1] = { + type = 'string', + source = source + } + status.results[#status.results+1] = { + type = 'table', + source = source + } + elseif op.type == '~' then + status.results[#status.results+1] = { + type = 'integer', + source = source + } + elseif op.type == '-' then + status.results[#status.results+1] = { + type = 'number', + source = source + } + end +end + +function m.inferByBinary(status, source) + if #status.results ~= 0 then + return + end + local parent = source.parent + if not parent or parent.type ~= 'binary' then + return + end + local op = parent.op + if op.type == '<=' + or op.type == '>=' + or op.type == '<' + or op.type == '>' + or op.type == '^' + or op.type == '/' + or op.type == '+' + or op.type == '-' + or op.type == '*' + or op.type == '%' then + status.results[#status.results+1] = { + type = 'number', + source = source, + } + elseif op.type == '|' + or op.type == '~' + or op.type == '&' + or op.type == '<<' + or op.type == '>>' + -- 整数的可能性比较高 + or op.type == '//' then + status.results[#status.results+1] = { + type = 'integer', + source = source, + } + elseif op.type == '..' then + status.results[#status.results+1] = { + type = 'string', + source = source, + } + end +end + +local function mergeFunctionReturnsByDoc(status, source, index, call) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local results = m.getDocTypeNames(status, rtn, function (typeName, typeUnit) + if not source.args or not call.args then + return + end + local name = typeUnit[1] + local generics = typeUnit.typeGeneric[name] + if not generics then + return + end + local first = generics[1] + if not first or first == typeUnit then + return + end + local docParam = m.getParentType(first, 'doc.param') + local paramName = docParam.param[1] + for i, arg in ipairs(source.args) do + if arg[1] == paramName then + local callArg = call.args[i] + if not callArg then + return + end + return m.viewInferType(m.searchInfer(status, callArg)) + end + end + end) + if #results == 0 then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +local function mergeDocTypeFunctionReturns(status, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local results = m.getDocTypeNames(status, rtn) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + end + end + end + end + end +end + +local function mergeFunctionReturns(status, source, index, call) + local returns = source.returns + if not returns then + return + end + for i = 1, #returns do + local rtn = returns[i] + if rtn[index] then + if rtn[index].type == 'call' then + if not m.checkReturnMark(status, rtn[index]) then + m.checkReturnMark(status, rtn[index], true) + m.inferByCallReturnAndIndex(status, rtn[index], index) + end + else + local newStatus = m.status(status) + m.searchInfer(newStatus, rtn[index]) + if #newStatus.results == 0 then + status.results[#status.results+1] = { + type = 'any', + source = rtn[index], + } + else + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end + end + end + end +end + +function m.inferByCallReturnAndIndex(status, call, index) + local node = call.node + local newStatus = m.status(nil, status.interface) + m.searchRefs(newStatus, node, 'def') + local hasDocReturn + for _, src in ipairs(newStatus.results) do + if mergeDocTypeFunctionReturns(status, src, index) then + hasDocReturn = true + elseif mergeFunctionReturnsByDoc(status, src.value, index, call) then + hasDocReturn = true + end + end + if not hasDocReturn then + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + if not m.checkReturnMark(status, src.value, true) then + mergeFunctionReturns(status, src.value, index, call) + end + end + end + end +end + +function m.inferByCallReturn(status, source) + if source.type == 'call' then + m.inferByCallReturnAndIndex(status, source, 1) + return + end + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + if not source.vararg or source.vararg.type ~= 'call' then + return + end + m.inferByCallReturnAndIndex(status, source.vararg, source.index) +end + +function m.inferByPCallReturn(status, source) + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + local call = source.vararg + if not call or call.type ~= 'call' then + return + end + local node = call.node + local specialName = node.special + local func, index + if specialName == 'pcall' then + func = call.args[1] + index = source.index - 1 + elseif specialName == 'xpcall' then + func = call.args[1] + index = source.index - 2 + else + return + end + local newStatus = m.status(nil, status.interface) + m.searchRefs(newStatus, func, 'def') + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + mergeFunctionReturns(status, src.value, index) + end + end +end + +function m.cleanInfers(infers) + local mark = {} + for i = #infers, 1, -1 do + local infer = infers[i] + local key = ('%s|%p'):format(infer.type, infer.source) + if mark[key] then + infers[i] = infers[#infers] + infers[#infers] = nil + else + mark[key] = true + end + end +end + +function m.searchInfer(status, obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return + end + end + while true do + local value = m.getObjectValue(obj) + if not value or value == obj then + break + end + obj = value + end + + local cache, makeCache = m.getRefCache(status, obj, 'infer') + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + if DEVELOP then + status.cache.clock = status.cache.clock or osClock() + end + + if not status.cache.lockInfer then + status.cache.lockInfer = {} + end + if status.cache.lockInfer[obj] then + return + end + status.cache.lockInfer[obj] = true + + local checked = m.inferCheckDoc(status, obj) + or m.inferCheckUpDoc(status, obj) + or m.inferCheckFieldDoc(status, obj) + or m.inferCheckLiteral(status, obj) + or m.inferCheckUnary(status, obj) + or m.inferCheckBinary(status, obj) + if checked then + m.cleanInfers(status.results) + if makeCache then + makeCache(status.results) + end + return + end + + if status.deep then + m.inferByDef(status, obj) + end + m.inferBySet(status, obj) + m.inferByCall(status, obj) + m.inferByGetTable(status, obj) + m.inferByUnary(status, obj) + m.inferByBinary(status, obj) + m.inferByCallReturn(status, obj) + m.inferByPCallReturn(status, obj) + m.cleanInfers(status.results) + if makeCache then + makeCache(status.results) + end +end + +--- 请求对象的引用,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 不穿透 `setmetatable` ,考虑由 +--- 业务层进行反向 def 搜索。 +function m.requestReference(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'ref') + + m.searchRefsAsFunction(status, obj, 'ref') + + if m.debugMode then + print('count:', status.cache.count) + end + + return status.results, status.cache.count +end + +--- 请求对象的定义,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 穿透 `setmetatable` 。 +function m.requestDefinition(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + -- 根据 field 搜索定义 + m.searchRefs(status, obj, 'def') + + return status.results, status.cache.count +end + +--- 请求对象的域 +function m.requestFields(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + + m.searchFields(status, obj) + + return status.results, status.cache.count +end + +--- 请求对象的类型推测 +function m.requestInfer(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + m.searchInfer(status, obj) + + return status.results, status.cache.count +end + +return m |