diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2022-10-23 02:21:17 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2022-10-23 02:21:17 +0800 |
commit | 96b320247788163f7d014404b7b1f089aed82dff (patch) | |
tree | 888558c3cfdc44e427e0f558d071897e38368f18 /script | |
parent | b6e7d271fcb9c6d807a196d333cd27b920d71f9e (diff) | |
download | lua-language-server-96b320247788163f7d014404b7b1f089aed82dff.zip |
infer type by function as parameters
resolve #1153
Diffstat (limited to 'script')
-rw-r--r-- | script/vm/sign.lua | 64 |
1 files changed, 59 insertions, 5 deletions
diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 7c95fd08..21044a28 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -22,12 +22,26 @@ function mt:resolve(uri, args, removeGeneric) if not args then return nil end + + ---@type table<string, vm.node> local resolved = {} - ---@param object vm.node.object + ---@param object vm.node|vm.node.object ---@param node vm.node local function resolve(object, node) + if object.type == 'vm.node' then + for o in object:eachObject() do + resolve(o, node) + end + return + end + if object.type == 'doc.type' then + ---@cast object parser.object + resolve(vm.compileNode(object), node) + return + end if object.type == 'doc.generic.name' then + ---@type string local key = object[1] if object.literal then -- 'number' -> `T` @@ -40,8 +54,21 @@ function mt:resolve(uri, args, removeGeneric) end else -- number -> T - resolved[key] = vm.createNode(node, resolved[key]) + for n in node:eachObject() do + if n.type ~= 'doc.generic.name' + and n.type ~= 'generic' then + if resolved[key] then + resolved[key]:merge(n) + else + resolved[key] = vm.createNode(n) + end + end + end + if resolved[key] and node:isOptional() then + resolved[key]:addOptional() + end end + return end if object.type == 'doc.type.array' then for n in node:eachObject() do @@ -68,6 +95,7 @@ function mt:resolve(uri, args, removeGeneric) resolve(object.node, vm.compileNode(n[1])) end end + return end if object.type == 'doc.type.table' then for _, ufield in ipairs(object.fields) do @@ -105,6 +133,34 @@ function mt:resolve(uri, args, removeGeneric) end ::CONTINUE:: end + return + end + if object.type == 'doc.type.function' then + for i, arg in ipairs(object.args) do + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local farg = n.args and n.args[i] + if farg then + resolve(arg.extends, vm.compileNode(farg)) + end + end + end + end + for i, ret in ipairs(object.returns) do + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local fret = vm.getReturnOfFunction(n, i) + if fret then + resolve(ret, vm.compileNode(fret)) + end + end + end + end + return end end @@ -190,9 +246,7 @@ function mt:resolve(uri, args, removeGeneric) local knownTypes, genericNames = getSignInfo(sign) if not isAllResolved(genericNames) then local newArgNode = buildArgNode(argNode,sign, knownTypes) - for n in sign:eachObject() do - resolve(n, newArgNode) - end + resolve(sign, newArgNode) end end |