diff options
Diffstat (limited to 'script')
-rw-r--r-- | script/plugin.lua | 57 | ||||
-rw-r--r-- | script/vm/compiler.lua | 108 | ||||
-rw-r--r-- | script/workspace/scope.lua | 5 |
3 files changed, 113 insertions, 57 deletions
diff --git a/script/plugin.lua b/script/plugin.lua index 7a661e0d..b297cd9b 100644 --- a/script/plugin.lua +++ b/script/plugin.lua @@ -7,6 +7,15 @@ local scope = require 'workspace.scope' local ws = require 'workspace' local fs = require 'bee.filesystem' +---@class pluginInterfaces +local pluginConfigs = { + -- create plugin for vm module + VM = { + OnCompileFunctionParam = function (next, func, source) + end + } +} + ---@class plugin local m = {} @@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...) return failed == 0, res1, res2 end +function m.getVmPlugin(uri) + local scp = scope.getScope(uri) + local interfaces = scp:get('pluginInterfaces') + if not interfaces then + return + end + return interfaces.VM +end + ---@async ---@param scp scope local function checkTrustLoad(scp) @@ -78,6 +96,40 @@ local function checkTrustLoad(scp) return true end +local function createMethodGroup(interfaces, key, methods) + local methodGroup = {} + + for method in pairs(methods) do + local funcs = setmetatable({}, { + __call = function (t, next, ...) + if #t == 0 then + return next(...) + else + local result + for _, fn in ipairs(t) do + result = fn(next, ...) + end + return result + end + end + }) + for _, interface in ipairs(interfaces) do + local func = interface[method] + if not func then + local namespace = interface[key] + if namespace then + func = namespace[method] + end + end + if func then + funcs[#funcs+1] = func + end + end + methodGroup[method] = funcs + end + return #methodGroup>0 and methodGroup or nil +end + ---@param uri uri local function initPlugin(uri) await.call(function () ---@async @@ -148,6 +200,11 @@ local function initPlugin(uri) end interfaces[#interfaces+1] = interface end + + for key, config in pairs(pluginConfigs) do + interfaces[key] = createMethodGroup(interfaces, key, config) + end + ws.resetFiles(scp) end) end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 4621006e..51931984 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1032,6 +1032,56 @@ local function compileForVars(source, target) end ---@param source parser.object +local function compileFunctionParam(func, source) + -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number + local funcNode = vm.compileNode(func) + local hasDocArg + for n in funcNode:eachObject() do + if n.type == 'doc.type.function' then + for index, arg in ipairs(n.args) do + if func.args[index] == source then + local argNode = vm.compileNode(arg) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end + end + return true + end + end + end + end + if func.parent.type == 'local' then + local refs = func.parent.ref + local findCall + if refs then + for i, ref in ipairs(refs) do + if ref.parent.type == 'call' then + findCall = ref.parent + break + end + end + end + if findCall and findCall.args then + local index + for i, arg in ipairs(source.parent) do + if arg == source then + index = i + break + end + end + if index then + local callerArg = findCall.args[index] + if callerArg then + vm.setNode(source, vm.compileNode(callerArg)) + return true + end + end + end + end +end + +---@param source parser.object local function compileLocal(source) local myNode = vm.setNode(source, source) @@ -1070,63 +1120,11 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(setfield.node)) end end - if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then local func = source.parent.parent - -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number - local funcNode = vm.compileNode(func) - local hasDocArg - for n in funcNode:eachObject() do - if n.type == 'doc.type.function' then - for index, arg in ipairs(n.args) do - if func.args[index] == source then - local argNode = vm.compileNode(arg) - for an in argNode:eachObject() do - if an.type ~= 'doc.generic.name' then - vm.setNode(source, an) - end - end - hasDocArg = true - end - end - end - end - if not hasDocArg - and func.parent.type == 'local' then - local refs = func.parent.ref - local findCall - if refs then - for i, ref in ipairs(refs) do - if ref.parent.type == 'call' then - findCall = ref.parent - break - end - end - end - if findCall and findCall.args then - local index - for i, arg in ipairs(source.parent) do - if arg == source then - index = i - break - end - end - if index then - local callerArg = findCall.args[index] - if callerArg then - hasDocArg = true - vm.setNode(source, vm.compileNode(callerArg)) - end - end - end - end - if not hasDocArg then - local suc, node = plugin.dispatch("OnNodeCompileFunctionParam", guide.getUri(source), source) - if suc and node then - hasDocArg = true - vm.setNode(source, node) - end - end + local vmPlugin = plugin.getVmPlugin(guide.getUri(source)) + local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source) + or compileFunctionParam(func, source) if not hasDocArg then vm.setNode(source, vm.declareGlobal('type', 'any')) end diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua index da72a1eb..a158c8de 100644 --- a/script/workspace/scope.lua +++ b/script/workspace/scope.lua @@ -125,8 +125,9 @@ function mt:set(k, v) return v end ----@param k string ----@return any +---@generic T +---@param k `T` +---@return T function mt:get(k) return self._data[k] end |