diff options
Diffstat (limited to 'script/vm')
-rw-r--r-- | script/vm/compiler.lua | 43 | ||||
-rw-r--r-- | script/vm/generic.lua | 22 | ||||
-rw-r--r-- | script/vm/infer.lua | 2 | ||||
-rw-r--r-- | script/vm/node.lua | 2 | ||||
-rw-r--r-- | script/vm/sign.lua | 2 |
5 files changed, 37 insertions, 34 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 2393544c..e0d41d48 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -269,6 +269,9 @@ local function getObjectSign(source) return source._sign end +---@param func parser.object +---@param index integer +---@return vm.object? function m.getReturnOfFunction(func, index) if func.type == 'function' then if not func._returns then @@ -281,19 +284,18 @@ function m.getReturnOfFunction(func, index) index = index, } end - return m.compileNode(func._returns[index]) + return func._returns[index] end if func.type == 'doc.type.function' then local rtn = func.returns[index] if not rtn then return nil end - local rtnNode = m.compileNode(rtn) local sign = getObjectSign(func) if not sign then - return rtnNode + return rtn end - return genericMgr(rtnNode, sign) + return genericMgr(rtn, sign) end end @@ -376,7 +378,8 @@ local function getReturn(func, index, args) for cnode in nodeMgr.eachObject(node) do if cnode.type == 'function' or cnode.type == 'doc.type.function' then - local returnNode = m.getReturnOfFunction(cnode, index) + local returnObject = m.getReturnOfFunction(cnode, index) + local returnNode = m.compileNode(returnObject) if returnNode then for rnode in nodeMgr.eachObject(returnNode) do if rnode.type == 'generic' then @@ -580,14 +583,6 @@ end ---@param myIndex integer local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) local valueMgr = require 'vm.value' - if not myIndex then - for i, carg in ipairs(call.args) do - if carg == arg then - myIndex = i - fixIndex - break - end - end - end local eventIndex, eventMap if call.args then @@ -611,9 +606,11 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) end end if n.type == 'doc.type.function' then - local event = m.compileNode(n.args[eventIndex]) + local argNode = m.compileNode(n.args[eventIndex]) + local event = argNode and argNode[1] if not event or not eventMap + or myIndex <= eventIndex or event.type ~= 'doc.type.string' or eventMap[event[1]] then local farg = getFuncArg(n, myIndex) @@ -629,8 +626,17 @@ end ---@param arg parser.object ---@param call parser.position ----@param index integer +---@param index? integer function m.compileCallArg(arg, call, index) + if not index then + for i, carg in ipairs(call.args) do + if carg == arg then + index = i + break + end + end + end + local callNode = m.compileNode(call.node) compileCallArgNode(arg, call, callNode, 0, index) @@ -638,7 +644,7 @@ function m.compileCallArg(arg, call, index) or call.node.special == 'xpcall' then local fixIndex = call.node.special == 'pcall' and 1 or 2 callNode = m.compileNode(call.args[1]) - compileCallArgNode(arg, call, callNode, fixIndex, index) + compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex) end return nodeMgr.getNode(arg) end @@ -889,11 +895,10 @@ local compilerSwitch = util.switch() hasGeneric = true end) end - local rtnNode = m.compileNode(rtn) if hasGeneric then - nodeMgr.setNode(source, genericMgr(rtnNode, sign)) + nodeMgr.setNode(source, genericMgr(rtn, sign)) else - nodeMgr.setNode(source, rtnNode) + nodeMgr.setNode(source, m.compileNode(rtn)) end end end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index b9e01efd..351b820f 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -6,7 +6,7 @@ local union = require 'vm.union' ---@class vm.generic ---@field sign vm.sign ----@field proto vm.node +---@field proto vm.object local mt = {} mt.__index = mt mt.type = 'generic' @@ -97,17 +97,13 @@ local function cloneObject(source, resolved) } for i, arg in ipairs(source.args) do local newObj = cloneObject(arg, resolved) - if arg.optional and newObj.type == 'vm.union' then - newObj:addOptional() - end + newObj.optional = arg.optional newDocFunc.args[i] = newObj end for i, ret in ipairs(source.returns) do local newObj = cloneObject(ret, resolved) - newObj.parent = newDocFunc - if ret.optional and newObj.type == 'vm.union' then - newObj:addOptional() - end + newObj.parent = newDocFunc + newObj.optional = ret.optional newDocFunc.returns[i] = cloneObject(ret, resolved) end return newDocFunc @@ -119,18 +115,20 @@ end ---@param args parser.object ---@return parser.object function mt:resolve(uri, args) - local compiler = require 'vm.compiler' - local resolved = self.sign:resolve(uri, args) + local compiler = require 'vm.compiler' + local resolved = self.sign:resolve(uri, args) + local protoNode = compiler.compileNode(self.proto) local result = union() - for nd in nodeMgr.eachObject(self.proto) do + for nd in nodeMgr.eachObject(protoNode) do local clonedNode = compiler.compileNode(cloneObject(nd, resolved)) result:merge(clonedNode) end return result end ----@param proto vm.node +---@param proto vm.object ---@param sign vm.sign +---@return vm.generic return function (proto, sign) local generic = setmetatable({ sign = sign, diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 1cb57e4c..16ffb4f3 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -102,7 +102,7 @@ local viewNodeSwitch = util.switch() end) : case 'generic' : call(function (source, infer) - return ('<%s>'):format(source.proto[1]) + return m.getInfer(source.proto):view() end) : case 'doc.generic.name' : call(function (source, infer) diff --git a/script/vm/node.lua b/script/vm/node.lua index 12e2fd53..9f998181 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -25,7 +25,7 @@ function m.mergeNode(a, b) end ---@param source vm.object ----@param node vm.node +---@param node vm.node | vm.object ---@param cover? boolean function m.setNode(source, node, cover) if cover then diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 9773627e..2f70437f 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -56,7 +56,7 @@ function mt:resolve(uri, args) for _, ufield in ipairs(typeUnit.fields) do local ufieldNode = compiler.compileNode(ufield.name) local uvalueNode = compiler.compileNode(ufield.extends) - if ufieldNode[1].type == 'doc.generic.name' and uvalueNode.type[1] == 'doc.generic.name' then + if ufieldNode[1].type == 'doc.generic.name' and uvalueNode[1].type == 'doc.generic.name' then -- { [number]: number} -> { [K]: V } local tfieldNode = vm.getTableKey(uri, node, 'any') local tvalueNode = vm.getTableValue(uri, node, 'any') |