summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsumneko <sumneko@hotmail.com>2022-03-11 04:04:08 +0800
committersumneko <sumneko@hotmail.com>2022-03-11 04:04:08 +0800
commit36976b3df2d19d6b2593a25527516218049d630a (patch)
tree154152c906b96eecbe9237bc0504115d96264693
parentbb02090e9f63355cda301a66571da9d2cd141334 (diff)
downloadlua-language-server-36976b3df2d19d6b2593a25527516218049d630a.zip
update
-rw-r--r--script/vm/compiler.lua12
-rw-r--r--script/vm/generic.lua2
-rw-r--r--script/vm/node.lua13
-rw-r--r--script/vm/sign.lua6
-rw-r--r--script/vm/union.lua95
-rw-r--r--test/type_inference/init.lua93
6 files changed, 178 insertions, 43 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index eda85f83..6b73b197 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -167,7 +167,11 @@ local function getObjectSign(source)
end
if source.args then
for _, arg in ipairs(source.args) do
- source._sign:addSign(m.compileNode(arg))
+ local argNode = m.compileNode(arg)
+ if arg.optional then
+ argNode = nodeMgr.addOptional(argNode)
+ end
+ source._sign:addSign(argNode)
end
end
end
@@ -183,7 +187,11 @@ local function getObjectSign(source)
source._sign = signMgr()
if source.type == 'doc.type.function' then
for _, arg in ipairs(source.args) do
- source._sign:addSign(m.compileNode(arg.extends))
+ local argNode = m.compileNode(arg.extends)
+ if arg.optional then
+ argNode = nodeMgr.addOptional(argNode)
+ end
+ source._sign:addSign(argNode)
end
end
end
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index aa1b76af..28caa2db 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -95,11 +95,13 @@ local function cloneObject(node, resolved)
for i, arg in ipairs(node.args) do
local newObj = cloneObject(arg, resolved)
newObj.parent = newDocFunc
+ newObj.optional = arg.optional
newDocFunc.args[i] = newObj
end
for i, ret in ipairs(node.returns) do
local newObj = cloneObject(ret, resolved)
newObj.parent = newDocFunc
+ newObj.optional = ret.optional
newDocFunc.returns[i] = cloneObject(ret, resolved)
end
return newDocFunc
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 8766d347..3368ab05 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -40,10 +40,21 @@ end
---@param node vm.node
---@return vm.node.union
-function m.setFalsy(node)
+function m.addOptional(node)
if node.type ~= 'union' then
node = union(node)
end
+ node = node:addOptional()
+ return node
+end
+
+---@param node vm.node
+---@return vm.node.union
+function m.removeOptional(node)
+ if node.type ~= 'union' then
+ node = union(node)
+ end
+ node = node:removeOptional()
return node
end
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 74b59347..059d82bb 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -81,7 +81,11 @@ function mt:resolve(argNodes)
break
end
for n in nodeMgr.eachNode(sign) do
- resolve(n, compiler.compileNode(node))
+ node = compiler.compileNode(node)
+ if sign.optional then
+ node = nodeMgr.removeOptional(node)
+ end
+ resolve(n, node)
end
end
diff --git a/script/vm/union.lua b/script/vm/union.lua
index 9f0cb767..183f3440 100644
--- a/script/vm/union.lua
+++ b/script/vm/union.lua
@@ -4,9 +4,19 @@ local localMgr = require 'vm.local-manager'
local mt = {}
mt.__index = mt
mt.type = 'union'
-mt.falsy = nil
+mt.optional = nil
mt.lastViews = nil
+---@param me parser.object
+---@param node vm.node
+---@return vm.node.union
+local function createUnion(me, node)
+ local union = setmetatable({}, mt)
+ union:merge(me)
+ union:merge(node)
+ return union
+end
+
---@param node vm.node
function mt:merge(node)
if not node then
@@ -19,10 +29,8 @@ function mt:merge(node)
self[#self+1] = c
end
end
- if node:isFalsy() then
- self:setFalsy()
- else
- self.falsy = nil
+ if node:isOptional() then
+ self.optional = true
end
else
if not self[node] then
@@ -47,48 +55,65 @@ function mt:eachNode()
end
end
-function mt:setFalsy()
- self.falsy = true
+---@return vm.node.union
+function mt:addOptional()
+ if self:isOptional() then
+ return self
+ end
+ self.optional = true
+ return self
end
-function mt:setTruthy()
- self.falsy = false
+---@return vm.node.union
+function mt:removeOptional()
+ self.optional = nil
+ if not self:isOptional() then
+ return self
+ end
+ -- copy union
+ local newUnion = createUnion()
+ for _, n in ipairs(self) do
+ if n.type == 'nil' then
+ goto CONTINUE
+ end
+ if n.type == 'boolean' then
+ if n[1] == false then
+ goto CONTINUE
+ end
+ end
+ if n.type == 'false' then
+ goto CONTINUE
+ end
+ newUnion[#newUnion+1] = n
+ ::CONTINUE::
+ end
+ newUnion.optional = false
+ return newUnion
end
-function mt:checkFalsy()
- if self.falsy ~= nil then
- return
+---@return boolean
+function mt:isOptional()
+ if self.optional ~= nil then
+ return self.optional
end
for _, c in ipairs(self) do
if c.type == 'nil' then
- self:setFalsy()
- return
+ self.optional = true
+ return true
end
if c.type == 'boolean' then
if c[1] == false then
- self:setFalsy()
- return
+ self.optional = true
+ return true
end
end
+ if c.type == 'false' then
+ self.optional = true
+ return true
+ end
end
+ self.optional = false
+ return false
end
-function mt:isFalsy()
- self:checkFalsy()
- return self.falsy == true
-end
-
-function mt:isTruthy()
- self:checkFalsy()
- return self.falsy == false
-end
-
----@param me parser.object
----@param node vm.node
----@return vm.node.union
-return function (me, node)
- local union = setmetatable({}, mt)
- union:merge(me)
- union:merge(node)
- return union
-end
+return createUnion
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 7721e776..46fd2cad 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -540,6 +540,22 @@ print(t.<?a?>)
]]
TEST 'integer' [[
+---@generic K
+---@type fun(a?: K):K
+local f
+
+local <?n?> = f(1)
+]]
+
+TEST 'unknown' [[
+---@generic K
+---@type fun(a?: K):K
+local f
+
+local <?n?> = f(nil)
+]]
+
+TEST 'integer' [[
---@class integer
---@generic T: table, V
@@ -685,22 +701,91 @@ local f2 = f(1)
local i, <?v?> = f2(true)
]]
+TEST 'fun(table: table<<K>, <V>>, index?: <K>):<K>, <V>' [[
+---@generic T: table, K, V
+---@param t T
+---@return fun(table: table<K, V>, index?: K):K, V
+---@return T
+---@return nil
+local function pairs(t) end
+
+local <?next?> = pairs(dummy)
+]]
+
TEST 'string' [[
----@class string
+---@generic T: table, K, V
+---@param t T
+---@return fun(table: table<K, V>, index?: K):K, V
+---@return T
+---@return nil
+local function pairs(t) end
+local next = pairs(dummy)
+
+---@type table<string, boolean>
+local t
+local <?k?>, v = next(t)
+]]
+
+TEST 'boolean' [[
---@generic T: table, K, V
---@param t T
----@return fun(table: table<K, V>, index: K):K, V
+---@return fun(table: table<K, V>, index?: K):K, V
---@return T
---@return nil
local function pairs(t) end
-local f = pairs(t)
+local next = pairs(dummy)
+
+---@type table<string, boolean>
+local t
+local k, <?v?> = next(t)
+]]
+
+TEST 'string' [[
+---@generic T: table, K, V
+---@param t T
+---@return fun(table: table<K, V>, index?: K):K, V
+---@return T
+---@return nil
+local function pairs(t) end
+
+local next = pairs(dummy)
+
+---@type table<string, boolean>
+local t
+local <?k?>, v = next(t, nil)
+]]
+
+TEST 'boolean' [[
+---@generic T: table, K, V
+---@param t T
+---@return fun(table: table<K, V>, index?: K):K, V
+---@return T
+---@return nil
+local function pairs(t) end
+
+local next = pairs(dummy)
+
+---@type table<string, boolean>
+local t
+local k, <?v?> = next(t, nil)
+]]
+
+TEST 'string' [[
+---@generic T: table, K, V
+---@param t T
+---@return fun(table: table<K, V>, index?: K):K, V
+---@return T
+---@return nil
+local function pairs(t) end
+
+local next = pairs(dummy)
---@type table<string, boolean>
local t
-for <?k?>, v in f, t do
+for <?k?>, v in next, t do
end
]]