summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/vm/compiler.lua43
-rw-r--r--script/vm/generic.lua22
-rw-r--r--script/vm/infer.lua2
-rw-r--r--script/vm/node.lua2
-rw-r--r--script/vm/sign.lua2
5 files changed, 37 insertions, 34 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 2393544c..e0d41d48 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -269,6 +269,9 @@ local function getObjectSign(source)
return source._sign
end
+---@param func parser.object
+---@param index integer
+---@return vm.object?
function m.getReturnOfFunction(func, index)
if func.type == 'function' then
if not func._returns then
@@ -281,19 +284,18 @@ function m.getReturnOfFunction(func, index)
index = index,
}
end
- return m.compileNode(func._returns[index])
+ return func._returns[index]
end
if func.type == 'doc.type.function' then
local rtn = func.returns[index]
if not rtn then
return nil
end
- local rtnNode = m.compileNode(rtn)
local sign = getObjectSign(func)
if not sign then
- return rtnNode
+ return rtn
end
- return genericMgr(rtnNode, sign)
+ return genericMgr(rtn, sign)
end
end
@@ -376,7 +378,8 @@ local function getReturn(func, index, args)
for cnode in nodeMgr.eachObject(node) do
if cnode.type == 'function'
or cnode.type == 'doc.type.function' then
- local returnNode = m.getReturnOfFunction(cnode, index)
+ local returnObject = m.getReturnOfFunction(cnode, index)
+ local returnNode = m.compileNode(returnObject)
if returnNode then
for rnode in nodeMgr.eachObject(returnNode) do
if rnode.type == 'generic' then
@@ -580,14 +583,6 @@ end
---@param myIndex integer
local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
local valueMgr = require 'vm.value'
- if not myIndex then
- for i, carg in ipairs(call.args) do
- if carg == arg then
- myIndex = i - fixIndex
- break
- end
- end
- end
local eventIndex, eventMap
if call.args then
@@ -611,9 +606,11 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
end
end
if n.type == 'doc.type.function' then
- local event = m.compileNode(n.args[eventIndex])
+ local argNode = m.compileNode(n.args[eventIndex])
+ local event = argNode and argNode[1]
if not event
or not eventMap
+ or myIndex <= eventIndex
or event.type ~= 'doc.type.string'
or eventMap[event[1]] then
local farg = getFuncArg(n, myIndex)
@@ -629,8 +626,17 @@ end
---@param arg parser.object
---@param call parser.position
----@param index integer
+---@param index? integer
function m.compileCallArg(arg, call, index)
+ if not index then
+ for i, carg in ipairs(call.args) do
+ if carg == arg then
+ index = i
+ break
+ end
+ end
+ end
+
local callNode = m.compileNode(call.node)
compileCallArgNode(arg, call, callNode, 0, index)
@@ -638,7 +644,7 @@ function m.compileCallArg(arg, call, index)
or call.node.special == 'xpcall' then
local fixIndex = call.node.special == 'pcall' and 1 or 2
callNode = m.compileNode(call.args[1])
- compileCallArgNode(arg, call, callNode, fixIndex, index)
+ compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex)
end
return nodeMgr.getNode(arg)
end
@@ -889,11 +895,10 @@ local compilerSwitch = util.switch()
hasGeneric = true
end)
end
- local rtnNode = m.compileNode(rtn)
if hasGeneric then
- nodeMgr.setNode(source, genericMgr(rtnNode, sign))
+ nodeMgr.setNode(source, genericMgr(rtn, sign))
else
- nodeMgr.setNode(source, rtnNode)
+ nodeMgr.setNode(source, m.compileNode(rtn))
end
end
end
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index b9e01efd..351b820f 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -6,7 +6,7 @@ local union = require 'vm.union'
---@class vm.generic
---@field sign vm.sign
----@field proto vm.node
+---@field proto vm.object
local mt = {}
mt.__index = mt
mt.type = 'generic'
@@ -97,17 +97,13 @@ local function cloneObject(source, resolved)
}
for i, arg in ipairs(source.args) do
local newObj = cloneObject(arg, resolved)
- if arg.optional and newObj.type == 'vm.union' then
- newObj:addOptional()
- end
+ newObj.optional = arg.optional
newDocFunc.args[i] = newObj
end
for i, ret in ipairs(source.returns) do
local newObj = cloneObject(ret, resolved)
- newObj.parent = newDocFunc
- if ret.optional and newObj.type == 'vm.union' then
- newObj:addOptional()
- end
+ newObj.parent = newDocFunc
+ newObj.optional = ret.optional
newDocFunc.returns[i] = cloneObject(ret, resolved)
end
return newDocFunc
@@ -119,18 +115,20 @@ end
---@param args parser.object
---@return parser.object
function mt:resolve(uri, args)
- local compiler = require 'vm.compiler'
- local resolved = self.sign:resolve(uri, args)
+ local compiler = require 'vm.compiler'
+ local resolved = self.sign:resolve(uri, args)
+ local protoNode = compiler.compileNode(self.proto)
local result = union()
- for nd in nodeMgr.eachObject(self.proto) do
+ for nd in nodeMgr.eachObject(protoNode) do
local clonedNode = compiler.compileNode(cloneObject(nd, resolved))
result:merge(clonedNode)
end
return result
end
----@param proto vm.node
+---@param proto vm.object
---@param sign vm.sign
+---@return vm.generic
return function (proto, sign)
local generic = setmetatable({
sign = sign,
diff --git a/script/vm/infer.lua b/script/vm/infer.lua
index 1cb57e4c..16ffb4f3 100644
--- a/script/vm/infer.lua
+++ b/script/vm/infer.lua
@@ -102,7 +102,7 @@ local viewNodeSwitch = util.switch()
end)
: case 'generic'
: call(function (source, infer)
- return ('<%s>'):format(source.proto[1])
+ return m.getInfer(source.proto):view()
end)
: case 'doc.generic.name'
: call(function (source, infer)
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 12e2fd53..9f998181 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -25,7 +25,7 @@ function m.mergeNode(a, b)
end
---@param source vm.object
----@param node vm.node
+---@param node vm.node | vm.object
---@param cover? boolean
function m.setNode(source, node, cover)
if cover then
diff --git a/script/vm/sign.lua b/script/vm/sign.lua
index 9773627e..2f70437f 100644
--- a/script/vm/sign.lua
+++ b/script/vm/sign.lua
@@ -56,7 +56,7 @@ function mt:resolve(uri, args)
for _, ufield in ipairs(typeUnit.fields) do
local ufieldNode = compiler.compileNode(ufield.name)
local uvalueNode = compiler.compileNode(ufield.extends)
- if ufieldNode[1].type == 'doc.generic.name' and uvalueNode.type[1] == 'doc.generic.name' then
+ if ufieldNode[1].type == 'doc.generic.name' and uvalueNode[1].type == 'doc.generic.name' then
-- { [number]: number} -> { [K]: V }
local tfieldNode = vm.getTableKey(uri, node, 'any')
local tvalueNode = vm.getTableValue(uri, node, 'any')