summaryrefslogtreecommitdiff
path: root/script/vm/infer.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/infer.lua')
-rw-r--r--script/vm/infer.lua48
1 files changed, 30 insertions, 18 deletions
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 2a64ed52..7bb581cf 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -144,20 +144,23 @@ local viewNodeSwitch = util.switch()
local argView = ''
local regView = ''
for i, arg in ipairs(source.args) do
+ local argNode = vm.compileNode(arg)
+ local isOptional = argNode:isOptional()
+ if isOptional then
+ argNode = argNode:copy()
+ argNode:removeOptional()
+ end
args[i] = string.format('%s%s: %s'
, arg.name[1]
- , arg.optional and '?' or ''
- , m.getInfer(arg):view()
+ , isOptional and '?' or ''
+ , m.getInfer(argNode):view()
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
for i, ret in ipairs(source.returns) do
- rets[i] = string.format('%s%s'
- , m.getInfer(ret):view()
- , ret.optional and '?' or ''
- )
+ rets[i] = m.getInfer(ret):view()
end
if #rets > 0 then
regView = ':' .. table.concat(rets, ', ')
@@ -165,16 +168,21 @@ local viewNodeSwitch = util.switch()
return ('fun(%s)%s'):format(argView, regView)
end)
----@param source parser.object
+---@param source parser.object | vm.node
---@return vm.infer
function m.getInfer(source)
- local node = vm.compileNode(source)
+ local node
+ if source.type == 'vm.node' then
+ node = source
+ else
+ node = vm.compileNode(source)
+ end
if node.lastInfer then
return node.lastInfer
end
local infer = setmetatable({
node = node,
- uri = guide.getUri(source),
+ uri = source.type ~= 'vm.node' and guide.getUri(source),
}, mt)
node.lastInfer = infer
@@ -298,22 +306,26 @@ function mt:view(default, uri)
local max = #array
local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit')
+ local view
if max > limit then
- local view = string.format('%s...(+%d)'
+ view = string.format('%s...(+%d)'
, table.concat(array, '|', 1, limit)
, max - limit
)
-
- self.cachedView = view
-
- return view
else
- local view = table.concat(array, '|')
-
- self.cachedView = view
+ view = table.concat(array, '|')
+ end
- return view
+ if self.node:isOptional() then
+ if max > 1 then
+ view = '(' .. view .. ')?'
+ else
+ view = view .. '?'
+ end
end
+ self.cachedView = view
+
+ return view
end
function mt:eachView()