diff options
Diffstat (limited to 'script/vm/infer.lua')
-rw-r--r-- | script/vm/infer.lua | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 2a64ed52..7bb581cf 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -144,20 +144,23 @@ local viewNodeSwitch = util.switch() local argView = '' local regView = '' for i, arg in ipairs(source.args) do + local argNode = vm.compileNode(arg) + local isOptional = argNode:isOptional() + if isOptional then + argNode = argNode:copy() + argNode:removeOptional() + end args[i] = string.format('%s%s: %s' , arg.name[1] - , arg.optional and '?' or '' - , m.getInfer(arg):view() + , isOptional and '?' or '' + , m.getInfer(argNode):view() ) end if #args > 0 then argView = table.concat(args, ', ') end for i, ret in ipairs(source.returns) do - rets[i] = string.format('%s%s' - , m.getInfer(ret):view() - , ret.optional and '?' or '' - ) + rets[i] = m.getInfer(ret):view() end if #rets > 0 then regView = ':' .. table.concat(rets, ', ') @@ -165,16 +168,21 @@ local viewNodeSwitch = util.switch() return ('fun(%s)%s'):format(argView, regView) end) ----@param source parser.object +---@param source parser.object | vm.node ---@return vm.infer function m.getInfer(source) - local node = vm.compileNode(source) + local node + if source.type == 'vm.node' then + node = source + else + node = vm.compileNode(source) + end if node.lastInfer then return node.lastInfer end local infer = setmetatable({ node = node, - uri = guide.getUri(source), + uri = source.type ~= 'vm.node' and guide.getUri(source), }, mt) node.lastInfer = infer @@ -298,22 +306,26 @@ function mt:view(default, uri) local max = #array local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') + local view if max > limit then - local view = string.format('%s...(+%d)' + view = string.format('%s...(+%d)' , table.concat(array, '|', 1, limit) , max - limit ) - - self.cachedView = view - - return view else - local view = table.concat(array, '|') - - self.cachedView = view + view = table.concat(array, '|') + end - return view + if self.node:isOptional() then + if max > 1 then + view = '(' .. view .. ')?' + else + view = view .. '?' + end end + self.cachedView = view + + return view end function mt:eachView() |