diff options
-rw-r--r-- | script/core/completion/completion.lua | 6 | ||||
-rw-r--r-- | script/core/diagnostics/no-implicit-any.lua | 2 | ||||
-rw-r--r-- | script/core/hint.lua | 3 | ||||
-rw-r--r-- | script/core/hover/arg.lua | 10 | ||||
-rw-r--r-- | script/core/hover/init.lua | 2 | ||||
-rw-r--r-- | script/core/hover/label.lua | 10 | ||||
-rw-r--r-- | script/core/hover/return.lua | 4 | ||||
-rw-r--r-- | script/vm/field.lua | 5 | ||||
-rw-r--r-- | script/vm/infer.lua | 161 | ||||
-rw-r--r-- | script/vm/union.lua | 2 | ||||
-rw-r--r-- | test/type_inference/init.lua | 4 |
11 files changed, 112 insertions, 97 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index f54bd2d1..133b262d 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -167,7 +167,7 @@ local function buildDetail(source) if source.type == 'dummy' then return end - local types = infer.viewType(source) + local types = infer.getInfer(source):view() local literals = infer.viewLiterals(source) if literals then return types .. ' = ' .. literals @@ -1819,14 +1819,14 @@ local function buildluaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = infer.viewType(arg) + args[#args+1] = infer.getInfer(arg):view() end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = infer.viewType(rtns[n]) + returns[n] = infer.getInfer(rtns[n]):view() end end end diff --git a/script/core/diagnostics/no-implicit-any.lua b/script/core/diagnostics/no-implicit-any.lua index 47f1b997..5c14d211 100644 --- a/script/core/diagnostics/no-implicit-any.lua +++ b/script/core/diagnostics/no-implicit-any.lua @@ -20,7 +20,7 @@ return function (uri, callback) and source.type ~= 'tableindex' then return end - if infer.viewType(source) == 'any' then + if infer.getInfer(source):view() == 'unknown' then callback { start = source.start, finish = source.finish, diff --git a/script/core/hint.lua b/script/core/hint.lua index 3b5db3e5..51842126 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -41,8 +41,9 @@ local function typeHint(uri, results, start, finish) end end await.delay() - local view = infer.viewType(source) + local view = infer.getInfer(source):view() if view == 'any' + or view == 'unknown' or view == 'nil' then return end diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua index c9c81a85..7611a895 100644 --- a/script/core/hover/arg.lua +++ b/script/core/hover/arg.lua @@ -21,7 +21,7 @@ local function asFunction(source, oop) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(infer.viewType(parent.node)) + args[#args+1] = ('self: %s'):format(infer.getInfer(parent.node)) end if source.args then for i = 1, #source.args do @@ -34,15 +34,15 @@ local function asFunction(source, oop) args[#args+1] = ('%s%s: %s'):format( name, optionalArg(arg) and '?' or '', - infer.viewType(arg, 'any') + infer.getInfer(arg):view 'any' ) elseif arg.type == '...' then args[#args+1] = ('%s: %s'):format( '...', - infer.viewType(arg, 'any') + infer.getInfer(arg):view 'any' ) else - args[#args+1] = ('%s'):format(infer.viewType(arg, 'any')) + args[#args+1] = ('%s'):format(infer.getInfer(arg):view 'any') end ::CONTINUE:: end @@ -65,7 +65,7 @@ local function asDocFunction(source, oop) args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - arg.extends and infer.viewType(arg.extends) or 'any' + arg.extends and infer.getInfer(arg.extends):view 'any' or 'any' ) end if oop then diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index fdfbd73d..bc2f40eb 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -40,7 +40,7 @@ local function getHover(source) end local oop - if infer.viewType(source) == 'function' then + if infer.getInfer(source):view() == 'function' then local hasFunc for _, def in ipairs(vm.getDefs(source)) do if guide.isOOP(def) then diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 01dd1143..150f0f24 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -35,7 +35,7 @@ local function asDocTypeName(source) end if doc.type == 'doc.alias.name' then local extends = doc.parent.extends - return lang.script('HOVER_EXTENDS', infer.viewType(extends)) + return lang.script('HOVER_EXTENDS', infer.getInfer(extends):view()) end end end @@ -43,12 +43,12 @@ end ---@async local function asValue(source, title) local name = buildName(source, false) or '' - local type = infer.viewType(source) + local type = infer.getInfer(source):view() local literal = infer.viewLiterals(source) local cont - if not infer.hasType(source, 'string') + if not infer.getInfer(source):hasType 'string' and not type:find('%[%]$') then - if infer.hasType(source, 'table') then + if infer.getInfer(source):hasType 'table' then cont = buildTable(source) end end @@ -131,7 +131,7 @@ local function asDocFieldName(source) break end end - local view = infer.viewType(docField.extends) + local view = infer.getInfer(docField.extends):view() if not class then return ('field ?.%s: %s'):format(name, view) end diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index e48febf3..cb8fa76f 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -67,7 +67,7 @@ local function asFunction(source) local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ') local text = ('%s%s%s'):format( name or '', - infer.viewType(rtn), + infer.getInfer(rtn):view(), doc and doc.optional and '?' or '' ) if i == 1 then @@ -87,7 +87,7 @@ local function asDocFunction(source) local returns = {} for i, rtn in ipairs(source.returns) do local rtnText = ('%s%s'):format( - infer.viewType(rtn), + infer.getInfer(rtn):view(), rtn.optional and '?' or '' ) if i == 1 then diff --git a/script/vm/field.lua b/script/vm/field.lua index 92448bb3..c30e112d 100644 --- a/script/vm/field.lua +++ b/script/vm/field.lua @@ -23,7 +23,10 @@ local function searchByNode(source, pushResult) if not node then return end - searchNodeSwitch(node.type, node, pushResult) + + for n in nodeMgr.eachNode(node) do + searchNodeSwitch(n.type, n, pushResult) + end end ---@param source parser.object diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 6457696a..31a08b74 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -7,6 +7,22 @@ local compiler = require 'vm.compiler' ---@class vm.infer-manager local m = {} +---@class vm.infer +---@field views table<string, boolean> +---@field source? parser.object +---@field cachedView? string +local mt = {} +mt.__index = mt +mt.hasNumber = false +mt.hasTable = false +mt.hasClass = false +mt.isParam = false +mt.isLocal = false +mt.hasDocFunction = false +mt.expandAlias = false + +local nullInfer = setmetatable({ views = {} }, mt) + local inferSorted = { ['boolean'] = - 100, ['string'] = - 99, @@ -23,44 +39,44 @@ local viewNodeSwitch = util.switch() : case 'string' : case 'function' : case 'integer' - : call(function (source, options) + : call(function (source, infer) return source.type end) : case 'number' - : call(function (source, options) - options['hasNumber'] = true + : call(function (source, infer) + infer.hasNumber = true return source.type end) : case 'table' - : call(function (source, options) - options['hasTable'] = true + : call(function (source, infer) + infer.hasTable = true end) : case 'local' - : call(function (source, options) + : call(function (source, infer) if source.parent == 'funcargs' then - options['isParam'] = true + infer.isParam = true else - options['isLocal'] = true + infer.isLocal = true end end) : case 'global' - : call(function (source, options) + : call(function (source, infer) if source.cate == 'type' then - options['hasClass'] = true + infer.hasClass = true return source.name end end) : case 'doc.type.integer' - : call(function (source, options) + : call(function (source, infer) return ('%d'):format(source[1]) end) : case 'doc.type.name' - : call(function (source, options) - options['hasClass'] = true + : call(function (source, infer) + infer.hasClass = true if source.signs then local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = m.viewType(sign) + buf[i] = m.getInfer(sign):view() end return ('%s<%s>'):format(source[1], table.concat(buf, ', ')) else @@ -68,25 +84,25 @@ local viewNodeSwitch = util.switch() end end) : case 'doc.generic.name' - : call(function (source, options) + : call(function (source, infer) return ('<%s>'):format(source[1]) end) : case 'doc.type.array' - : call(function (source, options) - options['hasClass'] = true - return m.viewType(source.node) .. '[]' + : call(function (source, infer) + infer.hasClass = true + return m.getInfer(source.node):view() .. '[]' end) : case 'doc.type.table' - : call(function (source, options) - options['hasTable'] = true + : call(function (source, infer) + infer.hasTable = true end) : case 'doc.type.string' - : call(function (source, options) + : call(function (source, infer) return ('%q'):format(source[1]) end) : case 'doc.type.function' - : call(function (source, options) - options['hasDocFunction'] = true + : call(function (source, infer) + infer.hasDocFunction = true local args = {} local rets = {} local argView = '' @@ -95,7 +111,7 @@ local viewNodeSwitch = util.switch() args[i] = string.format('%s%s: %s' , arg.name[1] , arg.optional and '?' or '' - , m.viewType(arg) + , m.getInfer(arg):view() ) end if #args > 0 then @@ -103,7 +119,7 @@ local viewNodeSwitch = util.switch() end for i, ret in ipairs(source.returns) do rets[i] = string.format('%s%s' - , m.viewType(ret) + , m.getInfer(ret):view() , ret.optional and '?' or '' ) end @@ -113,24 +129,20 @@ local viewNodeSwitch = util.switch() return ('fun(%s)%s'):format(argView, regView) end) ----@param node vm.node ----@return string? -local function viewNode(node, options) - return viewNodeSwitch(node.type, node, options) -end - -local function eraseAlias(node, viewMap, options) +---@param infer vm.infer +local function eraseAlias(infer) + local node = compiler.compileNode(infer.source) for n in nodeMgr.eachNode(node) do if n.type == 'global' and n.cate == 'type' then for _, set in ipairs(n:getSets()) do if set.type == 'doc.alias' then - if options['expandAlias'] then - viewMap[n.name] = nil + if infer.expandAlias then + infer.views[n.name] = nil else for _, ext in ipairs(set.extends.types) do - local view = viewNode(ext, {}) + local view = viewNodeSwitch(ext.type, ext, {}) if view and view ~= n.name then - viewMap[view] = nil + infer.views[view] = nil end end end @@ -141,71 +153,67 @@ local function eraseAlias(node, viewMap, options) end ---@param source parser.object ----@return table<string, boolean> ----@return table<string, boolean> -function m.getViews(source) +---@return vm.infer +function m.getInfer(source) local node = compiler.compileNode(source) if not node then - return {} + return nullInfer end - if node.type == 'union' and node.lastViews then - return node.lastViews + if node.type == 'union' and node.lastInfer then + return node.lastInfer end - local views = {} - local options = {} - options['expandAlias'] = config.get(guide.getUri(source), 'Lua.hover.expandAlias') + local infer = setmetatable({ + source = source, + views = {} + }, mt) + infer.expandAlias = config.get(guide.getUri(source), 'Lua.hover.expandAlias') if node.type == 'union' then - node.lastViews = views + node.lastInfer = infer end for n in nodeMgr.eachNode(node) do - local view = viewNode(n, options) + local view = viewNodeSwitch(n.type, n, infer) if view then - views[view] = true + infer.views[view] = true end end - if options['hasNumber'] then - views['integer'] = nil + if infer.hasNumber then + infer.views['integer'] = nil end - if options['hasDocFunction'] then - views['function'] = nil + if infer.hasDocFunction then + infer.views['function'] = nil end - if options['hasTable'] and not options['hasClass'] then - views['table'] = true + if infer.hasTable and not infer.hasClass then + infer.views['table'] = true end - if options['hasClass'] then - eraseAlias(node, views, options) + if infer.hasClass then + eraseAlias(infer) end - return views, options + return infer end ----@param source parser.object ---@param tp string ---@return boolean -function m.hasType(source, tp) - local views = m.getViews(source) - - if views[tp] then - return true - end - - return false +function mt:hasType(tp) + return self.views[tp] == true end ----@param source parser.object +---@param default? string ---@return string -function m.viewType(source, default) - local views = m.getViews(source) - - if views['any'] then +function mt:view(default) + if self.views['any'] then return 'any' end - if not next(views) then + if not next(self.views) then return default or 'unknown' end + if self.cachedView then + return self.cachedView + end + local array = {} - for view in pairs(views) do + for view in pairs(self.views) do array[#array+1] = view end @@ -219,7 +227,7 @@ function m.viewType(source, default) end) local max = #array - local limit = config.get(guide.getUri(source), 'Lua.hover.enumsLimit') + local limit = config.get(guide.getUri(self.source), 'Lua.hover.enumsLimit') if max > limit then local view = string.format('%s...(+%d)' @@ -227,13 +235,16 @@ function m.viewType(source, default) , max - limit ) + self.cachedView = view + return view else local view = table.concat(array, '|') + self.cachedView = view + return view end - end ---@param source parser.object diff --git a/script/vm/union.lua b/script/vm/union.lua index 183f3440..10aebda1 100644 --- a/script/vm/union.lua +++ b/script/vm/union.lua @@ -5,7 +5,7 @@ local mt = {} mt.__index = mt mt.type = 'union' mt.optional = nil -mt.lastViews = nil +mt.lastInfer = nil ---@param me parser.object ---@param node vm.node diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index e687afe1..1a2c4593 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -31,9 +31,9 @@ function TEST(wanted) files.setText('', newScript) local source = getSource(catched['?'][1][1]) assert(source) - local result = infer.viewType(source) + local result = infer.getInfer(source):view() if wanted ~= result then - infer.viewType(source) + infer.getInfer(source):view() end assert(wanted == result) files.remove('') |