From 96b320247788163f7d014404b7b1f089aed82dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Sun, 23 Oct 2022 02:21:17 +0800 Subject: infer type by function as parameters resolve #1153 --- script/vm/sign.lua | 64 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 5 deletions(-) (limited to 'script/vm/sign.lua') diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 7c95fd08..21044a28 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -22,12 +22,26 @@ function mt:resolve(uri, args, removeGeneric) if not args then return nil end + + ---@type table local resolved = {} - ---@param object vm.node.object + ---@param object vm.node|vm.node.object ---@param node vm.node local function resolve(object, node) + if object.type == 'vm.node' then + for o in object:eachObject() do + resolve(o, node) + end + return + end + if object.type == 'doc.type' then + ---@cast object parser.object + resolve(vm.compileNode(object), node) + return + end if object.type == 'doc.generic.name' then + ---@type string local key = object[1] if object.literal then -- 'number' -> `T` @@ -40,8 +54,21 @@ function mt:resolve(uri, args, removeGeneric) end else -- number -> T - resolved[key] = vm.createNode(node, resolved[key]) + for n in node:eachObject() do + if n.type ~= 'doc.generic.name' + and n.type ~= 'generic' then + if resolved[key] then + resolved[key]:merge(n) + else + resolved[key] = vm.createNode(n) + end + end + end + if resolved[key] and node:isOptional() then + resolved[key]:addOptional() + end end + return end if object.type == 'doc.type.array' then for n in node:eachObject() do @@ -68,6 +95,7 @@ function mt:resolve(uri, args, removeGeneric) resolve(object.node, vm.compileNode(n[1])) end end + return end if object.type == 'doc.type.table' then for _, ufield in ipairs(object.fields) do @@ -105,6 +133,34 @@ function mt:resolve(uri, args, removeGeneric) end ::CONTINUE:: end + return + end + if object.type == 'doc.type.function' then + for i, arg in ipairs(object.args) do + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local farg = n.args and n.args[i] + if farg then + resolve(arg.extends, vm.compileNode(farg)) + end + end + end + end + for i, ret in ipairs(object.returns) do + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local fret = vm.getReturnOfFunction(n, i) + if fret then + resolve(ret, vm.compileNode(fret)) + end + end + end + end + return end end @@ -190,9 +246,7 @@ function mt:resolve(uri, args, removeGeneric) local knownTypes, genericNames = getSignInfo(sign) if not isAllResolved(genericNames) then local newArgNode = buildArgNode(argNode,sign, knownTypes) - for n in sign:eachObject() do - resolve(n, newArgNode) - end + resolve(sign, newArgNode) end end -- cgit v1.2.3