diff options
-rw-r--r-- | script/core/completion/completion.lua | 2 | ||||
-rw-r--r-- | script/vm/compiler.lua | 61 | ||||
-rw-r--r-- | script/vm/def.lua | 30 | ||||
-rw-r--r-- | script/vm/field.lua | 2 | ||||
-rw-r--r-- | script/vm/generic.lua | 108 | ||||
-rw-r--r-- | script/vm/global-manager.lua | 14 | ||||
-rw-r--r-- | script/vm/global.lua | 8 | ||||
-rw-r--r-- | script/vm/infer.lua | 8 | ||||
-rw-r--r-- | script/vm/node.lua | 29 | ||||
-rw-r--r-- | script/vm/ref.lua | 2 | ||||
-rw-r--r-- | script/vm/sign.lua | 38 | ||||
-rw-r--r-- | script/vm/type.lua | 41 | ||||
-rw-r--r-- | script/vm/union.lua | 5 | ||||
-rw-r--r-- | script/vm/value.lua | 22 |
14 files changed, 199 insertions, 171 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index 9f0c6695..f429ad14 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -1406,7 +1406,7 @@ local function tryCallArg(state, position, results) end local node = compiler.compileCallArg({ type = 'dummyarg' }, call, argIndex) local enums = {} - for src in nodeMgr.eachNode(node) do + for src in nodeMgr.eachObject(node) do if src.type == 'doc.type.string' or src.type == 'doc.type.integer' or src.type == 'doc.type.boolean' then diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 12c2c24c..c74d2780 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -55,7 +55,7 @@ local searchFieldSwitch = util.switch() end end) : case 'global' - ---@param node vm.node.global + ---@param node vm.global : call(function (suri, node, key, pushResult) if node.cate == 'variable' then if key then @@ -115,7 +115,7 @@ local searchFieldSwitch = util.switch() local fieldKey = field.name if fieldKey.type == 'doc.type' then local fieldNode = m.compileNode(fieldKey) - for fn in nodeMgr.eachNode(fieldNode) do + for fn in nodeMgr.eachObject(fieldNode) do if fn.type == 'global' and fn.cate == 'type' then if key == nil or fn.name == 'any' @@ -306,7 +306,7 @@ local function getReturnOfSetMetaTable(args) end if mt then m.compileByParentNode(mt, '__index', function (src) - for n in nodeMgr.eachNode(m.compileNode(src)) do + for n in nodeMgr.eachObject(m.compileNode(src)) do if n.type == 'global' or n.type == 'local' or n.type == 'table' @@ -373,16 +373,26 @@ local function getReturn(func, index, args) ---@type vm.node.union local result if node then - for cnode in nodeMgr.eachNode(node) do + for cnode in nodeMgr.eachObject(node) do if cnode.type == 'function' or cnode.type == 'doc.type.function' then local returnNode = m.getReturnOfFunction(cnode, index) - if returnNode and returnNode.type == 'generic' then - returnNode = returnNode:resolve(guide.getUri(func), args) + if returnNode then + for rnode in nodeMgr.eachObject(returnNode) do + if rnode.type == 'generic' then + returnNode = rnode:resolve(guide.getUri(func), args) + break + end + end end - if returnNode and returnNode.type ~= 'doc.generic.name' then - result = result or union() - result:merge(m.compileNode(returnNode)) + if returnNode then + for rnode in nodeMgr.eachObject(returnNode) do + -- TODO: narrow type + if rnode.type ~= 'doc.generic.name' then + result = result or union() + result:merge(rnode) + end + end end end end @@ -467,7 +477,7 @@ function m.compileByParentNode(source, key, pushResult) return end local suri = guide.getUri(source) - for node in nodeMgr.eachNode(parentNode) do + for node in nodeMgr.eachObject(parentNode) do searchFieldSwitch(node.type, suri, node, key, pushResult) end end @@ -505,7 +515,7 @@ local function selectNode(source, list, index) -- remove any for returns local rtnNode = union() local hasKnownType - for n in nodeMgr.eachNode(result) do + for n in nodeMgr.eachObject(result) do if guide.isLiteral(n) then hasKnownType = true rtnNode:merge(n) @@ -529,7 +539,7 @@ local function selectNode(source, list, index) end ---@param source parser.object ----@param node vm.node +---@param node vm.object ---@return boolean local function isValidCallArgNode(source, node) if source.type == 'function' then @@ -563,6 +573,11 @@ local function getFuncArg(func, index) return nil end +---@param arg parser.object +---@param call parser.object +---@param callNode vm.node +---@param fixIndex integer +---@param myIndex integer local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) local valueMgr = require 'vm.value' if not myIndex then @@ -586,10 +601,10 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) end end - for n in nodeMgr.eachNode(callNode) do + for n in nodeMgr.eachObject(callNode) do if n.type == 'function' then local farg = getFuncArg(n, myIndex) - for fn in nodeMgr.eachNode(m.compileNode(farg)) do + for fn in nodeMgr.eachObject(m.compileNode(farg)) do if isValidCallArgNode(arg, fn) then nodeMgr.setNode(arg, fn) end @@ -602,7 +617,7 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) or event.type ~= 'doc.type.string' or eventMap[event[1]] then local farg = getFuncArg(n, myIndex) - for fn in nodeMgr.eachNode(m.compileNode(farg)) do + for fn in nodeMgr.eachObject(m.compileNode(farg)) do if isValidCallArgNode(arg, fn) then nodeMgr.setNode(arg, fn) end @@ -612,6 +627,9 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) end end +---@param arg parser.object +---@param call parser.position +---@param index integer function m.compileCallArg(arg, call, index) local callNode = m.compileNode(call.node) compileCallArgNode(arg, call, callNode, 0, index) @@ -631,7 +649,6 @@ local compilerSwitch = util.switch() : case 'integer' : case 'number' : case 'string' - : case 'union' : case 'doc.type.function' : case 'doc.type.table' : case 'doc.type.array' @@ -731,7 +748,7 @@ local compilerSwitch = util.switch() local func = source.parent.parent local funcNode = m.compileNode(func) local hasDocArg - for n in nodeMgr.eachNode(funcNode) do + for n in nodeMgr.eachObject(funcNode) do if n.type == 'doc.type.function' then for index, arg in ipairs(n.args) do if func.args[index] == source then @@ -905,7 +922,7 @@ local compilerSwitch = util.switch() if not node then return end - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'global' and n.cate == 'type' and n.name == '...' then @@ -928,7 +945,7 @@ local compilerSwitch = util.switch() if not node then return end - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'global' and n.cate == 'type' and n.name == '...' then @@ -1427,12 +1444,12 @@ local compilerSwitch = util.switch() end end) ----@param source parser.object +---@param source vm.object local function compileByNode(source) compilerSwitch(source.type, source) end ----@param source vm.node +---@param source vm.object local function compileByGlobal(uri, source) uri = uri or guide.getUri(source) if source.type == 'global' then @@ -1501,7 +1518,7 @@ function m.resumeCache() end end ----@param source parser.object +---@param source vm.object ---@return vm.node function m.compileNode(source, uri) if not source then diff --git a/script/vm/def.lua b/script/vm/def.lua index 81d23854..6357dee7 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -77,8 +77,8 @@ simpleSwitch = util.switch() local searchFieldSwitch = util.switch() : case 'table' - : call(function (suri, node, key, pushResult) - for _, field in ipairs(node) do + : call(function (suri, obj, key, pushResult) + for _, field in ipairs(obj) do if field.type == 'tablefield' or field.type == 'tableindex' then if guide.getKeyName(field) == key then @@ -88,24 +88,24 @@ local searchFieldSwitch = util.switch() end end) : case 'global' - ---@param node vm.node + ---@param obj vm.object ---@param key string - : call(function (suri, node, key, pushResult) - if node.cate == 'variable' then - local newGlobal = globalMgr.getGlobal('variable', node.name, key) + : call(function (suri, obj, key, pushResult) + if obj.cate == 'variable' then + local newGlobal = globalMgr.getGlobal('variable', obj.name, key) if newGlobal then for _, set in ipairs(newGlobal:getSets(suri)) do pushResult(set) end end end - if node.cate == 'type' then - compiler.getClassFields(suri, node, key, pushResult) + if obj.cate == 'type' then + compiler.getClassFields(suri, obj, key, pushResult) end end) : case 'local' - : call(function (suri, node, key, pushResult) - local sources = localID.getSources(node, key) + : call(function (suri, obj, key, pushResult) + local sources = localID.getSources(obj, key) if sources then for _, src in ipairs(sources) do if guide.isSet(src) then @@ -115,8 +115,8 @@ local searchFieldSwitch = util.switch() end end) : case 'doc.type.table' - : call(function (suri, node, key, pushResult) - for _, field in ipairs(node.fields) do + : call(function (suri, obj, key, pushResult) + for _, field in ipairs(obj.fields) do local fieldKey = field.name if fieldKey.type == 'doc.field.name' then if fieldKey[1] == key then @@ -146,7 +146,7 @@ local nodeSwitch = util.switch() end local uri = guide.getUri(source) local key = guide.getKeyName(source) - for pn in nodeMgr.eachNode(parentNode) do + for pn in nodeMgr.eachObject(parentNode) do searchFieldSwitch(pn.type, uri, pn, key, pushResult) end end) @@ -164,7 +164,7 @@ local nodeSwitch = util.switch() return end local uri = guide.getUri(source) - for pn in nodeMgr.eachNode(parentNode) do + for pn in nodeMgr.eachObject(parentNode) do searchFieldSwitch(pn.type, uri, pn, source[1], pushResult) end end) @@ -201,7 +201,7 @@ local function searchByNode(source, pushResult) return end local suri = guide.getUri(source) - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'global' then for _, set in ipairs(n:getSets(suri)) do pushResult(set) diff --git a/script/vm/field.lua b/script/vm/field.lua index d79e3f6a..1bcf0b6b 100644 --- a/script/vm/field.lua +++ b/script/vm/field.lua @@ -6,7 +6,7 @@ local guide = require 'parser.guide' local searchByNodeSwitch = util.switch() : case 'global' - ---@param global vm.node.global + ---@param global vm.global : call(function (suri, global, pushResult) for _, set in ipairs(global:getSets(suri)) do pushResult(set) diff --git a/script/vm/generic.lua b/script/vm/generic.lua index bacea1e8..d2627388 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -1,4 +1,5 @@ local nodeMgr = require 'vm.node' +local union = require 'vm.union' ---@class parser.object ---@field _generic vm.generic @@ -10,65 +11,68 @@ local mt = {} mt.__index = mt mt.type = 'generic' ----@param node vm.node +---@param source parser.object ---@param resolved? table<string, vm.node> ---@return vm.node -local function cloneObject(node, resolved) +local function cloneObject(source, resolved) if not resolved then - return node + return source end - if node.type == 'doc.generic.name' then - local key = node[1] - return resolved[key] or node + if source.type == 'doc.generic.name' then + local key = source[1] + if not resolved[key] then + return source + end + return resolved[key] or source end - if node.type == 'doc.type' then + if source.type == 'doc.type' then local newType = { - type = node.type, - start = node.start, - finish = node.finish, - parent = node.parent, + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, types = {}, } - for i, typeUnit in ipairs(node.types) do + for i, typeUnit in ipairs(source.types) do local newObj = cloneObject(typeUnit, resolved) newObj.parent = newType newType.types[i] = newObj end return newType end - if node.type == 'doc.type.arg' then + if source.type == 'doc.type.arg' then local newArg = { - type = node.type, - start = node.start, - finish = node.finish, - parent = node.parent, - name = node.name, - extends = cloneObject(node.extends, resolved) + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, + name = source.name, + extends = cloneObject(source.extends, resolved) } newArg.name.parent = newArg newArg.extends.parent = newArg return newArg end - if node.type == 'doc.type.array' then + if source.type == 'doc.type.array' then local newArray = { - type = node.type, - start = node.start, - finish = node.finish, - parent = node.parent, - node = cloneObject(node.node, resolved), + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, + node = cloneObject(source.node, resolved), } newArray.node.parent = newArray return newArray end - if node.type == 'doc.type.table' then + if source.type == 'doc.type.table' then local newTable = { - type = node.type, - start = node.start, - finish = node.finish, - parent = node.parent, + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, fields = {}, } - for i, field in ipairs(node.fields) do + for i, field in ipairs(source.fields) do local newField = { type = field.type, start = field.start, @@ -83,22 +87,22 @@ local function cloneObject(node, resolved) end return newTable end - if node.type == 'doc.type.function' then + if source.type == 'doc.type.function' then local newDocFunc = { - type = node.type, - start = node.start, - finish = node.finish, - parent = node.parent, + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, args = {}, returns = {}, } - for i, arg in ipairs(node.args) do + for i, arg in ipairs(source.args) do local newObj = cloneObject(arg, resolved) newObj.parent = newDocFunc newObj.optional = arg.optional newDocFunc.args[i] = newObj end - for i, ret in ipairs(node.returns) do + for i, ret in ipairs(source.returns) do local newObj = cloneObject(ret, resolved) newObj.parent = newDocFunc newObj.optional = ret.optional @@ -106,37 +110,27 @@ local function cloneObject(node, resolved) end return newDocFunc end - return node + return source end ---@param uri uri ----@param argNodes vm.node[] +---@param args parser.object ---@return parser.object -function mt:resolve(uri, argNodes) - local resolved = self.sign:resolve(uri, argNodes) - local newProto = cloneObject(self.proto, resolved) - return newProto -end - -function mt:eachNode() - local nodes = {} - for n in nodeMgr.eachNode(self.proto) do - nodes[#nodes+1] = n - end - local i = 0 - return function () - i = i + 1 - return nodes[i], self +function mt:resolve(uri, args) + local resolved = self.sign:resolve(uri, args) + local result = union() + for nd in nodeMgr.eachObject(self.proto) do + result:merge(cloneObject(nd, resolved)) end + return result end ---@param proto vm.node ---@param sign vm.sign return function (proto, sign) - local compiler = require 'vm.compiler' local generic = setmetatable({ sign = sign, - proto = compiler.compileNode(proto), + proto = proto, }, mt) return generic end diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua index bd5b8696..dffc5c73 100644 --- a/script/vm/global-manager.lua +++ b/script/vm/global-manager.lua @@ -5,11 +5,11 @@ local signMgr = require 'vm.sign' local genericMgr = require 'vm.generic' ---@class parser.object ----@field _globalNode vm.node.global +---@field _globalNode vm.global ---@class vm.global-manager local m = {} ----@type table<string, vm.node.global> +---@type table<string, vm.global> m.globals = {} ---@type table<uri, table<string, boolean>> m.globalSubs = util.multiTable(2) @@ -209,7 +209,7 @@ local compilerGlobalSwitch = util.switch() ---@param cate vm.global.cate ---@param name string ---@param uri uri ----@return vm.node.global +---@return vm.global function m.declareGlobal(cate, name, uri) local key = cate .. '|' .. name m.globalSubs[uri][key] = true @@ -222,7 +222,7 @@ end ---@param cate vm.global.cate ---@param name string ---@param field? string ----@return vm.node.global? +---@return vm.global? function m.getGlobal(cate, name, field) local key = cate .. '|' .. name if field then @@ -233,7 +233,7 @@ end ---@param cate vm.global.cate ---@param name string ----@return vm.node.global[] +---@return vm.global[] function m.getFields(cate, name) local globals = {} local key = cate .. '|' .. name @@ -252,7 +252,7 @@ function m.getFields(cate, name) end ---@param cate vm.global.cate ----@return vm.node.global[] +---@return vm.global[] function m.getGlobals(cate) local globals = {} @@ -327,7 +327,7 @@ function m.compileAst(source) end) end ----@return vm.node.global +---@return vm.global function m.getNode(source) if source.type == 'field' or source.type == 'method' then diff --git a/script/vm/global.lua b/script/vm/global.lua index 5ac18ced..9bce3060 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -1,12 +1,12 @@ local util = require 'utility' local scope= require 'workspace.scope' ----@class vm.node.global.link +---@class vm.global.link ---@field gets parser.object[] ---@field sets parser.object[] ----@class vm.node.global ----@field links table<uri, vm.node.global.link> +---@class vm.global +---@field links table<uri, vm.global.link> ---@field setsCache table<uri, parser.object[]> ---@field getsCache table<uri, parser.object[]> ---@field cate vm.global.cate @@ -110,7 +110,7 @@ function mt:isAlive() end ---@param cate vm.global.cate ----@return vm.node.global +---@return vm.global return function (name, cate) return setmetatable({ name = name, diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 1f2e0123..7f7dbb8e 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -203,7 +203,7 @@ end function mt:_eraseAlias() local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias') - for n in nodeMgr.eachNode(self.node) do + for n in nodeMgr.eachObject(self.node) do if n.type == 'global' and n.cate == 'type' then for _, set in ipairs(n:getSets(self.uri)) do if set.type == 'doc.alias' then @@ -250,7 +250,7 @@ function mt:_computeViews() self.views = {} - for n in nodeMgr.eachNode(self.node) do + for n in nodeMgr.eachObject(self.node) do local view = viewNodeSwitch(n.type, n, self) if view then self.views[view] = true @@ -343,7 +343,7 @@ function mt:viewLiterals() end local mark = {} local literals = {} - for n in nodeMgr.eachNode(self.node) do + for n in nodeMgr.eachObject(self.node) do if n.type == 'string' or n.type == 'number' or n.type == 'integer' @@ -369,7 +369,7 @@ function mt:viewClass() end local mark = {} local class = {} - for n in nodeMgr.eachNode(self.node) do + for n in nodeMgr.eachObject(self.node) do if n.type == 'global' and n.cate == 'type' then local name = n.name if not mark[name] then diff --git a/script/vm/node.lua b/script/vm/node.lua index 409841fc..eb1ee69d 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -1,14 +1,15 @@ local union = require 'vm.union' local files = require 'files' ----@alias vm.node parser.object | vm.node.union | vm.node.global | vm.generic +---@alias vm.node vm.node.union +---@alias vm.object parser.object | vm.global | vm.generic ---@class vm.node-manager local m = {} local DUMMY_FUNCTION = function () end ----@type table<parser.object, vm.node> +---@type table<vm.object, vm.node> m.nodeCache = {} ---@param a vm.node @@ -23,7 +24,7 @@ function m.mergeNode(a, b) return union(a, b) end ----@param source parser.object +---@param source vm.object ---@param node vm.node ---@param cover? boolean function m.setNode(source, node, cover) @@ -36,21 +37,23 @@ function m.setNode(source, node, cover) end local me = m.nodeCache[source] if not me then - m.nodeCache[source] = node - return - end - if me == node then + if node.type == 'union' then + m.nodeCache[source] = node + else + m.nodeCache[source] = union(node) + end return end - m.nodeCache[source] = m.mergeNode(me, node) + m.nodeCache[source] = union(me, node) end +---@return vm.node? function m.getNode(source) return m.nodeCache[source] end ----@param node vm.node ----@return vm.node.union +---@param node vm.node? +---@return vm.node function m.addOptional(node) if not node or node.type ~= 'union' then node = union(node) @@ -59,7 +62,7 @@ function m.addOptional(node) return node end ----@param node vm.node +---@param node vm.node? ---@return vm.node.union? function m.removeOptional(node) if not node then @@ -72,8 +75,8 @@ function m.removeOptional(node) return node end ----@return fun():vm.node -function m.eachNode(node) +---@return fun():vm.object +function m.eachObject(node) if not node then return DUMMY_FUNCTION end diff --git a/script/vm/ref.lua b/script/vm/ref.lua index 59b62c15..93412dee 100644 --- a/script/vm/ref.lua +++ b/script/vm/ref.lua @@ -272,7 +272,7 @@ local function searchByNode(source, pushResult) return end local uri = guide.getUri(source) - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'global' then for _, get in ipairs(n:getGets(uri)) do pushResult(get) diff --git a/script/vm/sign.lua b/script/vm/sign.lua index d5168f2c..9773627e 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -16,10 +16,10 @@ function mt:addSign(node) end ---@param uri uri ----@param argNodes vm.node[] +---@param args parser.object ---@return table<string, vm.node> -function mt:resolve(uri, argNodes) - if not argNodes then +function mt:resolve(uri, args) + if not args then return nil end local compiler = require 'vm.compiler' @@ -33,7 +33,7 @@ function mt:resolve(uri, argNodes) local key = typeUnit[1] if typeUnit.literal then -- 'number' -> `T` - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'string' then local type = globalMgr.declareGlobal('type', n[1], guide.getUri(n)) resolved[key] = nodeMgr.mergeNode(type, resolved[key]) @@ -45,10 +45,10 @@ function mt:resolve(uri, argNodes) end end if typeUnit.type == 'doc.type.array' then - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'doc.type.array' then -- number[] -> T[] - resolve(typeUnit.node, n.node) + resolve(typeUnit.node, compiler.compileNode(n.node)) end end end @@ -56,39 +56,39 @@ function mt:resolve(uri, argNodes) for _, ufield in ipairs(typeUnit.fields) do local ufieldNode = compiler.compileNode(ufield.name) local uvalueNode = compiler.compileNode(ufield.extends) - if ufieldNode.type == 'doc.generic.name' and uvalueNode.type == 'doc.generic.name' then + if ufieldNode[1].type == 'doc.generic.name' and uvalueNode.type[1] == 'doc.generic.name' then -- { [number]: number} -> { [K]: V } local tfieldNode = vm.getTableKey(uri, node, 'any') local tvalueNode = vm.getTableValue(uri, node, 'any') - resolve(ufieldNode, tfieldNode) - resolve(uvalueNode, tvalueNode) + resolve(ufieldNode[1], tfieldNode) + resolve(uvalueNode[1], tvalueNode) else - if ufieldNode.type == 'doc.generic.name' then + if ufieldNode[1].type == 'doc.generic.name' then -- { [number]: number}|number[] -> { [K]: number } local tnode = vm.getTableKey(uri, node, uvalueNode) - resolve(ufieldNode, tnode) - else + resolve(ufieldNode[1], tnode) + elseif uvalueNode[1].type == 'doc.generic.name' then -- { [number]: number}|number[] -> { [number]: V } local tnode = vm.getTableValue(uri, node, ufieldNode) - resolve(uvalueNode, tnode) + resolve(uvalueNode[1], tnode) end end end end end - for i, node in ipairs(argNodes) do + for i, arg in ipairs(args) do local sign = self.signList[i] if not sign then break end - for n in nodeMgr.eachNode(sign) do - node = compiler.compileNode(node) - if node then + for n in nodeMgr.eachObject(sign) do + local argNode = compiler.compileNode(arg) + if argNode then if sign.optional then - node = nodeMgr.removeOptional(node) + argNode = nodeMgr.removeOptional(argNode) end - resolve(n, node) + resolve(n, argNode) end end end diff --git a/script/vm/type.lua b/script/vm/type.lua index 82c0e3f4..56964df8 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -1,6 +1,7 @@ local nodeMgr = require 'vm.node' local compiler = require 'vm.compiler' local globalMgr = require 'vm.global-manager' +local union = require 'vm.union' ---@class vm local vm = require 'vm.vm' @@ -22,7 +23,7 @@ function vm.isSubType(uri, child, parent, mark) end mark = mark or {} - for childNode in nodeMgr.eachNode(child) do + for childNode in nodeMgr.eachObject(child) do if childNode.type ~= 'global' or childNode.cate ~= 'type' then goto CONTINUE_CHILD @@ -31,7 +32,7 @@ function vm.isSubType(uri, child, parent, mark) return false end mark[childNode.name] = true - for parentNode in nodeMgr.eachNode(parent) do + for parentNode in nodeMgr.eachObject(parent) do if parentNode.type ~= 'global' or parentNode.cate ~= 'type' then goto CONTINUE_PARENT @@ -73,69 +74,77 @@ end ---@param uri uri ---@param tnode vm.node ---@param knode vm.node +---@return vm.node.union? function vm.getTableValue(uri, tnode, knode) - local result - for tn in nodeMgr.eachNode(tnode) do + local result = union() + for tn in nodeMgr.eachObject(tnode) do if tn.type == 'doc.type.table' then for _, field in ipairs(tn.fields) do if vm.isSubType(uri, compiler.compileNode(field.name), knode) then - result = nodeMgr.mergeNode(result, compiler.compileNode(field.extends)) + result:merge(compiler.compileNode(field.extends)) end end end if tn.type == 'doc.type.array' then - result = nodeMgr.mergeNode(result, compiler.compileNode(tn.node)) + result:merge(compiler.compileNode(tn.node)) end if tn.type == 'table' then for _, field in ipairs(tn) do if field.type == 'tableindex' then - result = nodeMgr.mergeNode(result, compiler.compileNode(field.value)) + result:merge(compiler.compileNode(field.value)) end if field.type == 'tablefield' then if vm.isSubType(uri, knode, 'string') then - result = nodeMgr.mergeNode(result, compiler.compileNode(field.value)) + result:merge(compiler.compileNode(field.value)) end end if field.type == 'tableexp' then if vm.isSubType(uri, knode, 'integer') and field.tindex == 1 then - result = nodeMgr.mergeNode(result, compiler.compileNode(field.value)) + result:merge(compiler.compileNode(field.value)) end end end end end + if result:isEmpty() then + return nil + end return result end ---@param uri uri ---@param tnode vm.node ---@param vnode vm.node +---@return vm.node.union? function vm.getTableKey(uri, tnode, vnode) - local result - for tn in nodeMgr.eachNode(tnode) do + local result = union() + for tn in nodeMgr.eachObject(tnode) do if tn.type == 'doc.type.table' then for _, field in ipairs(tn.fields) do if vm.isSubType(uri, compiler.compileNode(field.extends), vnode) then - result = nodeMgr.mergeNode(result, compiler.compileNode(field.name)) + result:merge(compiler.compileNode(field.name)) end end end if tn.type == 'doc.type.array' then - result = nodeMgr.mergeNode(result, globalMgr.getGlobal('type', 'integer')) + result:merge(globalMgr.getGlobal('type', 'integer')) end if tn.type == 'table' then for _, field in ipairs(tn) do if field.type == 'tableindex' then - result = nodeMgr.mergeNode(result, compiler.compileNode(field.index)) + result:merge(compiler.compileNode(field.index)) end if field.type == 'tablefield' then - result = nodeMgr.mergeNode(result, globalMgr.getGlobal('type', 'string')) + result:merge(globalMgr.getGlobal('type', 'string')) end if field.type == 'tableexp' then - result = nodeMgr.mergeNode(result, globalMgr.getGlobal('type', 'integer')) + result:merge(globalMgr.getGlobal('type', 'integer')) end end end end + if result:isEmpty() then + return nil + end return result end diff --git a/script/vm/union.lua b/script/vm/union.lua index b66b34db..5be52de9 100644 --- a/script/vm/union.lua +++ b/script/vm/union.lua @@ -45,6 +45,11 @@ function mt:copy() return createUnion(self, nil) end +---@return boolean +function mt:isEmpty() + return #self == 0 +end + ---@param source parser.object function mt:subscribeLocal(source) for _, c in ipairs(self) do diff --git a/script/vm/value.lua b/script/vm/value.lua index 4141ba8a..bf9e1423 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -10,7 +10,7 @@ local m = {} function m.test(source) local node = compiler.compileNode(source) local hasTrue, hasFalse - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'boolean' then if n[1] == true then hasTrue = true @@ -40,7 +40,7 @@ function m.test(source) end end ----@param v vm.node +---@param v vm.object ---@return string? local function getUnique(v) if v.type == 'local' then @@ -83,15 +83,15 @@ function m.equal(a, b) local nodeA = compiler.compileNode(a) local nodeB = compiler.compileNode(b) local mapA = {} - for n in nodeMgr.eachNode(nodeA) do - local unique = getUnique(n) + for obj in nodeMgr.eachObject(nodeA) do + local unique = getUnique(obj) if not unique then return nil end mapA[unique] = true end - for n in nodeMgr.eachNode(nodeB) do - local unique = getUnique(n) + for obj in nodeMgr.eachObject(nodeB) do + local unique = getUnique(obj) if not unique then return nil end @@ -107,7 +107,7 @@ end function m.getInteger(v) local node = compiler.compileNode(v) local result - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'integer' then if result then return nil @@ -135,7 +135,7 @@ end function m.getString(v) local node = compiler.compileNode(v) local result - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'string' then if result then return nil @@ -155,7 +155,7 @@ end function m.getNumber(v) local node = compiler.compileNode(v) local result - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'number' or n.type == 'integer' then if result then @@ -176,7 +176,7 @@ end function m.getBoolean(v) local node = compiler.compileNode(v) local result - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do if n.type == 'boolean' then if result then return nil @@ -196,7 +196,7 @@ end function m.getLiterals(v) local map local node = compiler.compileNode(v) - for n in nodeMgr.eachNode(node) do + for n in nodeMgr.eachObject(node) do local literal if n.type == 'boolean' or n.type == 'string' |