summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/infer.lua283
-rw-r--r--test/type_inference/init.lua8
2 files changed, 290 insertions, 1 deletions
diff --git a/script/core/infer.lua b/script/core/infer.lua
new file mode 100644
index 00000000..c70662a0
--- /dev/null
+++ b/script/core/infer.lua
@@ -0,0 +1,283 @@
+local searcher = require 'core.searcher'
+local config = require 'config'
+
+local m = {}
+
+local function mergeTable(a, b)
+ if not b then
+ return
+ end
+ for v in pairs(b) do
+ a[v] = true
+ end
+end
+
+local function searchInferOfValue(value, infers)
+ if value.type == 'string' then
+ infers['string'] = true
+ return
+ end
+ if value.type == 'boolean' then
+ infers['boolean'] = true
+ return
+ end
+ if value.type == 'table' then
+ infers['table'] = true
+ return
+ end
+ if value.type == 'number' then
+ if math.type(value[1]) == 'integer' then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+ if value.type == 'function' then
+ infers['function'] = true
+ return
+ end
+ if value.type == 'unary' then
+ local op = value.op.type
+ if op == 'not' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '#' then
+ infers['integer'] = true
+ return
+ end
+ if op == '-' then
+ if m.hasType(value[1], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+ if op == '~' then
+ infers['integer'] = true
+ return
+ end
+ return
+ end
+ if value.type == 'binary' then
+ local op = value.op.type
+ if op == 'and' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[2]))
+ else
+ mergeTable(infers, m.searchInfers(value[1]))
+ end
+ return
+ end
+ if op == 'or' then
+ if m.isTrue(value[1]) then
+ mergeTable(infers, m.searchInfers(value[1]))
+ else
+ mergeTable(infers, m.searchInfers(value[2]))
+ end
+ return
+ end
+ if op == '=='
+ or op == '~='
+ or op == '<'
+ or op == '>'
+ or op == '<='
+ or op == '>=' then
+ infers['boolean'] = true
+ return
+ end
+ if op == '<<'
+ or op == '>>'
+ or op == '~'
+ or op == '&'
+ or op == '|' then
+ infers['integer'] = true
+ return
+ end
+ if op == '..' then
+ infers['string'] = true
+ return
+ end
+ if op == '^'
+ or op == '/' then
+ infers['number'] = true
+ return
+ end
+ if op == '+'
+ or op == '-'
+ or op == '*'
+ or op == '%'
+ or op == '//' then
+ if m.hasType(value[1], 'integer')
+ and m.hasType(value[2], 'integer') then
+ infers['integer'] = true
+ else
+ infers['number'] = true
+ end
+ return
+ end
+ end
+end
+
+local function searchLiteralOfValue(value, literals)
+ if value.type == 'string'
+ or value.type == 'boolean'
+ or value.tyoe == 'number'
+ or value.type == 'integer' then
+ local v = value[1]
+ if v ~= nil then
+ literals[v] = true
+ end
+ return
+ end
+ if value.type == 'unary' then
+ local op = value.op.type
+ if op == '-' then
+ end
+ if op == '~' then
+ end
+ end
+ return
+end
+
+local function bindClassOrType(source)
+ if not source.bindDocs then
+ return false
+ end
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.class'
+ or doc.type == 'doc.type' then
+ return true
+ end
+ end
+ return false
+end
+
+local function cleanInfers(infers)
+ local version = config.config.runtime.version
+ local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4'
+ if infers['number'] then
+ enableInteger = false
+ end
+ if not enableInteger and infers['integer'] then
+ infers['integer'] = nil
+ infers['number'] = true
+ end
+end
+
+---合并对象的推断类型
+---@param infers string[]
+---@return string
+function m.viewInfers(infers)
+ local count = 0
+ for infer in pairs(infers) do
+ count = count + 1
+ infers[count] = infer
+ end
+ for i = count + 1, #infers do
+ infers[i] = nil
+ end
+ if #infers == 0 then
+ return 'any'
+ end
+ table.sort(infers)
+ return table.concat(infers, '|', 1, count)
+end
+
+---显示对象的推断类型
+---@param source parser.guide.object
+---@return string
+local function searchInfer(source, infers)
+ if bindClassOrType(source) then
+ return
+ end
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchInferOfValue(value, infers)
+ return
+ end
+ -- X.a
+ if source.next and source.next.node == source then
+ if source.next.type == 'setfield'
+ or source.next.type == 'setindex'
+ or source.next.type == 'setmethod' then
+ infers['table'] = true
+ end
+ return
+ end
+end
+
+local function searchLiteral(source, literals)
+ local value = searcher.getObjectValue(source)
+ if value then
+ searchLiteralOfValue(value, literals)
+ return
+ end
+end
+
+---搜索对象的推断类型
+---@param source parser.guide.object
+---@return string[]
+function m.searchInfers(source)
+ if not source then
+ return nil
+ end
+ local defs = searcher.requestDefinition(source)
+ local infers = {}
+ searchInfer(source, infers)
+ for _, def in ipairs(defs) do
+ searchInfer(def, infers)
+ end
+ cleanInfers(infers)
+ return infers
+end
+
+---搜索对象的字面量值
+---@param source parser.guide.object
+---@return table
+function m.searchLiterals(source)
+ local defs = searcher.requestDefinition(source)
+ local literals = {}
+ searchLiteral(source, literals)
+ for _, def in ipairs(defs) do
+ searchLiteral(def, literals)
+ end
+ return literals
+end
+
+---判断对象的推断值是否是 true
+---@param source parser.guide.object
+function m.isTrue(source)
+ if not source then
+ return false
+ end
+ local literals = m.searchLiterals(source)
+ for literal in pairs(literals) do
+ if literal ~= false then
+ return true
+ end
+ end
+ return false
+end
+
+---判断对象的推断类型是否包含某个类型
+function m.hasType(source, tp)
+ local infers = m.searchInfers(source)
+ return infers[tp]
+end
+
+---搜索并显示推断类型
+---@param source parser.guide.object
+---@return string
+function m.searchAndViewInfers(source)
+ if not source then
+ return 'any'
+ end
+ local infers = m.searchInfers(source)
+ local view = m.viewInfers(infers)
+ return view
+end
+
+return m
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index dc77e0f8..df8f4fe8 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1,6 +1,7 @@
local files = require 'files'
local vm = require 'vm'
local guide = require 'parser.guide'
+local infer = require 'core.infer'
rawset(_G, 'TEST', true)
@@ -29,7 +30,7 @@ function TEST(wanted)
files.setText('', newScript)
local source = getSource(pos)
assert(source)
- local result = vm.getInferType(source, 0)
+ local result = infer.searchAndViewInfers(source, 0)
assert(wanted == result)
end
end
@@ -139,10 +140,15 @@ TEST 'number' [[
]]
TEST 'tablelib' [[
+---@class tablelib
+table = {}
+
<?table?>()
]]
TEST 'string' [[
+_VERSION = 'Lua 5.4'
+
<?x?> = _VERSION
]]