diff options
Diffstat (limited to 'script/vm/infer.lua')
-rw-r--r-- | script/vm/infer.lua | 211 |
1 files changed, 161 insertions, 50 deletions
diff --git a/script/vm/infer.lua b/script/vm/infer.lua index fabc9828..263b2500 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -8,10 +8,9 @@ local vm = require 'vm.vm' ---@field views table<string, boolean> ---@field cachedView? string ---@field node? vm.node ----@field uri? uri +---@field _drop table local mt = {} mt.__index = mt -mt._hasNumber = false mt._hasTable = false mt._hasClass = false mt._hasFunctionDef = false @@ -21,6 +20,8 @@ mt._isLocal = false vm.NULL = setmetatable({}, mt) +local LOCK = {} + local inferSorted = { ['boolean'] = - 100, ['string'] = - 99, @@ -43,14 +44,13 @@ local viewNodeSwitch = util.switch() end) : case 'number' : call(function (source, infer) - infer._hasNumber = true return source.type end) : case 'table' - : call(function (source, infer) + : call(function (source, infer, uri) if source.type == 'table' then if #source == 1 and source[1].type == 'varargs' then - local node = vm.getInfer(source[1]):view() + local node = vm.getInfer(source[1]):view(uri) return ('%s[]'):format(node) end end @@ -76,19 +76,18 @@ local viewNodeSwitch = util.switch() : case 'global' : call(function (source, infer) if source.cate == 'type' then - infer._hasClass = true - if source.name == 'number' then - infer._hasNumber = true + if not guide.isBasicType(source.name) then + infer._hasClass = true end return source.name end end) : case 'doc.type.name' - : call(function (source, infer) + : call(function (source, infer, uri) if source.signs then local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = vm.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view(uri) end return ('%s<%s>'):format(source[1], table.concat(buf, ', ')) else @@ -96,34 +95,68 @@ local viewNodeSwitch = util.switch() end end) : case 'generic' - : call(function (source, infer) - return vm.getInfer(source.proto):view() + : call(function (source, infer, uri) + return vm.getInfer(source.proto):view(uri) end) : case 'doc.generic.name' : call(function (source, infer) return ('<%s>'):format(source[1]) end) : case 'doc.type.array' - : call(function (source, infer) + : call(function (source, infer, uri) infer._hasClass = true - local view = vm.getInfer(source.node):view() + local view = vm.getInfer(source.node):view(uri) if source.node.type == 'doc.type' then view = '(' .. view .. ')' end return view .. '[]' end) : case 'doc.type.sign' - : call(function (source, infer) + : call(function (source, infer, uri) infer._hasClass = true local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = vm.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view(uri) + end + if infer._drop then + local node = vm.compileNode(source) + for c in node:eachObject() do + if guide.isLiteral(c) then + infer._drop[c] = true + end + end end return ('%s<%s>'):format(source.node[1], table.concat(buf, ', ')) end) : case 'doc.type.table' - : call(function (source, infer) - infer._hasTable = true + : call(function (source, infer, uri) + if #source.fields == 0 then + infer._hasTable = true + return + end + if infer._drop and infer._drop[source] then + infer._hasTable = true + return + end + infer._hasClass = true + local buf = {} + buf[#buf+1] = '{ ' + for i, field in ipairs(source.fields) do + if i > 1 then + buf[#buf+1] = ', ' + end + local key = field.name + if key.type == 'doc.type' then + buf[#buf+1] = ('[%s]: '):format(vm.getInfer(key):view(uri)) + elseif type(key[1]) == 'string' then + buf[#buf+1] = key[1] .. ': ' + else + buf[#buf+1] = ('[%q]: '):format(key[1]) + end + buf[#buf+1] = vm.getInfer(field.extends):view(uri) + end + buf[#buf+1] = ' }' + return table.concat(buf) end) : case 'doc.type.string' : call(function (source, infer) @@ -134,8 +167,12 @@ local viewNodeSwitch = util.switch() : call(function (source, infer) return ('%q'):format(source[1]) end) - : case 'doc.type.function' + : case 'doc.type.code' : call(function (source, infer) + return ('`%s`'):format(source[1]) + end) + : case 'doc.type.function' + : call(function (source, infer, uri) infer._hasDocFunction = true local args = {} local rets = {} @@ -148,31 +185,53 @@ local viewNodeSwitch = util.switch() argNode = argNode:copy() argNode:removeOptional() end - args[i] = string.format('%s%s: %s' + args[i] = string.format('%s%s%s%s' , arg.name[1] , isOptional and '?' or '' - , vm.getInfer(argNode):view() + , arg.name[1] == '...' and '' or ': ' + , vm.getInfer(argNode):view(uri) ) end if #args > 0 then argView = table.concat(args, ', ') end + local needReturnParen for i, ret in ipairs(source.returns) do - rets[i] = vm.getInfer(ret):view() + local retType = vm.getInfer(ret):view(uri) + if ret.name then + if ret.name[1] == '...' then + rets[i] = ('%s%s'):format(ret.name[1], retType) + else + needReturnParen = true + rets[i] = ('%s: %s'):format(ret.name[1], retType) + end + else + rets[i] = retType + end end if #rets > 0 then - regView = ':' .. table.concat(rets, ', ') + if needReturnParen then + regView = (':(%s)'):format(table.concat(rets, ', ')) + else + regView = (':%s'):format(table.concat(rets, ', ')) + end end return ('fun(%s)%s'):format(argView, regView) end) ----@param source parser.object | vm.node +---@class vm.node +---@field lastInfer? vm.infer + +---@param source vm.object | vm.node ---@return vm.infer function vm.getInfer(source) + ---@type vm.node local node if source.type == 'vm.node' then + ---@cast source vm.node node = source else + ---@cast source vm.object node = vm.compileNode(source) end if node.lastInfer then @@ -180,7 +239,7 @@ function vm.getInfer(source) end local infer = setmetatable({ node = node, - uri = source.type ~= 'vm.node' and guide.getUri(source), + _drop = {}, }, mt) node.lastInfer = infer @@ -188,9 +247,6 @@ function vm.getInfer(source) end function mt:_trim() - if self._hasNumber then - self.views['integer'] = nil - end if self._hasDocFunction then if self._hasFunctionDef then for view in pairs(self.views) do @@ -205,6 +261,13 @@ function mt:_trim() if self._hasTable and not self._hasClass then self.views['table'] = true end + if self.views['number'] then + self.views['integer'] = nil + end + if self.views['boolean'] then + self.views['true'] = nil + self.views['false'] = nil + end end ---@param uri uri @@ -214,46 +277,86 @@ function mt:_eraseAlias(uri) local expandAlias = config.get(uri, 'Lua.hover.expandAlias') for n in self.node:eachObject() do if n.type == 'global' and n.cate == 'type' then + if LOCK[n.name] then + goto CONTINUE + end + LOCK[n.name] = true for _, set in ipairs(n:getSets(uri)) do if set.type == 'doc.alias' then if expandAlias then drop[n.name] = true + local newInfer = {} + for _, ext in ipairs(set.extends.types) do + viewNodeSwitch(ext.type, ext, newInfer, uri) + end + if newInfer._hasTable then + self.views['table'] = true + end else for _, ext in ipairs(set.extends.types) do - local view = viewNodeSwitch(ext.type, ext, {}) + local view = viewNodeSwitch(ext.type, ext, {}, uri) if view and view ~= n.name then drop[view] = true end end end end + if set.type == 'doc.class' then + if set.extends then + for _, ext in ipairs(set.extends) do + if ext.type == 'doc.extends.name' then + local view = ext[1] + drop[view] = true + end + end + end + end end + LOCK[n.name] = nil + ::CONTINUE:: end end return drop end +---@param uri uri ---@param tp string ---@return boolean -function mt:hasType(tp) - self:_computeViews() +function mt:hasType(uri, tp) + self:_computeViews(uri) return self.views[tp] == true end +---@param uri uri +function mt:hasUnknown(uri) + self:_computeViews(uri) + return not next(self.views) + or self.views['unknown'] == true +end + +---@param uri uri +function mt:hasAny(uri) + self:_computeViews(uri) + return self.views['any'] == true +end + +---@param uri uri ---@return boolean -function mt:hasClass() - self:_computeViews() +function mt:hasClass(uri) + self:_computeViews(uri) return self._hasClass == true end +---@param uri uri ---@return boolean -function mt:hasFunction() - self:_computeViews() +function mt:hasFunction(uri) + self:_computeViews(uri) return self.views['function'] == true or self._hasDocFunction == true end -function mt:_computeViews() +---@param uri uri +function mt:_computeViews(uri) if self.views then return end @@ -261,7 +364,7 @@ function mt:_computeViews() self.views = {} for n in self.node:eachObject() do - local view = viewNodeSwitch(n.type, n, self) + local view = viewNodeSwitch(n.type, n, self, uri) if view then self.views[view] = true end @@ -270,11 +373,11 @@ function mt:_computeViews() self:_trim() end +---@param uri uri ---@param default? string ----@param uri? uri ---@return string -function mt:view(default, uri) - self:_computeViews() +function mt:view(uri, default) + self:_computeViews(uri) if self.views['any'] then return 'any' @@ -282,7 +385,7 @@ function mt:view(default, uri) local drop if self._hasClass then - drop = self:_eraseAlias(uri or self.uri) + drop = self:_eraseAlias(uri) end local array = {} @@ -302,7 +405,7 @@ function mt:view(default, uri) end) local max = #array - local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') + local limit = config.get(uri, 'Lua.hover.enumsLimit') local view if #array == 0 then @@ -329,8 +432,9 @@ function mt:view(default, uri) return view end -function mt:eachView() - self:_computeViews() +---@param uri uri +function mt:eachView(uri) + self:_computeViews(uri) return next, self.views end @@ -346,7 +450,6 @@ function mt:merge(other) local infer = setmetatable({ node = vm.createNode(self.node, other.node), - uri = self.uri, }, mt) return infer @@ -365,7 +468,7 @@ function mt:viewLiterals() or n.type == 'integer' or n.type == 'boolean' then local literal = util.viewLiteral(n[1]) - if not mark[literal] then + if literal and not mark[literal] then literals[#literals+1] = literal mark[literal] = true end @@ -374,7 +477,14 @@ function mt:viewLiterals() if #literals == 0 then return nil end - table.sort(literals) + table.sort(literals, function (a, b) + local sa = inferSorted[a] or 0 + local sb = inferSorted[b] or 0 + if sa == sb then + return a < b + end + return sa < sb + end) return table.concat(literals, '|') end @@ -401,8 +511,9 @@ function mt:viewClass() return table.concat(class, '|') end ----@param source parser.object +---@param source vm.node.object +---@param uri uri ---@return string? -function vm.viewObject(source) - return viewNodeSwitch(source.type, source, {}) +function vm.viewObject(source, uri) + return viewNodeSwitch(source.type, source, {}, uri) end |