summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/hover/return.lua10
-rw-r--r--script/vm/compiler.lua9
-rw-r--r--script/vm/infer.lua48
-rw-r--r--script/vm/node.lua34
-rw-r--r--script/vm/sign.lua2
-rw-r--r--script/vm/value.lua19
-rw-r--r--test/hover/init.lua2
-rw-r--r--test/type_inference/init.lua7
8 files changed, 64 insertions, 67 deletions
diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua
index 77710148..93ff1f6c 100644
--- a/script/core/hover/return.lua
+++ b/script/core/hover/return.lua
@@ -65,10 +65,9 @@ local function asFunction(source)
local rtn = vm.getReturnOfFunction(source, i)
local doc = docs[i]
local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ')
- local text = ('%s%s%s'):format(
+ local text = ('%s%s'):format(
name or '',
- infer.getInfer(rtn):view(),
- doc and doc.optional and '?' or ''
+ infer.getInfer(rtn):view()
)
if i == 1 then
returns[i] = (' -> %s'):format(text)
@@ -86,10 +85,7 @@ local function asDocFunction(source)
end
local returns = {}
for i, rtn in ipairs(source.returns) do
- local rtnText = ('%s%s'):format(
- infer.getInfer(rtn):view(),
- rtn.optional and '?' or ''
- )
+ local rtnText = infer.getInfer(rtn):view()
if i == 1 then
returns[#returns+1] = (' -> %s'):format(rtnText)
else
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 8126f393..8df28e7b 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -455,6 +455,9 @@ local function getReturn(func, index, args)
result:merge(rnode)
end
end
+ if result and returnNode:isOptional() then
+ result:addOptional()
+ end
end
end
end
@@ -1115,6 +1118,9 @@ local compilerSwitch = util.switch()
for _, typeUnit in ipairs(source.types) do
vm.setNode(source, vm.compileNode(typeUnit))
end
+ if source.optional then
+ vm.getNode(source):addOptional()
+ end
end)
: case 'doc.type.integer'
: case 'doc.type.string'
@@ -1220,6 +1226,9 @@ local compilerSwitch = util.switch()
else
vm.setNode(source, globalMgr.getGlobal('type', 'any'))
end
+ if source.optional then
+ vm.getNode(source):addOptional()
+ end
end)
: case 'generic'
: call(function (source)
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 2a64ed52..7bb581cf 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -144,20 +144,23 @@ local viewNodeSwitch = util.switch()
local argView = ''
local regView = ''
for i, arg in ipairs(source.args) do
+ local argNode = vm.compileNode(arg)
+ local isOptional = argNode:isOptional()
+ if isOptional then
+ argNode = argNode:copy()
+ argNode:removeOptional()
+ end
args[i] = string.format('%s%s: %s'
, arg.name[1]
- , arg.optional and '?' or ''
- , m.getInfer(arg):view()
+ , isOptional and '?' or ''
+ , m.getInfer(argNode):view()
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
for i, ret in ipairs(source.returns) do
- rets[i] = string.format('%s%s'
- , m.getInfer(ret):view()
- , ret.optional and '?' or ''
- )
+ rets[i] = m.getInfer(ret):view()
end
if #rets > 0 then
regView = ':' .. table.concat(rets, ', ')
@@ -165,16 +168,21 @@ local viewNodeSwitch = util.switch()
return ('fun(%s)%s'):format(argView, regView)
end)
----@param source parser.object
+---@param source parser.object | vm.node
---@return vm.infer
function m.getInfer(source)
- local node = vm.compileNode(source)
+ local node
+ if source.type == 'vm.node' then
+ node = source
+ else
+ node = vm.compileNode(source)
+ end
if node.lastInfer then
return node.lastInfer
end
local infer = setmetatable({
node = node,
- uri = guide.getUri(source),
+ uri = source.type ~= 'vm.node' and guide.getUri(source),
}, mt)
node.lastInfer = infer
@@ -298,22 +306,26 @@ function mt:view(default, uri)
local max = #array
local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit')
+ local view
if max > limit then
- local view = string.format('%s...(+%d)'
+ view = string.format('%s...(+%d)'
, table.concat(array, '|', 1, limit)
, max - limit
)
-
- self.cachedView = view
-
- return view
else
- local view = table.concat(array, '|')
-
- self.cachedView = view
+ view = table.concat(array, '|')
+ end
- return view
+ if self.node:isOptional() then
+ if max > 1 then
+ view = '(' .. view .. ')?'
+ else
+ view = view .. '?'
+ end
end
+ self.cachedView = view
+
+ return view
end
function mt:eachView()
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 3145a34d..6ef7c9d5 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -146,43 +146,30 @@ function mt:getData(k)
end
function mt:addOptional()
- if self:isOptional() then
- return self
- end
self.optional = true
end
function mt:removeOptional()
- if not self:isOptional() then
- return self
- end
- self:_expand()
- for i = #self, 1, -1 do
- local n = self[i]
- if n.type == 'nil'
- or (n.type == 'boolean' and n[1] == false)
- or (n.type == 'doc.type.boolean' and n[1] == false) then
- self[i] = self[#self]
- self[#self] = nil
- end
- end
+ self.optional = false
end
---@return boolean
function mt:isOptional()
- if self.optional ~= nil then
- return self.optional
+ return self.optional == true
+end
+
+---@return boolean
+function mt:isFalsy()
+ if self.optional then
+ return true
end
- self:_expand()
for _, c in ipairs(self) do
if c.type == 'nil'
or (c.type == 'boolean' and c[1] == false)
or (c.type == 'doc.type.boolean' and c[1] == false) then
- self.optional = true
return true
end
end
- self.optional = false
return false
end
@@ -196,6 +183,11 @@ function mt:eachObject()
end
end
+---@return vm.node
+function mt:copy()
+ return vm.createNode(self)
+end
+
---@param source parser.object | vm.generic
---@param node vm.node | vm.object
---@param cover? boolean
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 2d45a5a7..257166ce 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -114,7 +114,7 @@ function mt:resolve(uri, args)
local function buildArgNode(argNode, knownTypes)
local newArgNode = vm.createNode()
for n in argNode:eachObject() do
- if argNode:isOptional() and vm.isFalsy(n) then
+ if argNode:isFalsy() then
goto CONTINUE
end
local view = infer.viewObject(n)
diff --git a/script/vm/value.lua b/script/vm/value.lua
index a784be2a..1a2b5722 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -41,28 +41,9 @@ function vm.test(source)
end
end
----@param source parser.object
----@return boolean
-function vm.isFalsy(source)
- if source.type == 'nil' then
- return true
- end
- if source.type == 'boolean'
- or source.type == 'doc.type.boolean' then
- return source[1] == false
- end
- return false
-end
-
---@param v vm.object
---@return string?
local function getUnique(v)
- if v.type == 'local' then
- return ('loc:%s@%d'):format(guide.getUri(v), v.start)
- end
- if v.type == 'global' then
- return ('%s:%s'):format(v.cate, v.name)
- end
if v.type == 'boolean' then
if v[1] == nil then
return false
diff --git a/test/hover/init.lua b/test/hover/init.lua
index ee66ef2b..1a05d78e 100644
--- a/test/hover/init.lua
+++ b/test/hover/init.lua
@@ -772,7 +772,7 @@ local <?t?> = {
]]
[[
local t: {
- f: file*,
+ f: file*?,
}
]]
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 9ead2861..ca3027fe 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1482,3 +1482,10 @@ function mt:f()
print(<?self?>)
end
]]
+
+TEST 'string?' [[
+---@return string?
+local function f() end
+
+local <?x?> = f()
+]]