diff options
Diffstat (limited to 'script/vm/sign.lua')
-rw-r--r-- | script/vm/sign.lua | 96 |
1 files changed, 79 insertions, 17 deletions
diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 5b97f2b9..ca326965 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -1,5 +1,6 @@ local guide = require 'parser.guide' local vm = require 'vm.vm' +local infer = require 'vm.infer' ---@class vm.sign ---@field parent parser.object @@ -23,12 +24,12 @@ function mt:resolve(uri, args) local globalMgr = require 'vm.global-manager' local resolved = {} - ---@param typeUnit parser.object - ---@param node vm.node - local function resolve(typeUnit, node) - if typeUnit.type == 'doc.generic.name' then - local key = typeUnit[1] - if typeUnit.literal then + ---@param object parser.object + ---@param node vm.node + local function resolve(object, node) + if object.type == 'doc.generic.name' then + local key = object[1] + if object.literal then -- 'number' -> `T` for n in node:eachObject() do if n.type == 'string' then @@ -41,16 +42,16 @@ function mt:resolve(uri, args) resolved[key] = vm.createNode(node, resolved[key]) end end - if typeUnit.type == 'doc.type.array' then + if object.type == 'doc.type.array' then for n in node:eachObject() do if n.type == 'doc.type.array' then -- number[] -> T[] - resolve(typeUnit.node, vm.compileNode(n.node)) + resolve(object.node, vm.compileNode(n.node)) end end end - if typeUnit.type == 'doc.type.table' then - for _, ufield in ipairs(typeUnit.fields) do + if object.type == 'doc.type.table' then + for _, ufield in ipairs(object.fields) do local ufieldNode = vm.compileNode(ufield.name) local uvalueNode = vm.compileNode(ufield.extends) if ufieldNode[1].type == 'doc.generic.name' and uvalueNode[1].type == 'doc.generic.name' then @@ -74,18 +75,79 @@ function mt:resolve(uri, args) end end + ---@param sign vm.node + ---@return table<string, true> + ---@return table<string, true> + local function getSignInfo(sign) + local knownTypes = {} + local genericsNames = {} + for obj in sign:eachObject() do + if obj.type == 'doc.generic.name' then + genericsNames[obj[1]] = true + goto CONTINUE + end + if obj.type == 'doc.type.table' + or obj.type == 'doc.type.function' + or obj.type == 'doc.type.array' then + local hasGeneric + guide.eachSourceType(obj, 'doc.generic.name', function (src) + hasGeneric = true + genericsNames[src[1]] = true + end) + if hasGeneric then + goto CONTINUE + end + end + local view = infer.viewObject(obj) + if view then + knownTypes[view] = true + end + ::CONTINUE:: + end + return knownTypes, genericsNames + end + + -- remove un-generic type + ---@param argNode vm.node + ---@param knownTypes table<string, true> + ---@return vm.node + local function buildArgNode(argNode, knownTypes) + local newArgNode = vm.createNode() + for n in argNode:eachObject() do + if argNode:isOptional() and vm.isFalsy(n) then + goto CONTINUE + end + local view = infer.viewObject(n) + if knownTypes[view] then + goto CONTINUE + end + newArgNode:merge(n) + ::CONTINUE:: + end + return newArgNode + end + + ---@param genericNames table<string, true> + local function isAllResolved(genericNames) + for n in pairs(genericNames) do + if not resolved[n] then + return false + end + end + return true + end + for i, arg in ipairs(args) do local sign = self.signList[i] if not sign then break end - for n in sign:eachObject() do - local argNode = vm.compileNode(arg) - if argNode then - if sign.optional then - argNode:removeOptional() - end - resolve(n, argNode) + local argNode = vm.compileNode(arg) + local knownTypes, genericNames = getSignInfo(sign) + if not isAllResolved(genericNames) then + local newArgNode = buildArgNode(argNode, knownTypes) + for n in sign:eachObject() do + resolve(n, newArgNode) end end end |