diff options
Diffstat (limited to 'script/vm/infer.lua')
-rw-r--r-- | script/vm/infer.lua | 115 |
1 files changed, 63 insertions, 52 deletions
diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 2a64ed52..fabc9828 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -1,11 +1,9 @@ local util = require 'utility' local config = require 'config' local guide = require 'parser.guide' +---@class vm local vm = require 'vm.vm' ----@class vm.infer-manager -local m = {} - ---@class vm.infer ---@field views table<string, boolean> ---@field cachedView? string @@ -21,7 +19,7 @@ mt._hasDocFunction = false mt._isParam = false mt._isLocal = false -m.NULL = setmetatable({}, mt) +vm.NULL = setmetatable({}, mt) local inferSorted = { ['boolean'] = - 100, @@ -52,7 +50,7 @@ local viewNodeSwitch = util.switch() : call(function (source, infer) if source.type == 'table' then if #source == 1 and source[1].type == 'varargs' then - local node = m.getInfer(source[1]):view() + local node = vm.getInfer(source[1]):view() return ('%s[]'):format(node) end end @@ -90,7 +88,7 @@ local viewNodeSwitch = util.switch() if source.signs then local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = m.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view() end return ('%s<%s>'):format(source[1], table.concat(buf, ', ')) else @@ -99,7 +97,7 @@ local viewNodeSwitch = util.switch() end) : case 'generic' : call(function (source, infer) - return m.getInfer(source.proto):view() + return vm.getInfer(source.proto):view() end) : case 'doc.generic.name' : call(function (source, infer) @@ -108,7 +106,7 @@ local viewNodeSwitch = util.switch() : case 'doc.type.array' : call(function (source, infer) infer._hasClass = true - local view = m.getInfer(source.node):view() + local view = vm.getInfer(source.node):view() if source.node.type == 'doc.type' then view = '(' .. view .. ')' end @@ -119,7 +117,7 @@ local viewNodeSwitch = util.switch() infer._hasClass = true local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = m.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view() end return ('%s<%s>'):format(source.node[1], table.concat(buf, ', ')) end) @@ -144,20 +142,23 @@ local viewNodeSwitch = util.switch() local argView = '' local regView = '' for i, arg in ipairs(source.args) do + local argNode = vm.compileNode(arg) + local isOptional = argNode:isOptional() + if isOptional then + argNode = argNode:copy() + argNode:removeOptional() + end args[i] = string.format('%s%s: %s' , arg.name[1] - , arg.optional and '?' or '' - , m.getInfer(arg):view() + , isOptional and '?' or '' + , vm.getInfer(argNode):view() ) end if #args > 0 then argView = table.concat(args, ', ') end for i, ret in ipairs(source.returns) do - rets[i] = string.format('%s%s' - , m.getInfer(ret):view() - , ret.optional and '?' or '' - ) + rets[i] = vm.getInfer(ret):view() end if #rets > 0 then regView = ':' .. table.concat(rets, ', ') @@ -165,16 +166,21 @@ local viewNodeSwitch = util.switch() return ('fun(%s)%s'):format(argView, regView) end) ----@param source parser.object +---@param source parser.object | vm.node ---@return vm.infer -function m.getInfer(source) - local node = vm.compileNode(source) +function vm.getInfer(source) + local node + if source.type == 'vm.node' then + node = source + else + node = vm.compileNode(source) + end if node.lastInfer then return node.lastInfer end local infer = setmetatable({ node = node, - uri = guide.getUri(source), + uri = source.type ~= 'vm.node' and guide.getUri(source), }, mt) node.lastInfer = infer @@ -199,24 +205,24 @@ function mt:_trim() if self._hasTable and not self._hasClass then self.views['table'] = true end - if self._hasClass then - self:_eraseAlias() - end end -function mt:_eraseAlias() - local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias') +---@param uri uri +---@return table<string, true> +function mt:_eraseAlias(uri) + local drop = {} + local expandAlias = config.get(uri, 'Lua.hover.expandAlias') for n in self.node:eachObject() do if n.type == 'global' and n.cate == 'type' then - for _, set in ipairs(n:getSets(self.uri)) do + for _, set in ipairs(n:getSets(uri)) do if set.type == 'doc.alias' then if expandAlias then - self.views[n.name] = nil + drop[n.name] = true else for _, ext in ipairs(set.extends.types) do local view = viewNodeSwitch(ext.type, ext, {}) if view and view ~= n.name then - self.views[view] = nil + drop[view] = true end end end @@ -224,6 +230,7 @@ function mt:_eraseAlias() end end end + return drop end ---@param tp string @@ -273,17 +280,16 @@ function mt:view(default, uri) return 'any' end - if not next(self.views) then - return default or 'unknown' - end - - if self.cachedView then - return self.cachedView + local drop + if self._hasClass then + drop = self:_eraseAlias(uri or self.uri) end local array = {} for view in pairs(self.views) do - array[#array+1] = view + if not drop or not drop[view] then + array[#array+1] = view + end end table.sort(array, function (a, b) @@ -298,22 +304,29 @@ function mt:view(default, uri) local max = #array local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') - if max > limit then - local view = string.format('%s...(+%d)' - , table.concat(array, '|', 1, limit) - , max - limit - ) - - self.cachedView = view - - return view + local view + if #array == 0 then + view = default or 'unknown' else - local view = table.concat(array, '|') - - self.cachedView = view + if max > limit then + view = string.format('%s...(+%d)' + , table.concat(array, '|', 1, limit) + , max - limit + ) + else + view = table.concat(array, '|') + end + end - return view + if self.node:isOptional() then + if max > 1 then + view = '(' .. view .. ')?' + else + view = view .. '?' + end end + + return view end function mt:eachView() @@ -324,10 +337,10 @@ end ---@param other vm.infer ---@return vm.infer function mt:merge(other) - if self == m.NULL then + if self == vm.NULL then return other end - if other == m.NULL then + if other == vm.NULL then return self end @@ -390,8 +403,6 @@ end ---@param source parser.object ---@return string? -function m.viewObject(source) +function vm.viewObject(source) return viewNodeSwitch(source.type, source, {}) end - -return m |