summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md13
-rw-r--r--script/vm/sign.lua64
-rw-r--r--test/type_inference/init.lua55
3 files changed, 127 insertions, 5 deletions
diff --git a/changelog.md b/changelog.md
index 2329cdca..2ce0f9c3 100644
--- a/changelog.md
+++ b/changelog.md
@@ -46,6 +46,18 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
print(obj.initValue) --> `obj.initValue` is integer
```
+* `CHG` [#1153] infer type by generic parameters or returns of function
+ ```lua
+ ---@generic T
+ ---@param f fun(x: T)
+ ---@return T[]
+ local function x(f) end
+
+ ---@type fun(x: integer)
+ local cb
+
+ local arr = x(cb) --> `arr` is inferred as `integer[]`
+ ```
* `FIX` [#1567]
* `FIX` [#1593]
* `FIX` [#1595]
@@ -56,6 +68,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
* `FIX` [#1640]
* `FIX` [#1642]
+[#1153]: https://github.com/sumneko/lua-language-server/issues/1153
[#1177]: https://github.com/sumneko/lua-language-server/issues/1177
[#1458]: https://github.com/sumneko/lua-language-server/issues/1458
[#1557]: https://github.com/sumneko/lua-language-server/issues/1557
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 7c95fd08..21044a28 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -22,12 +22,26 @@ function mt:resolve(uri, args, removeGeneric)
if not args then
return nil
end
+
+ ---@type table<string, vm.node>
local resolved = {}
- ---@param object vm.node.object
+ ---@param object vm.node|vm.node.object
---@param node vm.node
local function resolve(object, node)
+ if object.type == 'vm.node' then
+ for o in object:eachObject() do
+ resolve(o, node)
+ end
+ return
+ end
+ if object.type == 'doc.type' then
+ ---@cast object parser.object
+ resolve(vm.compileNode(object), node)
+ return
+ end
if object.type == 'doc.generic.name' then
+ ---@type string
local key = object[1]
if object.literal then
-- 'number' -> `T`
@@ -40,8 +54,21 @@ function mt:resolve(uri, args, removeGeneric)
end
else
-- number -> T
- resolved[key] = vm.createNode(node, resolved[key])
+ for n in node:eachObject() do
+ if n.type ~= 'doc.generic.name'
+ and n.type ~= 'generic' then
+ if resolved[key] then
+ resolved[key]:merge(n)
+ else
+ resolved[key] = vm.createNode(n)
+ end
+ end
+ end
+ if resolved[key] and node:isOptional() then
+ resolved[key]:addOptional()
+ end
end
+ return
end
if object.type == 'doc.type.array' then
for n in node:eachObject() do
@@ -68,6 +95,7 @@ function mt:resolve(uri, args, removeGeneric)
resolve(object.node, vm.compileNode(n[1]))
end
end
+ return
end
if object.type == 'doc.type.table' then
for _, ufield in ipairs(object.fields) do
@@ -105,6 +133,34 @@ function mt:resolve(uri, args, removeGeneric)
end
::CONTINUE::
end
+ return
+ end
+ if object.type == 'doc.type.function' then
+ for i, arg in ipairs(object.args) do
+ for n in node:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ ---@cast n parser.object
+ local farg = n.args and n.args[i]
+ if farg then
+ resolve(arg.extends, vm.compileNode(farg))
+ end
+ end
+ end
+ end
+ for i, ret in ipairs(object.returns) do
+ for n in node:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ ---@cast n parser.object
+ local fret = vm.getReturnOfFunction(n, i)
+ if fret then
+ resolve(ret, vm.compileNode(fret))
+ end
+ end
+ end
+ end
+ return
end
end
@@ -190,9 +246,7 @@ function mt:resolve(uri, args, removeGeneric)
local knownTypes, genericNames = getSignInfo(sign)
if not isAllResolved(genericNames) then
local newArgNode = buildArgNode(argNode,sign, knownTypes)
- for n in sign:eachObject() do
- resolve(n, newArgNode)
- end
+ resolve(sign, newArgNode)
end
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 43177dc7..de42a91d 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -3803,3 +3803,58 @@ local class
class.has.nested.<?fn?>()
]]
+
+TEST 'integer[]' [[
+---@generic T
+---@param f fun(x: T)
+---@return T[]
+local function x(f) end
+
+---@param x integer
+local <?arr?> = x(function (x) end)
+]]
+
+TEST 'integer[]' [[
+---@generic T
+---@param f fun():T
+---@return T[]
+local function x(f) end
+
+local <?arr?> = x(function ()
+ return 1
+end)
+]]
+
+TEST 'integer[]' [[
+---@generic T
+---@param f fun():T
+---@return T[]
+local function x(f) end
+
+---@return integer
+local <?arr?> = x(function () end)
+]]
+
+TEST 'integer[]' [[
+---@generic T
+---@param f fun(x: T)
+---@return T[]
+local function x(f) end
+
+---@type fun(x: integer)
+local cb
+
+local <?arr?> = x(cb)
+]]
+
+TEST 'integer[]' [[
+---@generic T
+---@param f fun():T
+---@return T[]
+local function x(f) end
+
+---@type fun(): integer
+local cb
+
+local <?arr?> = x(cb)
+]]