summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/completion/completion.lua2
-rw-r--r--script/vm/compiler.lua61
-rw-r--r--script/vm/def.lua30
-rw-r--r--script/vm/field.lua2
-rw-r--r--script/vm/generic.lua108
-rw-r--r--script/vm/global-manager.lua14
-rw-r--r--script/vm/global.lua8
-rw-r--r--script/vm/infer.lua8
-rw-r--r--script/vm/node.lua29
-rw-r--r--script/vm/ref.lua2
-rw-r--r--script/vm/sign.lua38
-rw-r--r--script/vm/type.lua41
-rw-r--r--script/vm/union.lua5
-rw-r--r--script/vm/value.lua22
14 files changed, 199 insertions, 171 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua
index 9f0c6695..f429ad14 100644
--- a/script/core/completion/completion.lua
+++ b/script/core/completion/completion.lua
@@ -1406,7 +1406,7 @@ local function tryCallArg(state, position, results)
end
local node = compiler.compileCallArg({ type = 'dummyarg' }, call, argIndex)
local enums = {}
- for src in nodeMgr.eachNode(node) do
+ for src in nodeMgr.eachObject(node) do
if src.type == 'doc.type.string'
or src.type == 'doc.type.integer'
or src.type == 'doc.type.boolean' then
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 12c2c24c..c74d2780 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -55,7 +55,7 @@ local searchFieldSwitch = util.switch()
end
end)
: case 'global'
- ---@param node vm.node.global
+ ---@param node vm.global
: call(function (suri, node, key, pushResult)
if node.cate == 'variable' then
if key then
@@ -115,7 +115,7 @@ local searchFieldSwitch = util.switch()
local fieldKey = field.name
if fieldKey.type == 'doc.type' then
local fieldNode = m.compileNode(fieldKey)
- for fn in nodeMgr.eachNode(fieldNode) do
+ for fn in nodeMgr.eachObject(fieldNode) do
if fn.type == 'global' and fn.cate == 'type' then
if key == nil
or fn.name == 'any'
@@ -306,7 +306,7 @@ local function getReturnOfSetMetaTable(args)
end
if mt then
m.compileByParentNode(mt, '__index', function (src)
- for n in nodeMgr.eachNode(m.compileNode(src)) do
+ for n in nodeMgr.eachObject(m.compileNode(src)) do
if n.type == 'global'
or n.type == 'local'
or n.type == 'table'
@@ -373,16 +373,26 @@ local function getReturn(func, index, args)
---@type vm.node.union
local result
if node then
- for cnode in nodeMgr.eachNode(node) do
+ for cnode in nodeMgr.eachObject(node) do
if cnode.type == 'function'
or cnode.type == 'doc.type.function' then
local returnNode = m.getReturnOfFunction(cnode, index)
- if returnNode and returnNode.type == 'generic' then
- returnNode = returnNode:resolve(guide.getUri(func), args)
+ if returnNode then
+ for rnode in nodeMgr.eachObject(returnNode) do
+ if rnode.type == 'generic' then
+ returnNode = rnode:resolve(guide.getUri(func), args)
+ break
+ end
+ end
end
- if returnNode and returnNode.type ~= 'doc.generic.name' then
- result = result or union()
- result:merge(m.compileNode(returnNode))
+ if returnNode then
+ for rnode in nodeMgr.eachObject(returnNode) do
+ -- TODO: narrow type
+ if rnode.type ~= 'doc.generic.name' then
+ result = result or union()
+ result:merge(rnode)
+ end
+ end
end
end
end
@@ -467,7 +477,7 @@ function m.compileByParentNode(source, key, pushResult)
return
end
local suri = guide.getUri(source)
- for node in nodeMgr.eachNode(parentNode) do
+ for node in nodeMgr.eachObject(parentNode) do
searchFieldSwitch(node.type, suri, node, key, pushResult)
end
end
@@ -505,7 +515,7 @@ local function selectNode(source, list, index)
-- remove any for returns
local rtnNode = union()
local hasKnownType
- for n in nodeMgr.eachNode(result) do
+ for n in nodeMgr.eachObject(result) do
if guide.isLiteral(n) then
hasKnownType = true
rtnNode:merge(n)
@@ -529,7 +539,7 @@ local function selectNode(source, list, index)
end
---@param source parser.object
----@param node vm.node
+---@param node vm.object
---@return boolean
local function isValidCallArgNode(source, node)
if source.type == 'function' then
@@ -563,6 +573,11 @@ local function getFuncArg(func, index)
return nil
end
+---@param arg parser.object
+---@param call parser.object
+---@param callNode vm.node
+---@param fixIndex integer
+---@param myIndex integer
local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
local valueMgr = require 'vm.value'
if not myIndex then
@@ -586,10 +601,10 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
- for n in nodeMgr.eachNode(callNode) do
+ for n in nodeMgr.eachObject(callNode) do
if n.type == 'function' then
local farg = getFuncArg(n, myIndex)
- for fn in nodeMgr.eachNode(m.compileNode(farg)) do
+ for fn in nodeMgr.eachObject(m.compileNode(farg)) do
if isValidCallArgNode(arg, fn) then
nodeMgr.setNode(arg, fn)
end
@@ -602,7 +617,7 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
or event.type ~= 'doc.type.string'
or eventMap[event[1]] then
local farg = getFuncArg(n, myIndex)
- for fn in nodeMgr.eachNode(m.compileNode(farg)) do
+ for fn in nodeMgr.eachObject(m.compileNode(farg)) do
if isValidCallArgNode(arg, fn) then
nodeMgr.setNode(arg, fn)
end
@@ -612,6 +627,9 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
+---@param arg parser.object
+---@param call parser.position
+---@param index integer
function m.compileCallArg(arg, call, index)
local callNode = m.compileNode(call.node)
compileCallArgNode(arg, call, callNode, 0, index)
@@ -631,7 +649,6 @@ local compilerSwitch = util.switch()
: case 'integer'
: case 'number'
: case 'string'
- : case 'union'
: case 'doc.type.function'
: case 'doc.type.table'
: case 'doc.type.array'
@@ -731,7 +748,7 @@ local compilerSwitch = util.switch()
local func = source.parent.parent
local funcNode = m.compileNode(func)
local hasDocArg
- for n in nodeMgr.eachNode(funcNode) do
+ for n in nodeMgr.eachObject(funcNode) do
if n.type == 'doc.type.function' then
for index, arg in ipairs(n.args) do
if func.args[index] == source then
@@ -905,7 +922,7 @@ local compilerSwitch = util.switch()
if not node then
return
end
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'global'
and n.cate == 'type'
and n.name == '...' then
@@ -928,7 +945,7 @@ local compilerSwitch = util.switch()
if not node then
return
end
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'global'
and n.cate == 'type'
and n.name == '...' then
@@ -1427,12 +1444,12 @@ local compilerSwitch = util.switch()
end
end)
----@param source parser.object
+---@param source vm.object
local function compileByNode(source)
compilerSwitch(source.type, source)
end
----@param source vm.node
+---@param source vm.object
local function compileByGlobal(uri, source)
uri = uri or guide.getUri(source)
if source.type == 'global' then
@@ -1501,7 +1518,7 @@ function m.resumeCache()
end
end
----@param source parser.object
+---@param source vm.object
---@return vm.node
function m.compileNode(source, uri)
if not source then
diff --git a/script/vm/def.lua b/script/vm/def.lua
index 81d23854..6357dee7 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -77,8 +77,8 @@ simpleSwitch = util.switch()
local searchFieldSwitch = util.switch()
: case 'table'
- : call(function (suri, node, key, pushResult)
- for _, field in ipairs(node) do
+ : call(function (suri, obj, key, pushResult)
+ for _, field in ipairs(obj) do
if field.type == 'tablefield'
or field.type == 'tableindex' then
if guide.getKeyName(field) == key then
@@ -88,24 +88,24 @@ local searchFieldSwitch = util.switch()
end
end)
: case 'global'
- ---@param node vm.node
+ ---@param obj vm.object
---@param key string
- : call(function (suri, node, key, pushResult)
- if node.cate == 'variable' then
- local newGlobal = globalMgr.getGlobal('variable', node.name, key)
+ : call(function (suri, obj, key, pushResult)
+ if obj.cate == 'variable' then
+ local newGlobal = globalMgr.getGlobal('variable', obj.name, key)
if newGlobal then
for _, set in ipairs(newGlobal:getSets(suri)) do
pushResult(set)
end
end
end
- if node.cate == 'type' then
- compiler.getClassFields(suri, node, key, pushResult)
+ if obj.cate == 'type' then
+ compiler.getClassFields(suri, obj, key, pushResult)
end
end)
: case 'local'
- : call(function (suri, node, key, pushResult)
- local sources = localID.getSources(node, key)
+ : call(function (suri, obj, key, pushResult)
+ local sources = localID.getSources(obj, key)
if sources then
for _, src in ipairs(sources) do
if guide.isSet(src) then
@@ -115,8 +115,8 @@ local searchFieldSwitch = util.switch()
end
end)
: case 'doc.type.table'
- : call(function (suri, node, key, pushResult)
- for _, field in ipairs(node.fields) do
+ : call(function (suri, obj, key, pushResult)
+ for _, field in ipairs(obj.fields) do
local fieldKey = field.name
if fieldKey.type == 'doc.field.name' then
if fieldKey[1] == key then
@@ -146,7 +146,7 @@ local nodeSwitch = util.switch()
end
local uri = guide.getUri(source)
local key = guide.getKeyName(source)
- for pn in nodeMgr.eachNode(parentNode) do
+ for pn in nodeMgr.eachObject(parentNode) do
searchFieldSwitch(pn.type, uri, pn, key, pushResult)
end
end)
@@ -164,7 +164,7 @@ local nodeSwitch = util.switch()
return
end
local uri = guide.getUri(source)
- for pn in nodeMgr.eachNode(parentNode) do
+ for pn in nodeMgr.eachObject(parentNode) do
searchFieldSwitch(pn.type, uri, pn, source[1], pushResult)
end
end)
@@ -201,7 +201,7 @@ local function searchByNode(source, pushResult)
return
end
local suri = guide.getUri(source)
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'global' then
for _, set in ipairs(n:getSets(suri)) do
pushResult(set)
diff --git a/script/vm/field.lua b/script/vm/field.lua
index d79e3f6a..1bcf0b6b 100644
--- a/script/vm/field.lua
+++ b/script/vm/field.lua
@@ -6,7 +6,7 @@ local guide = require 'parser.guide'
local searchByNodeSwitch = util.switch()
: case 'global'
- ---@param global vm.node.global
+ ---@param global vm.global
: call(function (suri, global, pushResult)
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index bacea1e8..d2627388 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -1,4 +1,5 @@
local nodeMgr = require 'vm.node'
+local union = require 'vm.union'
---@class parser.object
---@field _generic vm.generic
@@ -10,65 +11,68 @@ local mt = {}
mt.__index = mt
mt.type = 'generic'
----@param node vm.node
+---@param source parser.object
---@param resolved? table<string, vm.node>
---@return vm.node
-local function cloneObject(node, resolved)
+local function cloneObject(source, resolved)
if not resolved then
- return node
+ return source
end
- if node.type == 'doc.generic.name' then
- local key = node[1]
- return resolved[key] or node
+ if source.type == 'doc.generic.name' then
+ local key = source[1]
+ if not resolved[key] then
+ return source
+ end
+ return resolved[key] or source
end
- if node.type == 'doc.type' then
+ if source.type == 'doc.type' then
local newType = {
- type = node.type,
- start = node.start,
- finish = node.finish,
- parent = node.parent,
+ type = source.type,
+ start = source.start,
+ finish = source.finish,
+ parent = source.parent,
types = {},
}
- for i, typeUnit in ipairs(node.types) do
+ for i, typeUnit in ipairs(source.types) do
local newObj = cloneObject(typeUnit, resolved)
newObj.parent = newType
newType.types[i] = newObj
end
return newType
end
- if node.type == 'doc.type.arg' then
+ if source.type == 'doc.type.arg' then
local newArg = {
- type = node.type,
- start = node.start,
- finish = node.finish,
- parent = node.parent,
- name = node.name,
- extends = cloneObject(node.extends, resolved)
+ type = source.type,
+ start = source.start,
+ finish = source.finish,
+ parent = source.parent,
+ name = source.name,
+ extends = cloneObject(source.extends, resolved)
}
newArg.name.parent = newArg
newArg.extends.parent = newArg
return newArg
end
- if node.type == 'doc.type.array' then
+ if source.type == 'doc.type.array' then
local newArray = {
- type = node.type,
- start = node.start,
- finish = node.finish,
- parent = node.parent,
- node = cloneObject(node.node, resolved),
+ type = source.type,
+ start = source.start,
+ finish = source.finish,
+ parent = source.parent,
+ node = cloneObject(source.node, resolved),
}
newArray.node.parent = newArray
return newArray
end
- if node.type == 'doc.type.table' then
+ if source.type == 'doc.type.table' then
local newTable = {
- type = node.type,
- start = node.start,
- finish = node.finish,
- parent = node.parent,
+ type = source.type,
+ start = source.start,
+ finish = source.finish,
+ parent = source.parent,
fields = {},
}
- for i, field in ipairs(node.fields) do
+ for i, field in ipairs(source.fields) do
local newField = {
type = field.type,
start = field.start,
@@ -83,22 +87,22 @@ local function cloneObject(node, resolved)
end
return newTable
end
- if node.type == 'doc.type.function' then
+ if source.type == 'doc.type.function' then
local newDocFunc = {
- type = node.type,
- start = node.start,
- finish = node.finish,
- parent = node.parent,
+ type = source.type,
+ start = source.start,
+ finish = source.finish,
+ parent = source.parent,
args = {},
returns = {},
}
- for i, arg in ipairs(node.args) do
+ for i, arg in ipairs(source.args) do
local newObj = cloneObject(arg, resolved)
newObj.parent = newDocFunc
newObj.optional = arg.optional
newDocFunc.args[i] = newObj
end
- for i, ret in ipairs(node.returns) do
+ for i, ret in ipairs(source.returns) do
local newObj = cloneObject(ret, resolved)
newObj.parent = newDocFunc
newObj.optional = ret.optional
@@ -106,37 +110,27 @@ local function cloneObject(node, resolved)
end
return newDocFunc
end
- return node
+ return source
end
---@param uri uri
----@param argNodes vm.node[]
+---@param args parser.object
---@return parser.object
-function mt:resolve(uri, argNodes)
- local resolved = self.sign:resolve(uri, argNodes)
- local newProto = cloneObject(self.proto, resolved)
- return newProto
-end
-
-function mt:eachNode()
- local nodes = {}
- for n in nodeMgr.eachNode(self.proto) do
- nodes[#nodes+1] = n
- end
- local i = 0
- return function ()
- i = i + 1
- return nodes[i], self
+function mt:resolve(uri, args)
+ local resolved = self.sign:resolve(uri, args)
+ local result = union()
+ for nd in nodeMgr.eachObject(self.proto) do
+ result:merge(cloneObject(nd, resolved))
end
+ return result
end
---@param proto vm.node
---@param sign vm.sign
return function (proto, sign)
- local compiler = require 'vm.compiler'
local generic = setmetatable({
sign = sign,
- proto = compiler.compileNode(proto),
+ proto = proto,
}, mt)
return generic
end
diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua
index bd5b8696..dffc5c73 100644
--- a/script/vm/global-manager.lua
+++ b/script/vm/global-manager.lua
@@ -5,11 +5,11 @@ local signMgr = require 'vm.sign'
local genericMgr = require 'vm.generic'
---@class parser.object
----@field _globalNode vm.node.global
+---@field _globalNode vm.global
---@class vm.global-manager
local m = {}
----@type table<string, vm.node.global>
+---@type table<string, vm.global>
m.globals = {}
---@type table<uri, table<string, boolean>>
m.globalSubs = util.multiTable(2)
@@ -209,7 +209,7 @@ local compilerGlobalSwitch = util.switch()
---@param cate vm.global.cate
---@param name string
---@param uri uri
----@return vm.node.global
+---@return vm.global
function m.declareGlobal(cate, name, uri)
local key = cate .. '|' .. name
m.globalSubs[uri][key] = true
@@ -222,7 +222,7 @@ end
---@param cate vm.global.cate
---@param name string
---@param field? string
----@return vm.node.global?
+---@return vm.global?
function m.getGlobal(cate, name, field)
local key = cate .. '|' .. name
if field then
@@ -233,7 +233,7 @@ end
---@param cate vm.global.cate
---@param name string
----@return vm.node.global[]
+---@return vm.global[]
function m.getFields(cate, name)
local globals = {}
local key = cate .. '|' .. name
@@ -252,7 +252,7 @@ function m.getFields(cate, name)
end
---@param cate vm.global.cate
----@return vm.node.global[]
+---@return vm.global[]
function m.getGlobals(cate)
local globals = {}
@@ -327,7 +327,7 @@ function m.compileAst(source)
end)
end
----@return vm.node.global
+---@return vm.global
function m.getNode(source)
if source.type == 'field'
or source.type == 'method' then
diff --git a/script/vm/global.lua b/script/vm/global.lua
index 5ac18ced..9bce3060 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -1,12 +1,12 @@
local util = require 'utility'
local scope= require 'workspace.scope'
----@class vm.node.global.link
+---@class vm.global.link
---@field gets parser.object[]
---@field sets parser.object[]
----@class vm.node.global
----@field links table<uri, vm.node.global.link>
+---@class vm.global
+---@field links table<uri, vm.global.link>
---@field setsCache table<uri, parser.object[]>
---@field getsCache table<uri, parser.object[]>
---@field cate vm.global.cate
@@ -110,7 +110,7 @@ function mt:isAlive()
end
---@param cate vm.global.cate
----@return vm.node.global
+---@return vm.global
return function (name, cate)
return setmetatable({
name = name,
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 1f2e0123..7f7dbb8e 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -203,7 +203,7 @@ end
function mt:_eraseAlias()
local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias')
- for n in nodeMgr.eachNode(self.node) do
+ for n in nodeMgr.eachObject(self.node) do
if n.type == 'global' and n.cate == 'type' then
for _, set in ipairs(n:getSets(self.uri)) do
if set.type == 'doc.alias' then
@@ -250,7 +250,7 @@ function mt:_computeViews()
self.views = {}
- for n in nodeMgr.eachNode(self.node) do
+ for n in nodeMgr.eachObject(self.node) do
local view = viewNodeSwitch(n.type, n, self)
if view then
self.views[view] = true
@@ -343,7 +343,7 @@ function mt:viewLiterals()
end
local mark = {}
local literals = {}
- for n in nodeMgr.eachNode(self.node) do
+ for n in nodeMgr.eachObject(self.node) do
if n.type == 'string'
or n.type == 'number'
or n.type == 'integer'
@@ -369,7 +369,7 @@ function mt:viewClass()
end
local mark = {}
local class = {}
- for n in nodeMgr.eachNode(self.node) do
+ for n in nodeMgr.eachObject(self.node) do
if n.type == 'global' and n.cate == 'type' then
local name = n.name
if not mark[name] then
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 409841fc..eb1ee69d 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -1,14 +1,15 @@
local union = require 'vm.union'
local files = require 'files'
----@alias vm.node parser.object | vm.node.union | vm.node.global | vm.generic
+---@alias vm.node vm.node.union
+---@alias vm.object parser.object | vm.global | vm.generic
---@class vm.node-manager
local m = {}
local DUMMY_FUNCTION = function () end
----@type table<parser.object, vm.node>
+---@type table<vm.object, vm.node>
m.nodeCache = {}
---@param a vm.node
@@ -23,7 +24,7 @@ function m.mergeNode(a, b)
return union(a, b)
end
----@param source parser.object
+---@param source vm.object
---@param node vm.node
---@param cover? boolean
function m.setNode(source, node, cover)
@@ -36,21 +37,23 @@ function m.setNode(source, node, cover)
end
local me = m.nodeCache[source]
if not me then
- m.nodeCache[source] = node
- return
- end
- if me == node then
+ if node.type == 'union' then
+ m.nodeCache[source] = node
+ else
+ m.nodeCache[source] = union(node)
+ end
return
end
- m.nodeCache[source] = m.mergeNode(me, node)
+ m.nodeCache[source] = union(me, node)
end
+---@return vm.node?
function m.getNode(source)
return m.nodeCache[source]
end
----@param node vm.node
----@return vm.node.union
+---@param node vm.node?
+---@return vm.node
function m.addOptional(node)
if not node or node.type ~= 'union' then
node = union(node)
@@ -59,7 +62,7 @@ function m.addOptional(node)
return node
end
----@param node vm.node
+---@param node vm.node?
---@return vm.node.union?
function m.removeOptional(node)
if not node then
@@ -72,8 +75,8 @@ function m.removeOptional(node)
return node
end
----@return fun():vm.node
-function m.eachNode(node)
+---@return fun():vm.object
+function m.eachObject(node)
if not node then
return DUMMY_FUNCTION
end
diff --git a/script/vm/ref.lua b/script/vm/ref.lua
index 59b62c15..93412dee 100644
--- a/script/vm/ref.lua
+++ b/script/vm/ref.lua
@@ -272,7 +272,7 @@ local function searchByNode(source, pushResult)
return
end
local uri = guide.getUri(source)
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'global' then
for _, get in ipairs(n:getGets(uri)) do
pushResult(get)
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index d5168f2c..9773627e 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -16,10 +16,10 @@ function mt:addSign(node)
end
---@param uri uri
----@param argNodes vm.node[]
+---@param args parser.object
---@return table<string, vm.node>
-function mt:resolve(uri, argNodes)
- if not argNodes then
+function mt:resolve(uri, args)
+ if not args then
return nil
end
local compiler = require 'vm.compiler'
@@ -33,7 +33,7 @@ function mt:resolve(uri, argNodes)
local key = typeUnit[1]
if typeUnit.literal then
-- 'number' -> `T`
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'string' then
local type = globalMgr.declareGlobal('type', n[1], guide.getUri(n))
resolved[key] = nodeMgr.mergeNode(type, resolved[key])
@@ -45,10 +45,10 @@ function mt:resolve(uri, argNodes)
end
end
if typeUnit.type == 'doc.type.array' then
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'doc.type.array' then
-- number[] -> T[]
- resolve(typeUnit.node, n.node)
+ resolve(typeUnit.node, compiler.compileNode(n.node))
end
end
end
@@ -56,39 +56,39 @@ function mt:resolve(uri, argNodes)
for _, ufield in ipairs(typeUnit.fields) do
local ufieldNode = compiler.compileNode(ufield.name)
local uvalueNode = compiler.compileNode(ufield.extends)
- if ufieldNode.type == 'doc.generic.name' and uvalueNode.type == 'doc.generic.name' then
+ if ufieldNode[1].type == 'doc.generic.name' and uvalueNode.type[1] == 'doc.generic.name' then
-- { [number]: number} -> { [K]: V }
local tfieldNode = vm.getTableKey(uri, node, 'any')
local tvalueNode = vm.getTableValue(uri, node, 'any')
- resolve(ufieldNode, tfieldNode)
- resolve(uvalueNode, tvalueNode)
+ resolve(ufieldNode[1], tfieldNode)
+ resolve(uvalueNode[1], tvalueNode)
else
- if ufieldNode.type == 'doc.generic.name' then
+ if ufieldNode[1].type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [K]: number }
local tnode = vm.getTableKey(uri, node, uvalueNode)
- resolve(ufieldNode, tnode)
- else
+ resolve(ufieldNode[1], tnode)
+ elseif uvalueNode[1].type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [number]: V }
local tnode = vm.getTableValue(uri, node, ufieldNode)
- resolve(uvalueNode, tnode)
+ resolve(uvalueNode[1], tnode)
end
end
end
end
end
- for i, node in ipairs(argNodes) do
+ for i, arg in ipairs(args) do
local sign = self.signList[i]
if not sign then
break
end
- for n in nodeMgr.eachNode(sign) do
- node = compiler.compileNode(node)
- if node then
+ for n in nodeMgr.eachObject(sign) do
+ local argNode = compiler.compileNode(arg)
+ if argNode then
if sign.optional then
- node = nodeMgr.removeOptional(node)
+ argNode = nodeMgr.removeOptional(argNode)
end
- resolve(n, node)
+ resolve(n, argNode)
end
end
end
diff --git a/script/vm/type.lua b/script/vm/type.lua
index 82c0e3f4..56964df8 100644
--- a/script/vm/type.lua
+++ b/script/vm/type.lua
@@ -1,6 +1,7 @@
local nodeMgr = require 'vm.node'
local compiler = require 'vm.compiler'
local globalMgr = require 'vm.global-manager'
+local union = require 'vm.union'
---@class vm
local vm = require 'vm.vm'
@@ -22,7 +23,7 @@ function vm.isSubType(uri, child, parent, mark)
end
mark = mark or {}
- for childNode in nodeMgr.eachNode(child) do
+ for childNode in nodeMgr.eachObject(child) do
if childNode.type ~= 'global'
or childNode.cate ~= 'type' then
goto CONTINUE_CHILD
@@ -31,7 +32,7 @@ function vm.isSubType(uri, child, parent, mark)
return false
end
mark[childNode.name] = true
- for parentNode in nodeMgr.eachNode(parent) do
+ for parentNode in nodeMgr.eachObject(parent) do
if parentNode.type ~= 'global'
or parentNode.cate ~= 'type' then
goto CONTINUE_PARENT
@@ -73,69 +74,77 @@ end
---@param uri uri
---@param tnode vm.node
---@param knode vm.node
+---@return vm.node.union?
function vm.getTableValue(uri, tnode, knode)
- local result
- for tn in nodeMgr.eachNode(tnode) do
+ local result = union()
+ for tn in nodeMgr.eachObject(tnode) do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
if vm.isSubType(uri, compiler.compileNode(field.name), knode) then
- result = nodeMgr.mergeNode(result, compiler.compileNode(field.extends))
+ result:merge(compiler.compileNode(field.extends))
end
end
end
if tn.type == 'doc.type.array' then
- result = nodeMgr.mergeNode(result, compiler.compileNode(tn.node))
+ result:merge(compiler.compileNode(tn.node))
end
if tn.type == 'table' then
for _, field in ipairs(tn) do
if field.type == 'tableindex' then
- result = nodeMgr.mergeNode(result, compiler.compileNode(field.value))
+ result:merge(compiler.compileNode(field.value))
end
if field.type == 'tablefield' then
if vm.isSubType(uri, knode, 'string') then
- result = nodeMgr.mergeNode(result, compiler.compileNode(field.value))
+ result:merge(compiler.compileNode(field.value))
end
end
if field.type == 'tableexp' then
if vm.isSubType(uri, knode, 'integer') and field.tindex == 1 then
- result = nodeMgr.mergeNode(result, compiler.compileNode(field.value))
+ result:merge(compiler.compileNode(field.value))
end
end
end
end
end
+ if result:isEmpty() then
+ return nil
+ end
return result
end
---@param uri uri
---@param tnode vm.node
---@param vnode vm.node
+---@return vm.node.union?
function vm.getTableKey(uri, tnode, vnode)
- local result
- for tn in nodeMgr.eachNode(tnode) do
+ local result = union()
+ for tn in nodeMgr.eachObject(tnode) do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
if vm.isSubType(uri, compiler.compileNode(field.extends), vnode) then
- result = nodeMgr.mergeNode(result, compiler.compileNode(field.name))
+ result:merge(compiler.compileNode(field.name))
end
end
end
if tn.type == 'doc.type.array' then
- result = nodeMgr.mergeNode(result, globalMgr.getGlobal('type', 'integer'))
+ result:merge(globalMgr.getGlobal('type', 'integer'))
end
if tn.type == 'table' then
for _, field in ipairs(tn) do
if field.type == 'tableindex' then
- result = nodeMgr.mergeNode(result, compiler.compileNode(field.index))
+ result:merge(compiler.compileNode(field.index))
end
if field.type == 'tablefield' then
- result = nodeMgr.mergeNode(result, globalMgr.getGlobal('type', 'string'))
+ result:merge(globalMgr.getGlobal('type', 'string'))
end
if field.type == 'tableexp' then
- result = nodeMgr.mergeNode(result, globalMgr.getGlobal('type', 'integer'))
+ result:merge(globalMgr.getGlobal('type', 'integer'))
end
end
end
end
+ if result:isEmpty() then
+ return nil
+ end
return result
end
diff --git a/script/vm/union.lua b/script/vm/union.lua
index b66b34db..5be52de9 100644
--- a/script/vm/union.lua
+++ b/script/vm/union.lua
@@ -45,6 +45,11 @@ function mt:copy()
return createUnion(self, nil)
end
+---@return boolean
+function mt:isEmpty()
+ return #self == 0
+end
+
---@param source parser.object
function mt:subscribeLocal(source)
for _, c in ipairs(self) do
diff --git a/script/vm/value.lua b/script/vm/value.lua
index 4141ba8a..bf9e1423 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -10,7 +10,7 @@ local m = {}
function m.test(source)
local node = compiler.compileNode(source)
local hasTrue, hasFalse
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'boolean' then
if n[1] == true then
hasTrue = true
@@ -40,7 +40,7 @@ function m.test(source)
end
end
----@param v vm.node
+---@param v vm.object
---@return string?
local function getUnique(v)
if v.type == 'local' then
@@ -83,15 +83,15 @@ function m.equal(a, b)
local nodeA = compiler.compileNode(a)
local nodeB = compiler.compileNode(b)
local mapA = {}
- for n in nodeMgr.eachNode(nodeA) do
- local unique = getUnique(n)
+ for obj in nodeMgr.eachObject(nodeA) do
+ local unique = getUnique(obj)
if not unique then
return nil
end
mapA[unique] = true
end
- for n in nodeMgr.eachNode(nodeB) do
- local unique = getUnique(n)
+ for obj in nodeMgr.eachObject(nodeB) do
+ local unique = getUnique(obj)
if not unique then
return nil
end
@@ -107,7 +107,7 @@ end
function m.getInteger(v)
local node = compiler.compileNode(v)
local result
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'integer' then
if result then
return nil
@@ -135,7 +135,7 @@ end
function m.getString(v)
local node = compiler.compileNode(v)
local result
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'string' then
if result then
return nil
@@ -155,7 +155,7 @@ end
function m.getNumber(v)
local node = compiler.compileNode(v)
local result
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'number'
or n.type == 'integer' then
if result then
@@ -176,7 +176,7 @@ end
function m.getBoolean(v)
local node = compiler.compileNode(v)
local result
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
if n.type == 'boolean' then
if result then
return nil
@@ -196,7 +196,7 @@ end
function m.getLiterals(v)
local map
local node = compiler.compileNode(v)
- for n in nodeMgr.eachNode(node) do
+ for n in nodeMgr.eachObject(node) do
local literal
if n.type == 'boolean'
or n.type == 'string'