diff options
-rw-r--r-- | script/core/hover/return.lua | 10 | ||||
-rw-r--r-- | script/vm/compiler.lua | 9 | ||||
-rw-r--r-- | script/vm/infer.lua | 48 | ||||
-rw-r--r-- | script/vm/node.lua | 34 | ||||
-rw-r--r-- | script/vm/sign.lua | 2 | ||||
-rw-r--r-- | script/vm/value.lua | 19 | ||||
-rw-r--r-- | test/hover/init.lua | 2 | ||||
-rw-r--r-- | test/type_inference/init.lua | 7 |
8 files changed, 64 insertions, 67 deletions
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index 77710148..93ff1f6c 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -65,10 +65,9 @@ local function asFunction(source) local rtn = vm.getReturnOfFunction(source, i) local doc = docs[i] local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ') - local text = ('%s%s%s'):format( + local text = ('%s%s'):format( name or '', - infer.getInfer(rtn):view(), - doc and doc.optional and '?' or '' + infer.getInfer(rtn):view() ) if i == 1 then returns[i] = (' -> %s'):format(text) @@ -86,10 +85,7 @@ local function asDocFunction(source) end local returns = {} for i, rtn in ipairs(source.returns) do - local rtnText = ('%s%s'):format( - infer.getInfer(rtn):view(), - rtn.optional and '?' or '' - ) + local rtnText = infer.getInfer(rtn):view() if i == 1 then returns[#returns+1] = (' -> %s'):format(rtnText) else diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 8126f393..8df28e7b 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -455,6 +455,9 @@ local function getReturn(func, index, args) result:merge(rnode) end end + if result and returnNode:isOptional() then + result:addOptional() + end end end end @@ -1115,6 +1118,9 @@ local compilerSwitch = util.switch() for _, typeUnit in ipairs(source.types) do vm.setNode(source, vm.compileNode(typeUnit)) end + if source.optional then + vm.getNode(source):addOptional() + end end) : case 'doc.type.integer' : case 'doc.type.string' @@ -1220,6 +1226,9 @@ local compilerSwitch = util.switch() else vm.setNode(source, globalMgr.getGlobal('type', 'any')) end + if source.optional then + vm.getNode(source):addOptional() + end end) : case 'generic' : call(function (source) diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 2a64ed52..7bb581cf 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -144,20 +144,23 @@ local viewNodeSwitch = util.switch() local argView = '' local regView = '' for i, arg in ipairs(source.args) do + local argNode = vm.compileNode(arg) + local isOptional = argNode:isOptional() + if isOptional then + argNode = argNode:copy() + argNode:removeOptional() + end args[i] = string.format('%s%s: %s' , arg.name[1] - , arg.optional and '?' or '' - , m.getInfer(arg):view() + , isOptional and '?' or '' + , m.getInfer(argNode):view() ) end if #args > 0 then argView = table.concat(args, ', ') end for i, ret in ipairs(source.returns) do - rets[i] = string.format('%s%s' - , m.getInfer(ret):view() - , ret.optional and '?' or '' - ) + rets[i] = m.getInfer(ret):view() end if #rets > 0 then regView = ':' .. table.concat(rets, ', ') @@ -165,16 +168,21 @@ local viewNodeSwitch = util.switch() return ('fun(%s)%s'):format(argView, regView) end) ----@param source parser.object +---@param source parser.object | vm.node ---@return vm.infer function m.getInfer(source) - local node = vm.compileNode(source) + local node + if source.type == 'vm.node' then + node = source + else + node = vm.compileNode(source) + end if node.lastInfer then return node.lastInfer end local infer = setmetatable({ node = node, - uri = guide.getUri(source), + uri = source.type ~= 'vm.node' and guide.getUri(source), }, mt) node.lastInfer = infer @@ -298,22 +306,26 @@ function mt:view(default, uri) local max = #array local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') + local view if max > limit then - local view = string.format('%s...(+%d)' + view = string.format('%s...(+%d)' , table.concat(array, '|', 1, limit) , max - limit ) - - self.cachedView = view - - return view else - local view = table.concat(array, '|') - - self.cachedView = view + view = table.concat(array, '|') + end - return view + if self.node:isOptional() then + if max > 1 then + view = '(' .. view .. ')?' + else + view = view .. '?' + end end + self.cachedView = view + + return view end function mt:eachView() diff --git a/script/vm/node.lua b/script/vm/node.lua index 3145a34d..6ef7c9d5 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -146,43 +146,30 @@ function mt:getData(k) end function mt:addOptional() - if self:isOptional() then - return self - end self.optional = true end function mt:removeOptional() - if not self:isOptional() then - return self - end - self:_expand() - for i = #self, 1, -1 do - local n = self[i] - if n.type == 'nil' - or (n.type == 'boolean' and n[1] == false) - or (n.type == 'doc.type.boolean' and n[1] == false) then - self[i] = self[#self] - self[#self] = nil - end - end + self.optional = false end ---@return boolean function mt:isOptional() - if self.optional ~= nil then - return self.optional + return self.optional == true +end + +---@return boolean +function mt:isFalsy() + if self.optional then + return true end - self:_expand() for _, c in ipairs(self) do if c.type == 'nil' or (c.type == 'boolean' and c[1] == false) or (c.type == 'doc.type.boolean' and c[1] == false) then - self.optional = true return true end end - self.optional = false return false end @@ -196,6 +183,11 @@ function mt:eachObject() end end +---@return vm.node +function mt:copy() + return vm.createNode(self) +end + ---@param source parser.object | vm.generic ---@param node vm.node | vm.object ---@param cover? boolean diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 2d45a5a7..257166ce 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -114,7 +114,7 @@ function mt:resolve(uri, args) local function buildArgNode(argNode, knownTypes) local newArgNode = vm.createNode() for n in argNode:eachObject() do - if argNode:isOptional() and vm.isFalsy(n) then + if argNode:isFalsy() then goto CONTINUE end local view = infer.viewObject(n) diff --git a/script/vm/value.lua b/script/vm/value.lua index a784be2a..1a2b5722 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -41,28 +41,9 @@ function vm.test(source) end end ----@param source parser.object ----@return boolean -function vm.isFalsy(source) - if source.type == 'nil' then - return true - end - if source.type == 'boolean' - or source.type == 'doc.type.boolean' then - return source[1] == false - end - return false -end - ---@param v vm.object ---@return string? local function getUnique(v) - if v.type == 'local' then - return ('loc:%s@%d'):format(guide.getUri(v), v.start) - end - if v.type == 'global' then - return ('%s:%s'):format(v.cate, v.name) - end if v.type == 'boolean' then if v[1] == nil then return false diff --git a/test/hover/init.lua b/test/hover/init.lua index ee66ef2b..1a05d78e 100644 --- a/test/hover/init.lua +++ b/test/hover/init.lua @@ -772,7 +772,7 @@ local <?t?> = { ]] [[ local t: { - f: file*, + f: file*?, } ]] diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 9ead2861..ca3027fe 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1482,3 +1482,10 @@ function mt:f() print(<?self?>) end ]] + +TEST 'string?' [[ +---@return string? +local function f() end + +local <?x?> = f() +]] |