summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md9
-rw-r--r--meta/template/table.lua5
-rw-r--r--script/vm/compiler.lua33
-rw-r--r--script/vm/generic.lua2
-rw-r--r--script/vm/sign.lua3
-rw-r--r--test/type_inference/init.lua23
6 files changed, 66 insertions, 9 deletions
diff --git a/changelog.md b/changelog.md
index fe319441..afb0ba24 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,6 +1,15 @@
# changelog
## 3.1.1
+* `NEW` supports infer of callback parameter
+ ```lua
+ ---@type string[]
+ local t
+
+ table.sort(t, function (a, b)
+ -- `a` and `b` is `string` here
+ end)
+ ```
* `FIX` [#1051](https://github.com/sumneko/lua-language-server/issues/1051)
* `FIX` [#1072](https://github.com/sumneko/lua-language-server/issues/1072)
* `FIX` runtime errors
diff --git a/meta/template/table.lua b/meta/template/table.lua
index 21c8b619..02288342 100644
--- a/meta/template/table.lua
+++ b/meta/template/table.lua
@@ -50,8 +50,9 @@ function table.pack(...) end
function table.remove(list, pos) end
---#DES 'table.sort'
----@param list table
----@param comp fun(a: any, b: any):boolean
+---@generic T
+---@param list T[]
+---@param comp fun(a: T, b: T):boolean
function table.sort(list, comp) end
---@version >5.2, JIT
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index ca01a138..eea6e093 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -289,6 +289,9 @@ local function getObjectSign(source)
end
source._sign = false
if source.type == 'function' then
+ if not source.bindDocs then
+ return false
+ end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.generic' then
if not source._sign then
@@ -323,11 +326,15 @@ local function getObjectSign(source)
source._sign = signMgr()
if source.type == 'doc.type.function' then
for _, arg in ipairs(source.args) do
- local argNode = vm.compileNode(arg.extends)
- if arg.optional then
- argNode:addOptional()
+ if arg.extends then
+ local argNode = vm.compileNode(arg.extends)
+ if arg.optional then
+ argNode:addOptional()
+ end
+ source._sign:addSign(argNode)
+ else
+ source._sign:addSign(vm.createNode())
end
- source._sign:addSign(argNode)
end
end
end
@@ -673,10 +680,21 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
for n in callNode:eachObject() do
if n.type == 'function' then
+ local sign = getObjectSign(n)
local farg = getFuncArg(n, myIndex)
if farg then
for fn in vm.compileNode(farg):eachObject() do
if isValidCallArgNode(arg, fn) then
+ if fn.type == 'doc.type.function' then
+ if sign then
+ local generic = genericMgr(fn, sign)
+ local args = {}
+ for i = fixIndex + 1, myIndex - 1 do
+ args[#args+1] = call.args[i]
+ end
+ fn = generic:resolve(guide.getUri(call), args)
+ end
+ end
vm.setNode(arg, fn)
end
end
@@ -797,7 +815,12 @@ local function compileLocalBase(source)
if n.type == 'doc.type.function' then
for index, arg in ipairs(n.args) do
if func.args[index] == source then
- vm.setNode(source, vm.compileNode(arg))
+ local argNode = vm.compileNode(arg)
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ end
+ end
hasDocArg = true
end
end
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index b3981ff8..b58c7bce 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -114,7 +114,7 @@ end
---@param uri uri
---@param args parser.object
----@return parser.object
+---@return vm.node
function mt:resolve(uri, args)
local resolved = self.sign:resolve(uri, args)
local protoNode = vm.compileNode(self.proto)
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 257166ce..e997624a 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -16,8 +16,9 @@ end
---@param uri uri
---@param args parser.object
+---@param removeGeneric true?
---@return table<string, vm.node>
-function mt:resolve(uri, args)
+function mt:resolve(uri, args, removeGeneric)
if not args then
return nil
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index ca3027fe..8f4b0441 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1127,6 +1127,29 @@ local t = f('')
print(t.<?x?>)
]]
+TEST 'string' [[
+---@generic T
+---@param t T[]
+---@param callback fun(v: T)
+local function f(t, callback) end
+
+---@type string[]
+local t
+
+f(t, function (<?v?>) end)
+]]
+
+TEST 'unknown' [[
+---@generic T
+---@param t T[]
+---@param callback fun(v: T)
+local function f(t, callback) end
+
+local t = {}
+
+f(t, function (<?v?>) end)
+]]
+
TEST 'table' [[
local <?t?> = setmetatable({}, { __index = function () end })
]]