summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2024-09-06 17:09:33 +0800
committer最萌小汐 <sumneko@hotmail.com>2024-09-06 17:09:33 +0800
commit7c481f57c407f7e4d0b359a3cfbce66add99ec2f (patch)
treed9e9a8ca7aba039f9324a21eaea8a584b590cb20
parent1ea4c04bdc392db66a36f521a5c256dbd837583b (diff)
downloadlua-language-server-7c481f57c407f7e4d0b359a3cfbce66add99ec2f.zip
Infer the parameter types of a same-named function in the subclass based on the parameter types in the superclass function.
-rw-r--r--changelog.md1
-rw-r--r--script/vm/compiler.lua70
-rw-r--r--test/type_inference/common.lua13
3 files changed, 67 insertions, 17 deletions
diff --git a/changelog.md b/changelog.md
index 3efffd98..38d9aeb7 100644
--- a/changelog.md
+++ b/changelog.md
@@ -4,6 +4,7 @@
<!-- Add all new changes here. They will be moved under a version at release -->
* `NEW` Custom documentation exporter
* `NEW` Setting: `Lua.docScriptPath`: Path to a script that overrides `cli.doc.export`, allowing user-specified documentation exporting.
+* `NEW` Infer the parameter types of a same-named function in the subclass based on the parameter types in the superclass function.
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
* `FIX` Improve type narrow by checking exact match on literal type params
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 86a5b2d0..9632ebbf 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1088,21 +1088,26 @@ end
---@param func parser.object
---@param source parser.object
local function compileFunctionParam(func, source)
+ local aindex
+ for index, arg in ipairs(func.args) do
+ if arg == source then
+ aindex = index
+ break
+ end
+ end
+ ---@cast aindex integer
+
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' then
- for index, arg in ipairs(n.args) do
- if func.args[index] == source then
- local argNode = vm.compileNode(arg)
- for an in argNode:eachObject() do
- if an.type ~= 'doc.generic.name' then
- vm.setNode(source, an)
- end
- end
- return true
+ local argNode = vm.compileNode(n.args[aindex])
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
end
end
+ return true
end
end
@@ -1118,19 +1123,50 @@ local function compileFunctionParam(func, source)
if not caller.args then
goto continue
end
- for index, arg in ipairs(source.parent) do
- if arg == source then
- local callerArg = caller.args[index]
- if callerArg then
- vm.setNode(source, vm.compileNode(callerArg))
- found = true
- end
- end
+ local callerArg = caller.args[aindex]
+ if callerArg then
+ vm.setNode(source, vm.compileNode(callerArg))
+ found = true
end
::continue::
end
return found
end
+
+ do
+ local parent = func.parent
+ local key = vm.getKeyName(parent)
+ local classDef = vm.getParentClass(parent)
+ local suri = guide.getUri(func)
+ if classDef and key then
+ local found
+ for _, set in ipairs(classDef:getSets(suri)) do
+ if set.type == 'doc.class' and set.extends then
+ for _, ext in ipairs(set.extends) do
+ local extClass = vm.getGlobal('type', ext[1])
+ if extClass then
+ vm.getClassFields(suri, extClass, key, function (field, isMark)
+ for n in vm.compileNode(field):eachObject() do
+ if n.type == 'function' then
+ local argNode = vm.compileNode(n.args[aindex])
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ found = true
+ end
+ end
+ end
+ end
+ end)
+ end
+ end
+ end
+ end
+ if found then
+ return true
+ end
+ end
+ end
end
---@param source parser.object
diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua
index 4a29454a..11fa39b8 100644
--- a/test/type_inference/common.lua
+++ b/test/type_inference/common.lua
@@ -4428,3 +4428,16 @@ TEST 'A' [[
local x
local <?y?> = 1 >> x
]]
+
+TEST 'number' [[
+---@class A
+local A = {}
+
+---@param x number
+function A:func(x) end
+
+---@class B: A
+local B = {}
+
+function B:func(<?x?>) end
+]]