diff options
-rw-r--r-- | changelog.md | 9 | ||||
-rw-r--r-- | meta/template/table.lua | 5 | ||||
-rw-r--r-- | script/vm/compiler.lua | 33 | ||||
-rw-r--r-- | script/vm/generic.lua | 2 | ||||
-rw-r--r-- | script/vm/sign.lua | 3 | ||||
-rw-r--r-- | test/type_inference/init.lua | 23 |
6 files changed, 66 insertions, 9 deletions
diff --git a/changelog.md b/changelog.md index fe319441..afb0ba24 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,15 @@ # changelog ## 3.1.1 +* `NEW` supports infer of callback parameter + ```lua + ---@type string[] + local t + + table.sort(t, function (a, b) + -- `a` and `b` is `string` here + end) + ``` * `FIX` [#1051](https://github.com/sumneko/lua-language-server/issues/1051) * `FIX` [#1072](https://github.com/sumneko/lua-language-server/issues/1072) * `FIX` runtime errors diff --git a/meta/template/table.lua b/meta/template/table.lua index 21c8b619..02288342 100644 --- a/meta/template/table.lua +++ b/meta/template/table.lua @@ -50,8 +50,9 @@ function table.pack(...) end function table.remove(list, pos) end ---#DES 'table.sort' ----@param list table ----@param comp fun(a: any, b: any):boolean +---@generic T +---@param list T[] +---@param comp fun(a: T, b: T):boolean function table.sort(list, comp) end ---@version >5.2, JIT diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index ca01a138..eea6e093 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -289,6 +289,9 @@ local function getObjectSign(source) end source._sign = false if source.type == 'function' then + if not source.bindDocs then + return false + end for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.generic' then if not source._sign then @@ -323,11 +326,15 @@ local function getObjectSign(source) source._sign = signMgr() if source.type == 'doc.type.function' then for _, arg in ipairs(source.args) do - local argNode = vm.compileNode(arg.extends) - if arg.optional then - argNode:addOptional() + if arg.extends then + local argNode = vm.compileNode(arg.extends) + if arg.optional then + argNode:addOptional() + end + source._sign:addSign(argNode) + else + source._sign:addSign(vm.createNode()) end - source._sign:addSign(argNode) end end end @@ -673,10 +680,21 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) for n in callNode:eachObject() do if n.type == 'function' then + local sign = getObjectSign(n) local farg = getFuncArg(n, myIndex) if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then + if fn.type == 'doc.type.function' then + if sign then + local generic = genericMgr(fn, sign) + local args = {} + for i = fixIndex + 1, myIndex - 1 do + args[#args+1] = call.args[i] + end + fn = generic:resolve(guide.getUri(call), args) + end + end vm.setNode(arg, fn) end end @@ -797,7 +815,12 @@ local function compileLocalBase(source) if n.type == 'doc.type.function' then for index, arg in ipairs(n.args) do if func.args[index] == source then - vm.setNode(source, vm.compileNode(arg)) + local argNode = vm.compileNode(arg) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end + end hasDocArg = true end end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index b3981ff8..b58c7bce 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -114,7 +114,7 @@ end ---@param uri uri ---@param args parser.object ----@return parser.object +---@return vm.node function mt:resolve(uri, args) local resolved = self.sign:resolve(uri, args) local protoNode = vm.compileNode(self.proto) diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 257166ce..e997624a 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -16,8 +16,9 @@ end ---@param uri uri ---@param args parser.object +---@param removeGeneric true? ---@return table<string, vm.node> -function mt:resolve(uri, args) +function mt:resolve(uri, args, removeGeneric) if not args then return nil end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index ca3027fe..8f4b0441 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1127,6 +1127,29 @@ local t = f('') print(t.<?x?>) ]] +TEST 'string' [[ +---@generic T +---@param t T[] +---@param callback fun(v: T) +local function f(t, callback) end + +---@type string[] +local t + +f(t, function (<?v?>) end) +]] + +TEST 'unknown' [[ +---@generic T +---@param t T[] +---@param callback fun(v: T) +local function f(t, callback) end + +local t = {} + +f(t, function (<?v?>) end) +]] + TEST 'table' [[ local <?t?> = setmetatable({}, { __index = function () end }) ]] |