summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/plugin.lua57
-rw-r--r--script/vm/compiler.lua108
-rw-r--r--script/workspace/scope.lua5
3 files changed, 113 insertions, 57 deletions
diff --git a/script/plugin.lua b/script/plugin.lua
index 7a661e0d..b297cd9b 100644
--- a/script/plugin.lua
+++ b/script/plugin.lua
@@ -7,6 +7,15 @@ local scope = require 'workspace.scope'
local ws = require 'workspace'
local fs = require 'bee.filesystem'
+---@class pluginInterfaces
+local pluginConfigs = {
+ -- create plugin for vm module
+ VM = {
+ OnCompileFunctionParam = function (next, func, source)
+ end
+ }
+}
+
---@class plugin
local m = {}
@@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...)
return failed == 0, res1, res2
end
+function m.getVmPlugin(uri)
+ local scp = scope.getScope(uri)
+ local interfaces = scp:get('pluginInterfaces')
+ if not interfaces then
+ return
+ end
+ return interfaces.VM
+end
+
---@async
---@param scp scope
local function checkTrustLoad(scp)
@@ -78,6 +96,40 @@ local function checkTrustLoad(scp)
return true
end
+local function createMethodGroup(interfaces, key, methods)
+ local methodGroup = {}
+
+ for method in pairs(methods) do
+ local funcs = setmetatable({}, {
+ __call = function (t, next, ...)
+ if #t == 0 then
+ return next(...)
+ else
+ local result
+ for _, fn in ipairs(t) do
+ result = fn(next, ...)
+ end
+ return result
+ end
+ end
+ })
+ for _, interface in ipairs(interfaces) do
+ local func = interface[method]
+ if not func then
+ local namespace = interface[key]
+ if namespace then
+ func = namespace[method]
+ end
+ end
+ if func then
+ funcs[#funcs+1] = func
+ end
+ end
+ methodGroup[method] = funcs
+ end
+ return #methodGroup>0 and methodGroup or nil
+end
+
---@param uri uri
local function initPlugin(uri)
await.call(function () ---@async
@@ -148,6 +200,11 @@ local function initPlugin(uri)
end
interfaces[#interfaces+1] = interface
end
+
+ for key, config in pairs(pluginConfigs) do
+ interfaces[key] = createMethodGroup(interfaces, key, config)
+ end
+
ws.resetFiles(scp)
end)
end
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 4621006e..51931984 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1032,6 +1032,56 @@ local function compileForVars(source, target)
end
---@param source parser.object
+local function compileFunctionParam(func, source)
+ -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
+ local funcNode = vm.compileNode(func)
+ local hasDocArg
+ for n in funcNode:eachObject() do
+ if n.type == 'doc.type.function' then
+ for index, arg in ipairs(n.args) do
+ if func.args[index] == source then
+ local argNode = vm.compileNode(arg)
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ end
+ end
+ return true
+ end
+ end
+ end
+ end
+ if func.parent.type == 'local' then
+ local refs = func.parent.ref
+ local findCall
+ if refs then
+ for i, ref in ipairs(refs) do
+ if ref.parent.type == 'call' then
+ findCall = ref.parent
+ break
+ end
+ end
+ end
+ if findCall and findCall.args then
+ local index
+ for i, arg in ipairs(source.parent) do
+ if arg == source then
+ index = i
+ break
+ end
+ end
+ if index then
+ local callerArg = findCall.args[index]
+ if callerArg then
+ vm.setNode(source, vm.compileNode(callerArg))
+ return true
+ end
+ end
+ end
+ end
+end
+
+---@param source parser.object
local function compileLocal(source)
local myNode = vm.setNode(source, source)
@@ -1070,63 +1120,11 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(setfield.node))
end
end
-
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
- -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
- local funcNode = vm.compileNode(func)
- local hasDocArg
- for n in funcNode:eachObject() do
- if n.type == 'doc.type.function' then
- for index, arg in ipairs(n.args) do
- if func.args[index] == source then
- local argNode = vm.compileNode(arg)
- for an in argNode:eachObject() do
- if an.type ~= 'doc.generic.name' then
- vm.setNode(source, an)
- end
- end
- hasDocArg = true
- end
- end
- end
- end
- if not hasDocArg
- and func.parent.type == 'local' then
- local refs = func.parent.ref
- local findCall
- if refs then
- for i, ref in ipairs(refs) do
- if ref.parent.type == 'call' then
- findCall = ref.parent
- break
- end
- end
- end
- if findCall and findCall.args then
- local index
- for i, arg in ipairs(source.parent) do
- if arg == source then
- index = i
- break
- end
- end
- if index then
- local callerArg = findCall.args[index]
- if callerArg then
- hasDocArg = true
- vm.setNode(source, vm.compileNode(callerArg))
- end
- end
- end
- end
- if not hasDocArg then
- local suc, node = plugin.dispatch("OnNodeCompileFunctionParam", guide.getUri(source), source)
- if suc and node then
- hasDocArg = true
- vm.setNode(source, node)
- end
- end
+ local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
+ local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
+ or compileFunctionParam(func, source)
if not hasDocArg then
vm.setNode(source, vm.declareGlobal('type', 'any'))
end
diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua
index da72a1eb..a158c8de 100644
--- a/script/workspace/scope.lua
+++ b/script/workspace/scope.lua
@@ -125,8 +125,9 @@ function mt:set(k, v)
return v
end
----@param k string
----@return any
+---@generic T
+---@param k `T`
+---@return T
function mt:get(k)
return self._data[k]
end