diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2024-09-06 17:09:33 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2024-09-06 17:09:33 +0800 |
commit | 7c481f57c407f7e4d0b359a3cfbce66add99ec2f (patch) | |
tree | d9e9a8ca7aba039f9324a21eaea8a584b590cb20 | |
parent | 1ea4c04bdc392db66a36f521a5c256dbd837583b (diff) | |
download | lua-language-server-7c481f57c407f7e4d0b359a3cfbce66add99ec2f.zip |
Infer the parameter types of a same-named function in the subclass based on the parameter types in the superclass function.
-rw-r--r-- | changelog.md | 1 | ||||
-rw-r--r-- | script/vm/compiler.lua | 70 | ||||
-rw-r--r-- | test/type_inference/common.lua | 13 |
3 files changed, 67 insertions, 17 deletions
diff --git a/changelog.md b/changelog.md index 3efffd98..38d9aeb7 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ <!-- Add all new changes here. They will be moved under a version at release --> * `NEW` Custom documentation exporter * `NEW` Setting: `Lua.docScriptPath`: Path to a script that overrides `cli.doc.export`, allowing user-specified documentation exporting. +* `NEW` Infer the parameter types of a same-named function in the subclass based on the parameter types in the superclass function. * `FIX` Fix `VM.OnCompileFunctionParam` function in plugins * `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter * `FIX` Improve type narrow by checking exact match on literal type params diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 86a5b2d0..9632ebbf 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1088,21 +1088,26 @@ end ---@param func parser.object ---@param source parser.object local function compileFunctionParam(func, source) + local aindex + for index, arg in ipairs(func.args) do + if arg == source then + aindex = index + break + end + end + ---@cast aindex integer + -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number local funcNode = vm.compileNode(func) for n in funcNode:eachObject() do if n.type == 'doc.type.function' then - for index, arg in ipairs(n.args) do - if func.args[index] == source then - local argNode = vm.compileNode(arg) - for an in argNode:eachObject() do - if an.type ~= 'doc.generic.name' then - vm.setNode(source, an) - end - end - return true + local argNode = vm.compileNode(n.args[aindex]) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) end end + return true end end @@ -1118,19 +1123,50 @@ local function compileFunctionParam(func, source) if not caller.args then goto continue end - for index, arg in ipairs(source.parent) do - if arg == source then - local callerArg = caller.args[index] - if callerArg then - vm.setNode(source, vm.compileNode(callerArg)) - found = true - end - end + local callerArg = caller.args[aindex] + if callerArg then + vm.setNode(source, vm.compileNode(callerArg)) + found = true end ::continue:: end return found end + + do + local parent = func.parent + local key = vm.getKeyName(parent) + local classDef = vm.getParentClass(parent) + local suri = guide.getUri(func) + if classDef and key then + local found + for _, set in ipairs(classDef:getSets(suri)) do + if set.type == 'doc.class' and set.extends then + for _, ext in ipairs(set.extends) do + local extClass = vm.getGlobal('type', ext[1]) + if extClass then + vm.getClassFields(suri, extClass, key, function (field, isMark) + for n in vm.compileNode(field):eachObject() do + if n.type == 'function' then + local argNode = vm.compileNode(n.args[aindex]) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + found = true + end + end + end + end + end) + end + end + end + end + if found then + return true + end + end + end end ---@param source parser.object diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 4a29454a..11fa39b8 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4428,3 +4428,16 @@ TEST 'A' [[ local x local <?y?> = 1 >> x ]] + +TEST 'number' [[ +---@class A +local A = {} + +---@param x number +function A:func(x) end + +---@class B: A +local B = {} + +function B:func(<?x?>) end +]] |