diff options
-rw-r--r-- | changelog.md | 13 | ||||
-rw-r--r-- | script/vm/sign.lua | 64 | ||||
-rw-r--r-- | test/type_inference/init.lua | 55 |
3 files changed, 127 insertions, 5 deletions
diff --git a/changelog.md b/changelog.md index 2329cdca..2ce0f9c3 100644 --- a/changelog.md +++ b/changelog.md @@ -46,6 +46,18 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`. print(obj.initValue) --> `obj.initValue` is integer ``` +* `CHG` [#1153] infer type by generic parameters or returns of function + ```lua + ---@generic T + ---@param f fun(x: T) + ---@return T[] + local function x(f) end + + ---@type fun(x: integer) + local cb + + local arr = x(cb) --> `arr` is inferred as `integer[]` + ``` * `FIX` [#1567] * `FIX` [#1593] * `FIX` [#1595] @@ -56,6 +68,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`. * `FIX` [#1640] * `FIX` [#1642] +[#1153]: https://github.com/sumneko/lua-language-server/issues/1153 [#1177]: https://github.com/sumneko/lua-language-server/issues/1177 [#1458]: https://github.com/sumneko/lua-language-server/issues/1458 [#1557]: https://github.com/sumneko/lua-language-server/issues/1557 diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 7c95fd08..21044a28 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -22,12 +22,26 @@ function mt:resolve(uri, args, removeGeneric) if not args then return nil end + + ---@type table<string, vm.node> local resolved = {} - ---@param object vm.node.object + ---@param object vm.node|vm.node.object ---@param node vm.node local function resolve(object, node) + if object.type == 'vm.node' then + for o in object:eachObject() do + resolve(o, node) + end + return + end + if object.type == 'doc.type' then + ---@cast object parser.object + resolve(vm.compileNode(object), node) + return + end if object.type == 'doc.generic.name' then + ---@type string local key = object[1] if object.literal then -- 'number' -> `T` @@ -40,8 +54,21 @@ function mt:resolve(uri, args, removeGeneric) end else -- number -> T - resolved[key] = vm.createNode(node, resolved[key]) + for n in node:eachObject() do + if n.type ~= 'doc.generic.name' + and n.type ~= 'generic' then + if resolved[key] then + resolved[key]:merge(n) + else + resolved[key] = vm.createNode(n) + end + end + end + if resolved[key] and node:isOptional() then + resolved[key]:addOptional() + end end + return end if object.type == 'doc.type.array' then for n in node:eachObject() do @@ -68,6 +95,7 @@ function mt:resolve(uri, args, removeGeneric) resolve(object.node, vm.compileNode(n[1])) end end + return end if object.type == 'doc.type.table' then for _, ufield in ipairs(object.fields) do @@ -105,6 +133,34 @@ function mt:resolve(uri, args, removeGeneric) end ::CONTINUE:: end + return + end + if object.type == 'doc.type.function' then + for i, arg in ipairs(object.args) do + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local farg = n.args and n.args[i] + if farg then + resolve(arg.extends, vm.compileNode(farg)) + end + end + end + end + for i, ret in ipairs(object.returns) do + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local fret = vm.getReturnOfFunction(n, i) + if fret then + resolve(ret, vm.compileNode(fret)) + end + end + end + end + return end end @@ -190,9 +246,7 @@ function mt:resolve(uri, args, removeGeneric) local knownTypes, genericNames = getSignInfo(sign) if not isAllResolved(genericNames) then local newArgNode = buildArgNode(argNode,sign, knownTypes) - for n in sign:eachObject() do - resolve(n, newArgNode) - end + resolve(sign, newArgNode) end end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 43177dc7..de42a91d 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -3803,3 +3803,58 @@ local class class.has.nested.<?fn?>() ]] + +TEST 'integer[]' [[ +---@generic T +---@param f fun(x: T) +---@return T[] +local function x(f) end + +---@param x integer +local <?arr?> = x(function (x) end) +]] + +TEST 'integer[]' [[ +---@generic T +---@param f fun():T +---@return T[] +local function x(f) end + +local <?arr?> = x(function () + return 1 +end) +]] + +TEST 'integer[]' [[ +---@generic T +---@param f fun():T +---@return T[] +local function x(f) end + +---@return integer +local <?arr?> = x(function () end) +]] + +TEST 'integer[]' [[ +---@generic T +---@param f fun(x: T) +---@return T[] +local function x(f) end + +---@type fun(x: integer) +local cb + +local <?arr?> = x(cb) +]] + +TEST 'integer[]' [[ +---@generic T +---@param f fun():T +---@return T[] +local function x(f) end + +---@type fun(): integer +local cb + +local <?arr?> = x(cb) +]] |