summaryrefslogtreecommitdiff
path: root/script/vm/compiler.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/vm/compiler.lua')
-rw-r--r--script/vm/compiler.lua62
1 files changed, 23 insertions, 39 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 2253c83a..cb1fa663 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1031,6 +1031,7 @@ local function compileForVars(source, target)
return false
end
+---@param func parser.object
---@param source parser.object
local function compileFunctionParam(func, source)
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
@@ -1050,33 +1051,31 @@ local function compileFunctionParam(func, source)
end
end
end
- if func.parent.type == 'local' then
+
+ local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')
+ if derviationParam and func.parent.type == 'local' and func.parent.ref then
local refs = func.parent.ref
- local findCall
- if refs then
- for i, ref in ipairs(refs) do
- if ref.parent.type == 'call' then
- findCall = ref.parent
- break
- end
+ local finded
+ for _, ref in ipairs(refs) do
+ if ref.parent.type ~= 'call' then
+ goto continue
end
- end
- if findCall and findCall.args then
- local index
- for i, arg in ipairs(source.parent) do
- if arg == source then
- index = i
- break
- end
+ local caller = ref.parent
+ if not caller.args then
+ goto continue
end
- if index then
- local callerArg = findCall.args[index]
- if callerArg then
- vm.setNode(source, vm.compileNode(callerArg))
- return true
+ for index, arg in ipairs(source.parent) do
+ if arg == source then
+ local callerArg = caller.args[index]
+ if callerArg then
+ vm.setNode(source, vm.compileNode(callerArg))
+ finded = true
+ end
end
end
+ ::continue::
end
+ return finded
end
end
@@ -1121,24 +1120,9 @@ local function compileLocal(source)
end
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
- -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
- local funcNode = vm.compileNode(func)
- local hasDocArg
- for n in funcNode:eachObject() do
- if n.type == 'doc.type.function' then
- for index, arg in ipairs(n.args) do
- if func.args[index] == source then
- local argNode = vm.compileNode(arg)
- for an in argNode:eachObject() do
- if an.type ~= 'doc.generic.name' then
- vm.setNode(source, an)
- end
- end
- hasDocArg = true
- end
- end
- end
- end
+ local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
+ local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
+ or compileFunctionParam(func, source)
if not hasDocArg then
vm.setNode(source, vm.declareGlobal('type', 'any'))
end