summaryrefslogtreecommitdiff
path: root/script/vm/compiler.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/compiler.lua')
-rw-r--r--script/vm/compiler.lua108
1 files changed, 53 insertions, 55 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 4621006e..51931984 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1032,6 +1032,56 @@ local function compileForVars(source, target)
end
---@param source parser.object
+local function compileFunctionParam(func, source)
+ -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
+ local funcNode = vm.compileNode(func)
+ local hasDocArg
+ for n in funcNode:eachObject() do
+ if n.type == 'doc.type.function' then
+ for index, arg in ipairs(n.args) do
+ if func.args[index] == source then
+ local argNode = vm.compileNode(arg)
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ end
+ end
+ return true
+ end
+ end
+ end
+ end
+ if func.parent.type == 'local' then
+ local refs = func.parent.ref
+ local findCall
+ if refs then
+ for i, ref in ipairs(refs) do
+ if ref.parent.type == 'call' then
+ findCall = ref.parent
+ break
+ end
+ end
+ end
+ if findCall and findCall.args then
+ local index
+ for i, arg in ipairs(source.parent) do
+ if arg == source then
+ index = i
+ break
+ end
+ end
+ if index then
+ local callerArg = findCall.args[index]
+ if callerArg then
+ vm.setNode(source, vm.compileNode(callerArg))
+ return true
+ end
+ end
+ end
+ end
+end
+
+---@param source parser.object
local function compileLocal(source)
local myNode = vm.setNode(source, source)
@@ -1070,63 +1120,11 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(setfield.node))
end
end
-
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
- -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
- local funcNode = vm.compileNode(func)
- local hasDocArg
- for n in funcNode:eachObject() do
- if n.type == 'doc.type.function' then
- for index, arg in ipairs(n.args) do
- if func.args[index] == source then
- local argNode = vm.compileNode(arg)
- for an in argNode:eachObject() do
- if an.type ~= 'doc.generic.name' then
- vm.setNode(source, an)
- end
- end
- hasDocArg = true
- end
- end
- end
- end
- if not hasDocArg
- and func.parent.type == 'local' then
- local refs = func.parent.ref
- local findCall
- if refs then
- for i, ref in ipairs(refs) do
- if ref.parent.type == 'call' then
- findCall = ref.parent
- break
- end
- end
- end
- if findCall and findCall.args then
- local index
- for i, arg in ipairs(source.parent) do
- if arg == source then
- index = i
- break
- end
- end
- if index then
- local callerArg = findCall.args[index]
- if callerArg then
- hasDocArg = true
- vm.setNode(source, vm.compileNode(callerArg))
- end
- end
- end
- end
- if not hasDocArg then
- local suc, node = plugin.dispatch("OnNodeCompileFunctionParam", guide.getUri(source), source)
- if suc and node then
- hasDocArg = true
- vm.setNode(source, node)
- end
- end
+ local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
+ local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
+ or compileFunctionParam(func, source)
if not hasDocArg then
vm.setNode(source, vm.declareGlobal('type', 'any'))
end