summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md11
-rw-r--r--script/vm/compiler.lua17
-rw-r--r--test/type_inference/init.lua8
3 files changed, 35 insertions, 1 deletions
diff --git a/changelog.md b/changelog.md
index 2ce0f9c3..2e247a54 100644
--- a/changelog.md
+++ b/changelog.md
@@ -28,7 +28,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
---@type myClass
local class
- print(class.a.b.c.e.f.g) --> infered as integer
+ print(class.a.b.c.e.f.g) --> inferred as integer
```
* `CHG` [#1582] the following diagnostics consider `overload`
* `missing-return`
@@ -58,6 +58,14 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
local arr = x(cb) --> `arr` is inferred as `integer[]`
```
+* `CHG` [#1202] infer parameter type by expected returned function of parent function
+ ```lua
+ ---@return fun(x: integer)
+ local function f()
+ return function (x) --> `x` is inferred as `integer`
+ end
+ end
+ ```
* `FIX` [#1567]
* `FIX` [#1593]
* `FIX` [#1595]
@@ -70,6 +78,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
[#1153]: https://github.com/sumneko/lua-language-server/issues/1153
[#1177]: https://github.com/sumneko/lua-language-server/issues/1177
+[#1202]: https://github.com/sumneko/lua-language-server/issues/1202
[#1458]: https://github.com/sumneko/lua-language-server/issues/1458
[#1557]: https://github.com/sumneko/lua-language-server/issues/1557
[#1558]: https://github.com/sumneko/lua-language-server/issues/1558
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index ad1379fc..599b2c4a 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1059,6 +1059,7 @@ local function compileLocal(source)
end
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
+ -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
local hasDocArg
for n in funcNode:eachObject() do
@@ -1158,6 +1159,22 @@ local compilerSwitch = util.switch()
local call = source.parent.parent
vm.compileCallArg(source, call)
end
+
+ -- function f() return function (<?x?>) end end
+ if source.parent.type == 'return' then
+ for i, ret in ipairs(source.parent) do
+ if ret == source then
+ local func = guide.getParentFunction(source.parent)
+ if func then
+ local returnObj = vm.getReturnOfFunction(func, i)
+ if returnObj then
+ vm.setNode(source, vm.compileNode(returnObj))
+ end
+ end
+ break
+ end
+ end
+ end
end)
: case 'paren'
: call(function (source)
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index de42a91d..33704c1a 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -3858,3 +3858,11 @@ local cb
local <?arr?> = x(cb)
]]
+
+TEST 'integer' [[
+---@return fun(x: integer)
+local function f()
+ return function (<?x?>)
+ end
+end
+]]