diff options
-rw-r--r-- | script/core/definition.lua | 2 | ||||
-rw-r--r-- | script/core/infer.lua | 637 | ||||
-rw-r--r-- | script/core/type-definition.lua | 2 | ||||
-rw-r--r-- | script/vm/def.lua (renamed from script/vm/getDef.lua) | 6 | ||||
-rw-r--r-- | script/vm/infer.lua | 52 | ||||
-rw-r--r-- | script/vm/init.lua | 2 | ||||
-rw-r--r-- | script/vm/node.lua | 5 | ||||
-rw-r--r-- | test.lua | 2 | ||||
-rw-r--r-- | test/type_inference/init.lua | 22 |
9 files changed, 73 insertions, 657 deletions
diff --git a/script/core/definition.lua b/script/core/definition.lua index a01a6a25..78ebce8f 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -130,7 +130,7 @@ return function (uri, offset) end end - local defs = vm.getAllDefs(source) + local defs = vm.getDefs(source) for _, src in ipairs(defs) do if src.dummy then diff --git a/script/core/infer.lua b/script/core/infer.lua deleted file mode 100644 index 35a054ca..00000000 --- a/script/core/infer.lua +++ /dev/null @@ -1,637 +0,0 @@ -local config = require 'config' -local util = require 'utility' -local vm = require "vm.vm" -local guide = require "parser.guide" - -local CLASS = { 'CLASS' } -local TABLE = { 'TABLE' } -local CACHE = { 'CACHE' } - -local typeSort = { - ['boolean'] = 1, - ['string'] = 2, - ['integer'] = 3, - ['number'] = 4, - ['table'] = 5, - ['function'] = 6, - ['true'] = 101, - ['false'] = 102, -} - -local m = {} - -local function mergeTable(a, b) - if not b then - return - end - for v in pairs(b) do - a[v] = true - end - a[CACHE] = nil -end - -local function isBaseType(source, mark) - return m.hasType(source, 'number', mark) - or m.hasType(source, 'integer', mark) - or m.hasType(source, 'string', mark) -end - -local function searchInferOfUnary(value, infers, mark) - local op = value.op.type - if op == 'not' then - infers['boolean'] = true - return - end - if op == '#' then - if m.hasType(value[1], 'table', mark) - or m.hasType(value[1], 'string', mark) then - infers['integer'] = true - end - return - end - if op == '-' then - if m.hasType(value[1], 'integer', mark) then - infers['integer'] = true - elseif isBaseType(value[1], mark) then - infers['number'] = true - end - return - end - if op == '~' then - if isBaseType(value[1], mark) then - infers['integer'] = true - end - return - end -end - -local function searchInferOfBinary(value, infers, mark) - local op = value.op.type - if op == 'and' then - if m.isTrue(value[1], mark) then - mergeTable(infers, m.searchInfers(value[2], nil, mark)) - else - mergeTable(infers, m.searchInfers(value[1], nil, mark)) - end - return - end - if op == 'or' then - if m.isTrue(value[1], mark) then - mergeTable(infers, m.searchInfers(value[1], nil, mark)) - else - mergeTable(infers, m.searchInfers(value[2], nil, mark)) - end - return - end - -- must return boolean - if op == '==' - or op == '~=' - or op == '<' - or op == '>' - or op == '<=' - or op == '>=' then - infers['boolean'] = true - return - end - -- check number - if op == '<<' - or op == '>>' - or op == '~' - or op == '&' - or op == '|' then - if isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['integer'] = true - end - return - end - if op == '..' then - if isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['string'] = true - end - return - end - if op == '^' - or op == '/' then - if isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['number'] = true - end - return - end - if op == '+' - or op == '-' - or op == '*' - or op == '%' - or op == '//' then - if m.hasType(value[1], 'integer', mark) - and m.hasType(value[2], 'integer', mark) then - infers['integer'] = true - elseif isBaseType(value[1], mark) - and isBaseType(value[2], mark) then - infers['number'] = true - end - return - end -end - -local function searchInferOfValue(value, infers, mark) - if value.type == 'string' then - infers['string'] = true - return true - end - if value.type == 'boolean' then - infers['boolean'] = true - return true - end - if value.type == 'table' then - if value.array then - local node = m.searchAndViewInfers(value.array, nil, mark) - if node ~= 'any' then - local infer = node .. '[]' - infers[infer] = true - end - else - infers['table'] = true - end - return true - end - if value.type == 'integer' then - infers['integer'] = true - return true - end - if value.type == 'number' then - infers['number'] = true - return true - end - if value.type == 'function' then - infers['function'] = true - return true - end - if value.type == 'unary' then - searchInferOfUnary(value, infers, mark) - return true - end - if value.type == 'binary' then - searchInferOfBinary(value, infers, mark) - return true - end - return false -end - -local function searchLiteralOfValue(value, literals, mark) - if value.type == 'string' - or value.type == 'boolean' - or value.type == 'number' - or value.type == 'integer' then - local v = value[1] - if v ~= nil then - literals[v] = true - end - return - end - if value.type == 'unary' then - local op = value.op.type - if op == '-' then - local subLiterals = m.searchLiterals(value[1], nil, mark) - if subLiterals then - for subLiteral in pairs(subLiterals) do - local num = tonumber(subLiteral) - if num then - literals[-num] = true - end - end - end - end - if op == '~' then - local subLiterals = m.searchLiterals(value[1], nil, mark) - if subLiterals then - for subLiteral in pairs(subLiterals) do - local num = math.tointeger(subLiteral) - if num then - literals[~num] = true - end - end - end - end - end -end - -local function bindClassOrType(source) - if not source.bindDocs then - return false - end - for _, doc in ipairs(source.bindDocs) do - if doc.type == 'doc.class' - or doc.type == 'doc.type' then - return true - end - end - return false -end - -local function cleanInfers(uri, infers) - local version = config.get(uri, 'Lua.runtime.version') - local enableInteger = version == 'Lua 5.3' or version == 'Lua 5.4' - infers['unknown'] = nil - if infers['number'] then - enableInteger = false - end - if not enableInteger and infers['integer'] then - infers['integer'] = nil - infers['number'] = true - end - -- stringlib 就是 string - if infers['stringlib'] and infers['string'] then - infers['stringlib'] = nil - end - -- 如果有doc标记,则先移除table类型 - if infers[CLASS] then - infers[CLASS] = nil - infers['table'] = nil - end - -- 用doc标记的table,加入table类型 - if infers[TABLE] then - infers[TABLE] = nil - infers['table'] = true - end - if infers['function'] then - for k in pairs(infers) do - if k:sub(1, 4) == 'fun(' then - infers[k] = nil - end - end - end -end - ----合并对象的推断类型 ----@param infers string[] ----@return string -function m.viewInfers(uri, infers) - if infers[CACHE] then - return infers[CACHE] - end - -- 如果有显性的 any ,则直接显示为 any - if infers['any'] then - infers[CACHE] = 'any' - return 'any' - end - local result = {} - local count = 0 - for infer in pairs(infers) do - count = count + 1 - result[count] = infer - end - -- 如果没有任何显性类型,则推测为 unkonwn ,显示为 any - if count == 0 then - infers[CACHE] = 'any' - return 'any' - end - table.sort(result, function (a, b) - local sa = typeSort[a] or 100 - local sb = typeSort[b] or 100 - if sa == sb then - return a < b - else - return sa < sb - end - end) - local limit = config.get(uri, 'Lua.hover.enumsLimit') - if limit < 0 then - limit = 0 - end - infers[CACHE] = table.concat(result, '|', 1, math.min(count, limit)) - if count > limit then - infers[CACHE] = ('%s...(+%d)'):format(infers[CACHE], count - limit) - end - return infers[CACHE] -end - ----合并对象的值 ----@param literals string[] ----@return string -function m.viewLiterals(literals) - local result = {} - local count = 0 - for infer in pairs(literals) do - count = count + 1 - result[count] = util.viewLiteral(infer) - end - if count == 0 then - return nil - end - table.sort(result) - local view = table.concat(result, '|') - return view -end - -function m.viewDocName(doc) - if not doc then - return nil - end - if doc.type == 'doc.type' then - local list = {} - for _, tp in ipairs(doc.types) do - list[#list+1] = m.getDocName(tp) - end - return table.concat(list, '|') - end - return m.getDocName(doc) -end - -function m.getDocName(doc) - if not doc then - return nil - end - if doc.type == 'doc.class.name' - or doc.type == 'doc.type.name' then - local name = doc[1] or '?' - if doc.typeGeneric then - return '<' .. name .. '>' - else - return tostring(name) - end - end - if doc.type == 'doc.type.array' then - local nodeName = m.viewDocName(doc.node) or '?' - return nodeName .. '[]' - end - if doc.type == 'doc.type.table' then - local node = m.viewDocName(doc.node) or '?' - local key = m.viewDocName(doc.tkey) or '?' - local value = m.viewDocName(doc.tvalue) or '?' - return ('%s<%s, %s>'):format(node, key, value) - end - if doc.type == 'doc.type.function' then - return m.viewDocFunction(doc) - end - if doc.type == 'doc.type.enum' - or doc.type == 'doc.resume' then - local value = doc[1] or '?' - return tostring(value) - end - if doc.type == 'doc.type.ltable' then - return 'table' - end -end - -function m.viewDocFunction(doc) - if doc.type ~= 'doc.type.function' then - return '' - end - local args = {} - for i, arg in ipairs(doc.args) do - args[i] = ('%s: %s'):format(arg.name[1], arg.extends and m.viewDocName(arg.extends) or 'any') - end - local label = ('fun(%s)'):format(table.concat(args, ', ')) - if #doc.returns > 0 then - local returns = {} - for i, rtn in ipairs(doc.returns) do - returns[i] = m.viewDocName(rtn) - end - label = ('%s:%s'):format(label, table.concat(returns, ', ')) - end - return label -end - ----显示对象的推断类型 ----@param source parser.object ----@param mark table ----@return string -local function searchInfer(source, infers, mark) - if mark[source] then - return - end - mark[source] = true - if bindClassOrType(source) then - return - end - if searchInferOfValue(source, infers, mark) then - return - end - local value = searcher.getObjectValue(source) - if value then - if value.type ~= 'function' - and value.type ~= 'table' - and value.type ~= 'nil' then - searchInferOfValue(value, infers, mark) - end - return - end - -- check LuaDoc - local docName = m.getDocName(source) - if docName and docName ~= 'nil' and docName ~= 'unknown' then - infers[docName] = true - if not vm.isBuiltinType(docName) then - infers[CLASS] = true - end - if docName == 'table' then - infers[TABLE] = true - end - end -end - -local function searchLiteral(source, literals, mark) - if mark[source] then - return - end - mark[source] = true - searchLiteralOfValue(source, literals, mark) - local value = searcher.getObjectValue(source) - if value then - if value.type ~= 'function' - and value.type ~= 'table' then - searchLiteralOfValue(value, literals, mark) - end - return - end -end - -local function getCachedInfers(source, field) - local inferCache = vm.getCache 'infers' - local sourceCache = inferCache[source] - if not sourceCache then - sourceCache = {} - inferCache[source] = sourceCache - end - if not field then - field = '' - end - if sourceCache[field] then - return true, sourceCache[field] - end - local infers = {} - sourceCache[field] = infers - return false, infers -end - ----搜索对象的推断类型 ----@param source parser.object ----@param field? string ----@param mark? table ----@return string[] -function m.searchInfers(source, field, mark) - if not source then - return nil - end - if source.type == 'setlocal' - or source.type == 'getlocal' then - source = source.node - end - local suc, infers = getCachedInfers(source, field) - if suc then - return infers - end - local isParam = source.parent.type == 'funcargs' - local defs = vm.getDefs(source, field) - mark = mark or {} - if not field then - searchInfer(source, infers, mark) - end - for _, def in ipairs(defs) do - if def.typeGeneric and not isParam then - goto CONTINUE - end - if def.type == 'setlocal' then - goto CONTINUE - end - searchInfer(def, infers, mark) - ::CONTINUE:: - end - if source.type == 'doc.type' then - for _, def in ipairs(source.types) do - if def.typeGeneric then - searchInfer(def, infers, mark) - end - end - end - cleanInfers(guide.getUri(source), infers) - return infers -end - ----搜索对象的字面量值 ----@param source parser.object ----@param field? string ----@param mark? table ----@return table -function m.searchLiterals(source, field, mark) - if not source then - return nil - end - local defs = vm.getDefs(source, field) - local literals = {} - mark = mark or {} - if not field then - searchLiteral(source, literals, mark) - end - for _, def in ipairs(defs) do - searchLiteral(def, literals, mark) - end - return literals -end - ----搜索并显示推断值 ----@param source parser.object ----@param field? string ----@return string -function m.searchAndViewLiterals(source, field, mark) - if not source then - return nil - end - local literals = m.searchLiterals(source, field, mark) - if not literals then - return nil - end - local view = m.viewLiterals(literals) - return view -end - ----判断对象的推断值是否是 true ----@param source parser.object ----@param mark? table -function m.isTrue(source, mark) - if not source then - return false - end - mark = mark or {} - if not mark.isTrue then - mark.isTrue = {} - end - if mark.isTrue[source] == nil then - mark.isTrue[source] = false - local literals = m.searchLiterals(source, nil, mark) - if literals then - for literal in pairs(literals) do - if literal ~= false then - mark.isTrue[source] = true - break - end - end - end - end - return mark.isTrue[source] -end - ----判断对象的推断类型是否包含某个类型 -function m.hasType(source, tp, mark) - mark = mark or {} - local infers = m.searchInfers(source, nil, mark) - if not infers then - return false - end - if infers[tp] then - return true - end - if tp == 'function' then - for infer in pairs(infers) do - if infer ~= CACHE and infer:sub(1, 4) == 'fun(' then - return true - end - end - end - return false -end - ----搜索并显示推断类型 ----@param source parser.object ----@param field? string ----@return string -function m.searchAndViewInfers(source, field, mark) - if not source then - return 'any' - end - local infers = m.searchInfers(source, field, mark) - local view = m.viewInfers(guide.getUri(source), infers) - if type(view) == 'boolean' then - log.error('Why view is boolean?', util.dump(infers)) - return 'any' - end - return view -end - ----搜索并显示推断的class ----@param source parser.object ----@return string? -function m.getClass(source) - if not source then - return nil - end - local infers = {} - local defs = vm.getDefs(source) - for _, def in ipairs(defs) do - if def.type == 'doc.class.name' then - if not vm.isBuiltinType(def[1]) then - infers[def[1]] = true - end - end - end - cleanInfers(guide.getUri(source), infers) - local view = m.viewInfers(guide.getUri(source), infers) - if view == 'any' then - return nil - end - return view -end - -return m diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua index 1f021fb3..c9053fab 100644 --- a/script/core/type-definition.lua +++ b/script/core/type-definition.lua @@ -132,7 +132,7 @@ return function (uri, offset) end end - local defs = vm.getAllDefs(source) + local defs = vm.getDefs(source) local values = {} for _, src in ipairs(defs) do local value = searcher.getObjectValue(src) diff --git a/script/vm/getDef.lua b/script/vm/def.lua index 5edc6c67..49a4f7ef 100644 --- a/script/vm/getDef.lua +++ b/script/vm/def.lua @@ -248,9 +248,3 @@ function vm.getDefs(source) return results end - ----@param source parser.object ----@return parser.object[] -function vm.getAllDefs(source) - return vm.getDefs(source) -end diff --git a/script/vm/infer.lua b/script/vm/infer.lua new file mode 100644 index 00000000..8b2cf5ee --- /dev/null +++ b/script/vm/infer.lua @@ -0,0 +1,52 @@ +local util = require 'utility' +local nodeMgr = require 'vm.node' + +---@class vm.infer-manager +local m = {} + +local viewNodeMap = util.switch() + : case 'boolean' + : case 'string' + : case 'table' + : case 'function' + : case 'number' + : case 'integer' + : call(function (source) + return source.type + end) + : getMap() + +---@param node vm.node +---@return string? +local function viewNode(node) + if viewNodeMap[node.type] then + return viewNodeMap[node.type](node) + end +end + +---@param source parser.object +function m.viewType(source) + local compiler = require 'vm.compiler' + local node = compiler.compileNode(source) + local views = {} + for n in nodeMgr.eachNode(node) do + local view = viewNode(n) + if view then + views[view] = true + end + end + if views['number'] then + views['integer'] = nil + end + local array = {} + for view in pairs(views) do + array[#array+1] = view + end + if #array == 0 then + return 'unknown' + end + table.sort(array) + return table.concat(array, '|') +end + +return m diff --git a/script/vm/init.lua b/script/vm/init.lua index 79cecbb7..2669c849 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -1,8 +1,8 @@ local vm = require 'vm.vm' require 'vm.manager' +require 'vm.def' require 'vm.getDocs' require 'vm.getLibrary' -require 'vm.getDef' require 'vm.getRef' require 'vm.getLinks' return vm diff --git a/script/vm/node.lua b/script/vm/node.lua index a8a900ec..f9dbdf3e 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -5,6 +5,8 @@ local union = require 'vm.union' ---@class vm.node-manager local m = {} +local DUMMY_FUNCTION = function () end + ---@type table<parser.object, vm.node> m.nodeCache = {} @@ -38,6 +40,9 @@ end ---@return fun():vm.node function m.eachNode(node) + if not node then + return DUMMY_FUNCTION + end if node.type == 'union' then return node:eachNode() end @@ -52,8 +52,8 @@ end local function testAll() test 'basic' test 'definition' - test 'references' test 'type_inference' + test 'references' test 'hover' test 'completion' test 'crossfile' diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 46c5a4f0..07bbf195 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1,6 +1,6 @@ local files = require 'files' local guide = require 'parser.guide' -local infer = require 'core.infer' +local infer = require 'vm.infer' local config = require 'config' local catch = require 'catch' @@ -27,9 +27,9 @@ function TEST(wanted) files.setText('', newScript) local source = getSource(catched['?'][1][1]) assert(source) - local result = infer.searchAndViewInfers(source) + local result = infer.viewType(source) if wanted ~= result then - infer.searchAndViewInfers(source) + infer.viewType(source) end assert(wanted == result) files.remove('') @@ -52,12 +52,17 @@ TEST 'number' [[ local <?var?> = 1.0 ]] +TEST 'number' [[ +local <?var?> +var = 1 +var = 1.0 +]] + TEST 'string' [[ local var = '111' t.<?x?> = var ]] -config.set(nil, 'Lua.IntelliSense.traceLocalSet', true) TEST 'string' [[ local <?var?> var = '111' @@ -68,7 +73,6 @@ local var var = '111' print(<?var?>) ]] -config.set(nil, 'Lua.IntelliSense.traceLocalSet', false) TEST 'function' [[ function <?xx?>() @@ -80,13 +84,11 @@ local function <?xx?>() end ]] -config.set(nil, 'Lua.IntelliSense.traceLocalSet', true) TEST 'function' [[ local xx <?xx?> = function () end ]] -config.set(nil, 'Lua.IntelliSense.traceLocalSet', false) TEST 'table' [[ local <?t?> = {} @@ -100,7 +102,7 @@ TEST 'boolean' [[ <?x?> = not y ]] -TEST 'any' [[ +TEST 'integer' [[ <?x?> = #y ]] @@ -112,7 +114,7 @@ TEST 'integer' [[ <?x?> = #{} ]] -TEST 'any' [[ +TEST 'number' [[ <?x?> = - y ]] @@ -120,7 +122,7 @@ TEST 'number' [[ <?x?> = - 1.0 ]] -TEST 'any' [[ +TEST 'integer' [[ <?x?> = ~ y ]] |