summaryrefslogtreecommitdiff
path: root/script/vm/sign.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/sign.lua')
-rw-r--r--script/vm/sign.lua96
1 files changed, 79 insertions, 17 deletions
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 5b97f2b9..ca326965 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -1,5 +1,6 @@
local guide = require 'parser.guide'
local vm = require 'vm.vm'
+local infer = require 'vm.infer'
---@class vm.sign
---@field parent parser.object
@@ -23,12 +24,12 @@ function mt:resolve(uri, args)
local globalMgr = require 'vm.global-manager'
local resolved = {}
- ---@param typeUnit parser.object
- ---@param node vm.node
- local function resolve(typeUnit, node)
- if typeUnit.type == 'doc.generic.name' then
- local key = typeUnit[1]
- if typeUnit.literal then
+ ---@param object parser.object
+ ---@param node vm.node
+ local function resolve(object, node)
+ if object.type == 'doc.generic.name' then
+ local key = object[1]
+ if object.literal then
-- 'number' -> `T`
for n in node:eachObject() do
if n.type == 'string' then
@@ -41,16 +42,16 @@ function mt:resolve(uri, args)
resolved[key] = vm.createNode(node, resolved[key])
end
end
- if typeUnit.type == 'doc.type.array' then
+ if object.type == 'doc.type.array' then
for n in node:eachObject() do
if n.type == 'doc.type.array' then
-- number[] -> T[]
- resolve(typeUnit.node, vm.compileNode(n.node))
+ resolve(object.node, vm.compileNode(n.node))
end
end
end
- if typeUnit.type == 'doc.type.table' then
- for _, ufield in ipairs(typeUnit.fields) do
+ if object.type == 'doc.type.table' then
+ for _, ufield in ipairs(object.fields) do
local ufieldNode = vm.compileNode(ufield.name)
local uvalueNode = vm.compileNode(ufield.extends)
if ufieldNode[1].type == 'doc.generic.name' and uvalueNode[1].type == 'doc.generic.name' then
@@ -74,18 +75,79 @@ function mt:resolve(uri, args)
end
end
+ ---@param sign vm.node
+ ---@return table<string, true>
+ ---@return table<string, true>
+ local function getSignInfo(sign)
+ local knownTypes = {}
+ local genericsNames = {}
+ for obj in sign:eachObject() do
+ if obj.type == 'doc.generic.name' then
+ genericsNames[obj[1]] = true
+ goto CONTINUE
+ end
+ if obj.type == 'doc.type.table'
+ or obj.type == 'doc.type.function'
+ or obj.type == 'doc.type.array' then
+ local hasGeneric
+ guide.eachSourceType(obj, 'doc.generic.name', function (src)
+ hasGeneric = true
+ genericsNames[src[1]] = true
+ end)
+ if hasGeneric then
+ goto CONTINUE
+ end
+ end
+ local view = infer.viewObject(obj)
+ if view then
+ knownTypes[view] = true
+ end
+ ::CONTINUE::
+ end
+ return knownTypes, genericsNames
+ end
+
+ -- remove un-generic type
+ ---@param argNode vm.node
+ ---@param knownTypes table<string, true>
+ ---@return vm.node
+ local function buildArgNode(argNode, knownTypes)
+ local newArgNode = vm.createNode()
+ for n in argNode:eachObject() do
+ if argNode:isOptional() and vm.isFalsy(n) then
+ goto CONTINUE
+ end
+ local view = infer.viewObject(n)
+ if knownTypes[view] then
+ goto CONTINUE
+ end
+ newArgNode:merge(n)
+ ::CONTINUE::
+ end
+ return newArgNode
+ end
+
+ ---@param genericNames table<string, true>
+ local function isAllResolved(genericNames)
+ for n in pairs(genericNames) do
+ if not resolved[n] then
+ return false
+ end
+ end
+ return true
+ end
+
for i, arg in ipairs(args) do
local sign = self.signList[i]
if not sign then
break
end
- for n in sign:eachObject() do
- local argNode = vm.compileNode(arg)
- if argNode then
- if sign.optional then
- argNode:removeOptional()
- end
- resolve(n, argNode)
+ local argNode = vm.compileNode(arg)
+ local knownTypes, genericNames = getSignInfo(sign)
+ if not isAllResolved(genericNames) then
+ local newArgNode = buildArgNode(argNode, knownTypes)
+ for n in sign:eachObject() do
+ resolve(n, newArgNode)
end
end
end