summaryrefslogtreecommitdiff
path: root/script-beta/vm/getInfer.lua
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2020-08-16 16:20:12 +0800
committer最萌小汐 <sumneko@hotmail.com>2020-08-16 16:20:12 +0800
commit68f6df0181c8c5e499b571c0447af9203d3fd6d9 (patch)
tree6d44c1dfd98a57e01ca1482deb4cc8206ddb9c33 /script-beta/vm/getInfer.lua
parent8fdeec14b4a90e1f15ed639ebc7b7b423e100424 (diff)
downloadlua-language-server-68f6df0181c8c5e499b571c0447af9203d3fd6d9.zip
整理代码
Diffstat (limited to 'script-beta/vm/getInfer.lua')
-rw-r--r--script-beta/vm/getInfer.lua949
1 files changed, 949 insertions, 0 deletions
diff --git a/script-beta/vm/getInfer.lua b/script-beta/vm/getInfer.lua
new file mode 100644
index 00000000..e98ac1a0
--- /dev/null
+++ b/script-beta/vm/getInfer.lua
@@ -0,0 +1,949 @@
+local vm = require 'vm.vm'
+local util = require 'utility'
+local guide = require 'parser.guide'
+local library = require 'library'
+local select = select
+
+local typeSort = {
+ ['boolean'] = 1,
+ ['string'] = 2,
+ ['integer'] = 3,
+ ['number'] = 4,
+ ['table'] = 5,
+ ['function'] = 6,
+ ['nil'] = math.maxinteger,
+}
+
+NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end })
+
+local function merge(t, b)
+ if not t then
+ t = {}
+ end
+ if not b then
+ return t
+ end
+ for i = 1, #b do
+ local o = b[i]
+ if not t[o] then
+ t[o] = true
+ t[#t+1] = o
+ end
+ end
+ return t
+end
+
+local function alloc(o)
+ -- TODO
+ assert(o.type)
+ if type(o.type) == 'table' then
+ local values = {}
+ for i = 1, #o.type do
+ local sub = {
+ type = o.type[i],
+ value = o.value,
+ source = o.source,
+ }
+ values[i] = sub
+ values[sub] = true
+ end
+ return values
+ else
+ return {
+ [1] = o,
+ [o] = true,
+ }
+ end
+end
+
+local function insert(t, o)
+ if not o then
+ return
+ end
+ if not t[o] then
+ t[o] = true
+ t[#t+1] = o
+ end
+ return t
+end
+
+local function checkLiteral(source)
+ if source.type == 'string' then
+ return alloc {
+ type = 'string',
+ value = source[1],
+ source = source,
+ }
+ elseif source.type == 'nil' then
+ return alloc {
+ type = 'nil',
+ value = NIL,
+ source = source,
+ }
+ elseif source.type == 'boolean' then
+ return alloc {
+ type = 'boolean',
+ value = source[1],
+ source = source,
+ }
+ elseif source.type == 'number' then
+ if math.type(source[1]) == 'integer' then
+ return alloc {
+ type = 'integer',
+ value = source[1],
+ source = source,
+ }
+ else
+ return alloc {
+ type = 'number',
+ value = source[1],
+ source = source,
+ }
+ end
+ elseif source.type == 'integer' then
+ return alloc {
+ type = 'integer',
+ source = source,
+ }
+ elseif source.type == 'table' then
+ return alloc {
+ type = 'table',
+ source = source,
+ }
+ elseif source.type == 'function' then
+ return alloc {
+ type = 'function',
+ source = source,
+ }
+ elseif source.type == '...' then
+ return alloc {
+ type = '...',
+ source = source,
+ }
+ end
+end
+
+local function checkUnary(source)
+ if source.type ~= 'unary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'not' then
+ local checkTrue = vm.checkTrue(source[1])
+ local value = nil
+ if checkTrue == true then
+ value = false
+ elseif checkTrue == false then
+ value = true
+ end
+ return alloc {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ elseif op.type == '#' then
+ return alloc {
+ type = 'integer',
+ source = source,
+ }
+ elseif op.type == '~' then
+ local l = vm.getLiteral(source[1], 'integer')
+ return alloc {
+ type = 'integer',
+ value = l and ~l or nil,
+ source = source,
+ }
+ elseif op.type == '-' then
+ local v = vm.getLiteral(source[1], 'integer')
+ if v then
+ return alloc {
+ type = 'integer',
+ value = - v,
+ source = source,
+ }
+ end
+ v = vm.getLiteral(source[1], 'number')
+ return alloc {
+ type = 'number',
+ value = v and -v or nil,
+ source = source,
+ }
+ end
+end
+
+local function mathCheck(a, b)
+ local v1 = vm.getLiteral(a, 'integer') or vm.getLiteral(a, 'number')
+ local v2 = vm.getLiteral(b, 'integer') or vm.getLiteral(a, 'number')
+ local int = vm.hasType(a, 'integer')
+ and vm.hasType(b, 'integer')
+ and not vm.hasType(a, 'number')
+ and not vm.hasType(b, 'number')
+ return int and 'integer' or 'number', v1, v2
+end
+
+local function checkBinary(source)
+ if source.type ~= 'binary' then
+ return
+ end
+ local op = source.op
+ if op.type == 'and' then
+ local isTrue = vm.checkTrue(source[1])
+ if isTrue == true then
+ return vm.getInfers(source[2])
+ elseif isTrue == false then
+ return vm.getInfers(source[1])
+ else
+ return merge(
+ vm.getInfers(source[1]),
+ vm.getInfers(source[2])
+ )
+ end
+ elseif op.type == 'or' then
+ local isTrue = vm.checkTrue(source[1])
+ if isTrue == true then
+ return vm.getInfers(source[1])
+ elseif isTrue == false then
+ return vm.getInfers(source[2])
+ else
+ return merge(
+ vm.getInfers(source[1]),
+ vm.getInfers(source[2])
+ )
+ end
+ elseif op.type == '==' then
+ local value = vm.isSameValue(source[1], source[2])
+ if value ~= nil then
+ return alloc {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ end
+ local isSame = vm.isSameRef(source[1], source[2])
+ if isSame == true then
+ value = true
+ else
+ value = nil
+ end
+ return alloc {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ elseif op.type == '~=' then
+ local value = vm.isSameValue(source[1], source[2])
+ if value ~= nil then
+ return alloc {
+ type = 'boolean',
+ value = not value,
+ source = source,
+ }
+ end
+ local isSame = vm.isSameRef(source[1], source[2])
+ if isSame == true then
+ value = false
+ else
+ value = nil
+ end
+ return alloc {
+ type = 'boolean',
+ value = value,
+ source = source,
+ }
+ elseif op.type == '<=' then
+ local v1 = vm.getLiteral(source[1], 'integer') or vm.getLiteral(source[1], 'number')
+ local v2 = vm.getLiteral(source[2], 'integer') or vm.getLiteral(source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 <= v2
+ end
+ return alloc {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '>=' then
+ local v1 = vm.getLiteral(source[1], 'integer') or vm.getLiteral(source[1], 'number')
+ local v2 = vm.getLiteral(source[2], 'integer') or vm.getLiteral(source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 >= v2
+ end
+ return alloc {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '<' then
+ local v1 = vm.getLiteral(source[1], 'integer') or vm.getLiteral(source[1], 'number')
+ local v2 = vm.getLiteral(source[2], 'integer') or vm.getLiteral(source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 < v2
+ end
+ return alloc {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '>' then
+ local v1 = vm.getLiteral(source[1], 'integer') or vm.getLiteral(source[1], 'number')
+ local v2 = vm.getLiteral(source[2], 'integer') or vm.getLiteral(source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 > v2
+ end
+ return alloc {
+ type = 'boolean',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '|' then
+ local v1 = vm.getLiteral(source[1], 'integer')
+ local v2 = vm.getLiteral(source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 | v2
+ end
+ return alloc {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '~' then
+ local v1 = vm.getLiteral(source[1], 'integer')
+ local v2 = vm.getLiteral(source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 ~ v2
+ end
+ return alloc {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '&' then
+ local v1 = vm.getLiteral(source[1], 'integer')
+ local v2 = vm.getLiteral(source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 & v2
+ end
+ return alloc {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '<<' then
+ local v1 = vm.getLiteral(source[1], 'integer')
+ local v2 = vm.getLiteral(source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 << v2
+ end
+ return alloc {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '>>' then
+ local v1 = vm.getLiteral(source[1], 'integer')
+ local v2 = vm.getLiteral(source[2], 'integer')
+ local v
+ if v1 and v2 then
+ v = v1 >> v2
+ end
+ return alloc {
+ type = 'integer',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '..' then
+ local v1 = vm.getLiteral(source[1], 'string')
+ local v2 = vm.getLiteral(source[2], 'string')
+ local v
+ if v1 and v2 then
+ v = v1 .. v2
+ end
+ return alloc {
+ type = 'string',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '^' then
+ local v1 = vm.getLiteral(source[1], 'integer') or vm.getLiteral(source[1], 'number')
+ local v2 = vm.getLiteral(source[2], 'integer') or vm.getLiteral(source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 ^ v2
+ end
+ return alloc {
+ type = 'number',
+ value = v,
+ source = source,
+ }
+ elseif op.type == '/' then
+ local v1 = vm.getLiteral(source[1], 'integer') or vm.getLiteral(source[1], 'number')
+ local v2 = vm.getLiteral(source[2], 'integer') or vm.getLiteral(source[2], 'number')
+ local v
+ if v1 and v2 then
+ v = v1 > v2
+ end
+ return alloc {
+ type = 'number',
+ value = v,
+ source = source,
+ }
+ -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数
+ elseif op.type == '+' then
+ local int, v1, v2 = mathCheck(source[1], source[2])
+ return alloc {
+ type = int,
+ value = (v1 and v2) and (v1 + v2) or nil,
+ source = source,
+ }
+ elseif op.type == '-' then
+ local int, v1, v2 = mathCheck(source[1], source[2])
+ return alloc {
+ type = int,
+ value = (v1 and v2) and (v1 - v2) or nil,
+ source = source,
+ }
+ elseif op.type == '*' then
+ local int, v1, v2 = mathCheck(source[1], source[2])
+ return alloc {
+ type = int,
+ value = (v1 and v2) and (v1 * v2) or nil,
+ source = source,
+ }
+ elseif op.type == '%' then
+ local int, v1, v2 = mathCheck(source[1], source[2])
+ return alloc {
+ type = int,
+ value = (v1 and v2) and (v1 % v2) or nil,
+ source = source,
+ }
+ elseif op.type == '//' then
+ local int, v1, v2 = mathCheck(source[1], source[2])
+ return alloc {
+ type = int,
+ value = (v1 and v2) and (v1 // v2) or nil,
+ source = source,
+ }
+ end
+end
+
+local function inferByCall(results, source)
+ if #results ~= 0 then
+ return
+ end
+ if not source.parent then
+ return
+ end
+ if source.parent.type ~= 'call' then
+ return
+ end
+ if source.parent.node == source then
+ insert(results, {
+ type = 'function',
+ source = source,
+ })
+ return
+ end
+end
+
+local function inferByGetTable(results, source)
+ if #results ~= 0 then
+ return
+ end
+ local next = source.next
+ if not next then
+ return
+ end
+ if next.type == 'getfield'
+ or next.type == 'getindex'
+ or next.type == 'getmethod'
+ or next.type == 'setfield'
+ or next.type == 'setindex'
+ or next.type == 'setmethod' then
+ insert(results, {
+ type = 'table',
+ source = source,
+ })
+ end
+end
+
+local function inferByDef(results, source)
+ local defs = vm.getDefs(source)
+ for _, src in ipairs(defs) do
+ local tp = vm.inferValue(src, false)
+ if tp then
+ merge(results, tp)
+ end
+ end
+end
+
+local function checkLibraryTypes(source)
+ if type(source.type) ~= 'table' then
+ return nil
+ end
+ local results = {}
+ for i = 1, #source.type do
+ insert(results, {
+ type = source.type[i],
+ source = source,
+ })
+ end
+ return results
+end
+
+local function checkLibrary(source)
+ local lib = vm.getLibrary(source)
+ if not lib then
+ return nil
+ end
+ return alloc {
+ type = lib.type,
+ value = lib.value,
+ source = lib,
+ }
+end
+
+local function checkSpecialReturn(source)
+ if source.type ~= 'select' then
+ return nil
+ end
+ local index = source.index
+ local call = source.vararg
+ if call.type ~= 'call' then
+ return nil
+ end
+ local func = call.node
+ local lib = vm.getLibrary(func)
+ if not lib then
+ return nil
+ end
+ if lib.special == 'require' then
+ local modName = call.args[1]
+ if modName and modName.type == 'string' then
+ lib = library.library[modName[1]]
+ if lib then
+ return alloc {
+ type = lib.type,
+ value = lib.value,
+ source = lib,
+ }
+ end
+ end
+ end
+ return nil
+end
+
+local function checkLibraryReturn(source)
+ if source.type ~= 'select' then
+ return nil
+ end
+ local index = source.index
+ local call = source.vararg
+ if call.type ~= 'call' then
+ return nil
+ end
+ local func = call.node
+ local lib = vm.getLibrary(func)
+ if not lib then
+ return nil
+ end
+ if lib.type ~= 'function' then
+ return nil
+ end
+ if not lib.returns then
+ return nil
+ end
+ local rtn = lib.returns[index]
+ if not rtn then
+ return nil
+ end
+ if not rtn.type then
+ return nil
+ end
+ if rtn.type == '...' or rtn.type == 'any' then
+ return
+ end
+ return alloc {
+ type = rtn.type,
+ value = rtn.value,
+ source = rtn,
+ }
+end
+
+local function inferByLibraryArg(results, source)
+ local args = source.parent
+ if not args then
+ return
+ end
+ if args.type ~= 'callargs' then
+ return
+ end
+ local call = args.parent
+ if not call then
+ return
+ end
+ local func = call.node
+ local index
+ for i = 1, #args do
+ if args[i] == source then
+ index = i
+ break
+ end
+ end
+ if not index then
+ return
+ end
+ local lib = vm.getLibrary(func)
+ local arg = lib and lib.args and lib.args[index]
+ if not arg then
+ return
+ end
+ if not arg.type then
+ return
+ end
+ if arg.type == '...' or arg.type == 'any' then
+ return
+ end
+ return insert(results, {
+ type = arg.type,
+ value = arg.value,
+ source = arg,
+ })
+end
+
+local function hasTypeInResults(results, type)
+ for i = 1, #results do
+ if results[i].type == 'type' then
+ return true
+ end
+ end
+ return false
+end
+
+local function inferByUnary(results, source)
+ if #results ~= 0 then
+ return
+ end
+ local parent = source.parent
+ if not parent or parent.type ~= 'unary' then
+ return
+ end
+ local op = parent.op
+ if op.type == '#' then
+ -- 会受顺序影响,不检查了
+ --if hasTypeInResults(results, 'string')
+ --or hasTypeInResults(results, 'integer') then
+ -- return
+ --end
+ insert(results, {
+ type = 'string',
+ source = source
+ })
+ insert(results, {
+ type = 'table',
+ source = source
+ })
+ elseif op.type == '~' then
+ insert(results, {
+ type = 'integer',
+ source = source
+ })
+ elseif op.type == '-' then
+ insert(results, {
+ type = 'number',
+ source = source
+ })
+ end
+end
+
+local function inferByBinary(results, source)
+ if #results ~= 0 then
+ return
+ end
+ local parent = source.parent
+ if not parent or parent.type ~= 'binary' then
+ return
+ end
+ local op = parent.op
+ if op.type == '<='
+ or op.type == '>='
+ or op.type == '<'
+ or op.type == '>'
+ or op.type == '^'
+ or op.type == '/'
+ or op.type == '+'
+ or op.type == '-'
+ or op.type == '*'
+ or op.type == '%' then
+ insert(results, {
+ type = 'number',
+ source = source
+ })
+ elseif op.type == '|'
+ or op.type == '~'
+ or op.type == '&'
+ or op.type == '<<'
+ or op.type == '>>'
+ -- 整数的可能性比较高
+ or op.type == '//' then
+ insert(results, {
+ type = 'integer',
+ source = source
+ })
+ elseif op.type == '..' then
+ insert(results, {
+ type = 'string',
+ source = source
+ })
+ end
+end
+
+local function inferBySetOfLocal(results, source)
+ if source.ref then
+ for i = 1, math.min(#source.ref, 100) do
+ local ref = source.ref[i]
+ if ref.type == 'setlocal' then
+ break
+ end
+ merge(results, vm.getInfers(ref))
+ end
+ end
+end
+
+local function inferBySet(results, source)
+ if #results ~= 0 then
+ return
+ end
+ if source.type == 'local' then
+ inferBySetOfLocal(results, source)
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ merge(results, vm.getInfers(source.node))
+ end
+end
+
+local function mergeFunctionReturns(results, source, index)
+ local returns = source.returns
+ if not returns then
+ return
+ end
+ for i = 1, #returns do
+ local rtn = returns[i]
+ if rtn[index] then
+ merge(results, vm.getInfers(rtn[index]))
+ end
+ end
+end
+
+local function inferByCallReturn(results, source)
+ if source.type ~= 'select' then
+ return
+ end
+ if not source.vararg or source.vararg.type ~= 'call' then
+ return
+ end
+ local node = source.vararg.node
+ local nodeValues = vm.getInfers(node)
+ if not nodeValues then
+ return
+ end
+ local index = source.index
+ for i = 1, #nodeValues do
+ local value = nodeValues[i]
+ local src = value.source
+ if src.type == 'function' then
+ mergeFunctionReturns(results, src, index)
+ end
+ end
+end
+
+local function inferByPCallReturn(results, source)
+ if source.type ~= 'select' then
+ return
+ end
+ local call = source.vararg
+ if not call or call.type ~= 'call' then
+ return
+ end
+ local node = call.node
+ local lib = vm.getLibrary(node)
+ if not lib then
+ return
+ end
+ local func, index
+ if lib.name == 'pcall' then
+ func = call.args[1]
+ index = source.index - 1
+ elseif lib.name == 'xpcall' then
+ func = call.args[1]
+ index = source.index - 2
+ else
+ return
+ end
+ local funcValues = vm.getInfers(func)
+ if not funcValues then
+ return
+ end
+ for i = 1, #funcValues do
+ local value = funcValues[i]
+ local src = value.source
+ if src.type == 'function' then
+ mergeFunctionReturns(results, src, index)
+ end
+ end
+end
+
+function vm.inferValue(source, infer)
+ source = guide.getObjectValue(source) or source
+ local results = checkLiteral(source)
+ or checkUnary(source)
+ or checkBinary(source)
+ or checkLibraryTypes(source)
+ or checkLibrary(source)
+ or checkSpecialReturn(source)
+ or checkLibraryReturn(source)
+ if results then
+ return results
+ end
+ if not infer then
+ return
+ end
+
+ results = {}
+ inferByLibraryArg(results, source)
+ inferByDef(results, source)
+ inferBySet(results, source)
+ inferByCall(results, source)
+ inferByGetTable(results, source)
+ inferByUnary(results, source)
+ inferByBinary(results, source)
+ inferByCallReturn(results, source)
+ inferByPCallReturn(results, source)
+
+ if #results == 0 then
+ return nil
+ end
+
+ return results
+end
+
+function vm.checkTrue(source)
+ local values = vm.getInfers(source)
+ if not values then
+ return
+ end
+ -- 当前认为的结果
+ local current
+ for i = 1, #values do
+ -- 新的结果
+ local new
+ local v = values[i]
+ if v.type == 'nil' then
+ new = false
+ elseif v.type == 'boolean' then
+ if v.value == true then
+ new = true
+ elseif v.value == false then
+ new = false
+ end
+ end
+ if new ~= nil then
+ if current == nil then
+ current = new
+ else
+ -- 如果2个结果完全相反,则返回 nil 表示不确定
+ if new ~= current then
+ return nil
+ end
+ end
+ end
+ end
+ return current
+end
+
+--- 获取特定类型的字面量值
+function vm.getLiteral(source, type)
+ local values = vm.getInfers(source)
+ if not values then
+ return nil
+ end
+ for i = 1, #values do
+ local v = values[i]
+ if v.value ~= nil then
+ if type == nil or v.type == type then
+ return v.value
+ end
+ end
+ end
+ return nil
+end
+
+function vm.isSameValue(a, b)
+ local valuesA = vm.getInfers(a)
+ local valuesB = vm.getInfers(b)
+ if not valuesA or not valuesB then
+ return false
+ end
+ if valuesA == valuesB then
+ return true
+ end
+ local values = {}
+ for i = 1, #valuesA do
+ local value = valuesA[i]
+ local literal = value.value
+ if literal then
+ values[literal] = false
+ end
+ end
+ for i = 1, #valuesB do
+ local value = valuesA[i]
+ local literal = value.value
+ if literal then
+ if values[literal] == nil then
+ return false
+ end
+ values[literal] = true
+ end
+ end
+ for k, v in pairs(values) do
+ if v == false then
+ return false
+ end
+ end
+ return true
+end
+
+--- 是否包含某种类型
+function vm.hasType(source, type)
+ local infers = vm.getInfers(source)
+ if not infers then
+ return false
+ end
+ for i = 1, #infers do
+ local infer = infers[i]
+ if infer.type == type then
+ return true
+ end
+ end
+ return false
+end
+
+function vm.getType(source)
+ local infers = vm.getInfers(source)
+ return guide.viewInfer(infers)
+end
+
+--- 获取对象的值
+--- 会尝试穿透函数调用
+function vm.getInfers(source)
+ if not source then
+ return
+ end
+ return guide.requestInfer(source, vm.interface)
+end