summaryrefslogtreecommitdiff
path: root/script/vm/infer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/infer.lua')
-rw-r--r--script/vm/infer.lua115
1 files changed, 63 insertions, 52 deletions
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 2a64ed52..fabc9828 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -1,11 +1,9 @@
local util = require 'utility'
local config = require 'config'
local guide = require 'parser.guide'
+---@class vm
local vm = require 'vm.vm'
----@class vm.infer-manager
-local m = {}
-
---@class vm.infer
---@field views table<string, boolean>
---@field cachedView? string
@@ -21,7 +19,7 @@ mt._hasDocFunction = false
mt._isParam = false
mt._isLocal = false
-m.NULL = setmetatable({}, mt)
+vm.NULL = setmetatable({}, mt)
local inferSorted = {
['boolean'] = - 100,
@@ -52,7 +50,7 @@ local viewNodeSwitch = util.switch()
: call(function (source, infer)
if source.type == 'table' then
if #source == 1 and source[1].type == 'varargs' then
- local node = m.getInfer(source[1]):view()
+ local node = vm.getInfer(source[1]):view()
return ('%s[]'):format(node)
end
end
@@ -90,7 +88,7 @@ local viewNodeSwitch = util.switch()
if source.signs then
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = m.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view()
end
return ('%s<%s>'):format(source[1], table.concat(buf, ', '))
else
@@ -99,7 +97,7 @@ local viewNodeSwitch = util.switch()
end)
: case 'generic'
: call(function (source, infer)
- return m.getInfer(source.proto):view()
+ return vm.getInfer(source.proto):view()
end)
: case 'doc.generic.name'
: call(function (source, infer)
@@ -108,7 +106,7 @@ local viewNodeSwitch = util.switch()
: case 'doc.type.array'
: call(function (source, infer)
infer._hasClass = true
- local view = m.getInfer(source.node):view()
+ local view = vm.getInfer(source.node):view()
if source.node.type == 'doc.type' then
view = '(' .. view .. ')'
end
@@ -119,7 +117,7 @@ local viewNodeSwitch = util.switch()
infer._hasClass = true
local buf = {}
for i, sign in ipairs(source.signs) do
- buf[i] = m.getInfer(sign):view()
+ buf[i] = vm.getInfer(sign):view()
end
return ('%s<%s>'):format(source.node[1], table.concat(buf, ', '))
end)
@@ -144,20 +142,23 @@ local viewNodeSwitch = util.switch()
local argView = ''
local regView = ''
for i, arg in ipairs(source.args) do
+ local argNode = vm.compileNode(arg)
+ local isOptional = argNode:isOptional()
+ if isOptional then
+ argNode = argNode:copy()
+ argNode:removeOptional()
+ end
args[i] = string.format('%s%s: %s'
, arg.name[1]
- , arg.optional and '?' or ''
- , m.getInfer(arg):view()
+ , isOptional and '?' or ''
+ , vm.getInfer(argNode):view()
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
for i, ret in ipairs(source.returns) do
- rets[i] = string.format('%s%s'
- , m.getInfer(ret):view()
- , ret.optional and '?' or ''
- )
+ rets[i] = vm.getInfer(ret):view()
end
if #rets > 0 then
regView = ':' .. table.concat(rets, ', ')
@@ -165,16 +166,21 @@ local viewNodeSwitch = util.switch()
return ('fun(%s)%s'):format(argView, regView)
end)
----@param source parser.object
+---@param source parser.object | vm.node
---@return vm.infer
-function m.getInfer(source)
- local node = vm.compileNode(source)
+function vm.getInfer(source)
+ local node
+ if source.type == 'vm.node' then
+ node = source
+ else
+ node = vm.compileNode(source)
+ end
if node.lastInfer then
return node.lastInfer
end
local infer = setmetatable({
node = node,
- uri = guide.getUri(source),
+ uri = source.type ~= 'vm.node' and guide.getUri(source),
}, mt)
node.lastInfer = infer
@@ -199,24 +205,24 @@ function mt:_trim()
if self._hasTable and not self._hasClass then
self.views['table'] = true
end
- if self._hasClass then
- self:_eraseAlias()
- end
end
-function mt:_eraseAlias()
- local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias')
+---@param uri uri
+---@return table<string, true>
+function mt:_eraseAlias(uri)
+ local drop = {}
+ local expandAlias = config.get(uri, 'Lua.hover.expandAlias')
for n in self.node:eachObject() do
if n.type == 'global' and n.cate == 'type' then
- for _, set in ipairs(n:getSets(self.uri)) do
+ for _, set in ipairs(n:getSets(uri)) do
if set.type == 'doc.alias' then
if expandAlias then
- self.views[n.name] = nil
+ drop[n.name] = true
else
for _, ext in ipairs(set.extends.types) do
local view = viewNodeSwitch(ext.type, ext, {})
if view and view ~= n.name then
- self.views[view] = nil
+ drop[view] = true
end
end
end
@@ -224,6 +230,7 @@ function mt:_eraseAlias()
end
end
end
+ return drop
end
---@param tp string
@@ -273,17 +280,16 @@ function mt:view(default, uri)
return 'any'
end
- if not next(self.views) then
- return default or 'unknown'
- end
-
- if self.cachedView then
- return self.cachedView
+ local drop
+ if self._hasClass then
+ drop = self:_eraseAlias(uri or self.uri)
end
local array = {}
for view in pairs(self.views) do
- array[#array+1] = view
+ if not drop or not drop[view] then
+ array[#array+1] = view
+ end
end
table.sort(array, function (a, b)
@@ -298,22 +304,29 @@ function mt:view(default, uri)
local max = #array
local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit')
- if max > limit then
- local view = string.format('%s...(+%d)'
- , table.concat(array, '|', 1, limit)
- , max - limit
- )
-
- self.cachedView = view
-
- return view
+ local view
+ if #array == 0 then
+ view = default or 'unknown'
else
- local view = table.concat(array, '|')
-
- self.cachedView = view
+ if max > limit then
+ view = string.format('%s...(+%d)'
+ , table.concat(array, '|', 1, limit)
+ , max - limit
+ )
+ else
+ view = table.concat(array, '|')
+ end
+ end
- return view
+ if self.node:isOptional() then
+ if max > 1 then
+ view = '(' .. view .. ')?'
+ else
+ view = view .. '?'
+ end
end
+
+ return view
end
function mt:eachView()
@@ -324,10 +337,10 @@ end
---@param other vm.infer
---@return vm.infer
function mt:merge(other)
- if self == m.NULL then
+ if self == vm.NULL then
return other
end
- if other == m.NULL then
+ if other == vm.NULL then
return self
end
@@ -390,8 +403,6 @@ end
---@param source parser.object
---@return string?
-function m.viewObject(source)
+function vm.viewObject(source)
return viewNodeSwitch(source.type, source, {})
end
-
-return m