From 5f0cba4fef00f1c426052af421443bd72d38c8aa Mon Sep 17 00:00:00 2001 From: fesily Date: Wed, 17 Jan 2024 14:54:57 +0800 Subject: plugin interface add OnNodeCompileFunctionParam --- script/plugins/nodeHelper.lua | 75 +++++++++++++++++++++++++++++++++++++++++++ script/vm/compiler.lua | 8 +++++ 2 files changed, 83 insertions(+) create mode 100644 script/plugins/nodeHelper.lua diff --git a/script/plugins/nodeHelper.lua b/script/plugins/nodeHelper.lua new file mode 100644 index 00000000..3f90b152 --- /dev/null +++ b/script/plugins/nodeHelper.lua @@ -0,0 +1,75 @@ +local vm = require 'vm' +local guide = require 'parser.guide' + +local _M = {} + +---@class node.match.pattern +---@field next node.match.pattern? + +local function deepCompare(source, pattern) + local type1, type2 = type(source), type(pattern) + if type1 ~= type2 then + return false + end + + if type1 ~= "table" then + return source == pattern + end + + for key2, value2 in pairs(pattern) do + local value1 = source[key2] + if value1 == nil or not deepCompare(value1, value2) then + return false + end + end + + return true +end + +---@param source parser.object +---@param pattern node.match.pattern +---@return boolean +function _M.matchPattern(source, pattern) + if source.type == 'local' then + if source.parent.type == 'funcargs' and source.parent.parent.type == 'function' then + for i, ref in ipairs(source.ref) do + if deepCompare(ref, pattern) then + return true + end + end + end + end + return false +end + +local vaildVarRegex = "()([a-zA-Z][a-zA-Z0-9_]*)()" +---创建类型 *.field.field形式的 pattern +---@param pattern string +---@return node.match.pattern?, string? +function _M.createFieldPattern(pattern) + local ret = { next = nil } + local next = ret + local init = 1 + while true do + local startpos, matched, endpos + if pattern:sub(1, 1) == "*" then + startpos, matched, endpos = init, "*", init + 1 + else + startpos, matched, endpos = vaildVarRegex:match(pattern, init) + end + if not startpos then + break + end + if startpos ~= init then + return nil, "invalid pattern" + end + local field = matched == "*" and { next = nil } + or { field = { type = 'field', matched }, type = 'getfield', next = nil } + next.next = field + next = field + pattern = pattern:sub(endpos) + end + return ret +end + +return _M diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 8a1fa96a..6b8b76cd 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -5,6 +5,7 @@ local rpath = require 'workspace.require-path' local files = require 'files' ---@class vm local vm = require 'vm.vm' +local plugin = require 'plugin' ---@class parser.object ---@field _compiledNodes boolean @@ -1090,6 +1091,13 @@ local function compileLocal(source) end end end + if not hasDocArg then + local suc, node = plugin.dispatch("OnNodeCompileFunctionParam", guide.getUri(source), source) + if suc and node then + hasDocArg = true + vm.setNode(source, node) + end + end if not hasDocArg then vm.setNode(source, vm.declareGlobal('type', 'any')) end -- cgit debian/1.2.3+git2.25.1-1-2-gaceb0 From d315c3a271ecffe4d98fabf3e4bac58f71aa8677 Mon Sep 17 00:00:00 2001 From: fesily Date: Wed, 17 Jan 2024 17:02:31 +0800 Subject: add test --- test/plugins/node/test.lua | 52 ++++++++++++++++++++++++++++++++++++++++++++++ test/plugins/test.lua | 1 + 2 files changed, 53 insertions(+) create mode 100644 test/plugins/node/test.lua diff --git a/test/plugins/node/test.lua b/test/plugins/node/test.lua new file mode 100644 index 00000000..81d9a1b3 --- /dev/null +++ b/test/plugins/node/test.lua @@ -0,0 +1,52 @@ +local files = require 'files' +local scope = require 'workspace.scope' +local nodeHelper = require 'plugins.nodeHelper' +local vm = require 'vm' +local guide = require 'parser.guide' + + +local pattern, msg = nodeHelper.createFieldPattern("*.components") +assert(pattern, msg) + +---@param source parser.object +function OnNodeCompileFunctionParam(uri, source) + --从该参数的使用模式来推导该类型 + if nodeHelper.matchPattern(source, pattern) then + local type = vm.declareGlobal('type', 'TestClass', TESTURI) + return vm.createNode(type, source) + end +end + +local myplugin = { OnNodeCompileFunctionParam = OnNodeCompileFunctionParam } + +---@diagnostic disable: await-in-sync +local function TestPlugin(script) + local prefix = [[ + ---@class TestClass + ---@field b string + ]] + ---@param checker fun(state:parser.state) + return function (plugin, checker) + files.open(TESTURI) + files.setText(TESTURI, prefix .. script, true) + scope.getScope(TESTURI):set('pluginInterface', plugin) + local state = files.getState(TESTURI) + assert(state) + checker(state) + files.remove(TESTURI) + end +end + +TestPlugin [[ + local function t(a) + a.components:test() + end +]](myplugin, function (state) + guide.eachSourceType(state.ast, 'local', function (src) + if guide.getKeyName(src) == 'a' then + local node = vm.compileNode(src) + assert(node) + assert(not vm.isUnknown(node)) + end + end) +end) diff --git a/test/plugins/test.lua b/test/plugins/test.lua index 655d30b8..53a92cc8 100644 --- a/test/plugins/test.lua +++ b/test/plugins/test.lua @@ -1,2 +1,3 @@ require 'plugins.ast.test' require 'plugins.ffi.test' +require 'plugins.node.test' \ No newline at end of file -- cgit debian/1.2.3+git2.25.1-1-2-gaceb0 From 856c27caa41f11cef728e45ff356be98023e201b Mon Sep 17 00:00:00 2001 From: fesily Date: Wed, 17 Jan 2024 17:05:27 +0800 Subject: format --- test/plugins/test.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/plugins/test.lua b/test/plugins/test.lua index 53a92cc8..b6533de9 100644 --- a/test/plugins/test.lua +++ b/test/plugins/test.lua @@ -1,3 +1,3 @@ require 'plugins.ast.test' require 'plugins.ffi.test' -require 'plugins.node.test' \ No newline at end of file +require 'plugins.node.test' -- cgit debian/1.2.3+git2.25.1-1-2-gaceb0 From 326a033a816f0bd5dcca6aeea1443935ee5a9e3c Mon Sep 17 00:00:00 2001 From: fesily Date: Wed, 17 Jan 2024 17:26:55 +0800 Subject: fix test --- test/plugins/node/test.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/plugins/node/test.lua b/test/plugins/node/test.lua index 81d9a1b3..466d0189 100644 --- a/test/plugins/node/test.lua +++ b/test/plugins/node/test.lua @@ -29,7 +29,7 @@ local function TestPlugin(script) return function (plugin, checker) files.open(TESTURI) files.setText(TESTURI, prefix .. script, true) - scope.getScope(TESTURI):set('pluginInterface', plugin) + scope.getScope(TESTURI):set('pluginInterfaces', plugin) local state = files.getState(TESTURI) assert(state) checker(state) -- cgit debian/1.2.3+git2.25.1-1-2-gaceb0 From 82c004e0ba166deaca282d2d6cfb8819eb830830 Mon Sep 17 00:00:00 2001 From: fesily Date: Fri, 19 Jan 2024 12:47:06 +0800 Subject: recode plugin interface --- script/plugin.lua | 57 ++++++++++++++++++++++++ script/vm/compiler.lua | 108 ++++++++++++++++++++++----------------------- script/workspace/scope.lua | 5 ++- test/plugins/node/test.lua | 10 +++-- 4 files changed, 120 insertions(+), 60 deletions(-) diff --git a/script/plugin.lua b/script/plugin.lua index 7a661e0d..b297cd9b 100644 --- a/script/plugin.lua +++ b/script/plugin.lua @@ -7,6 +7,15 @@ local scope = require 'workspace.scope' local ws = require 'workspace' local fs = require 'bee.filesystem' +---@class pluginInterfaces +local pluginConfigs = { + -- create plugin for vm module + VM = { + OnCompileFunctionParam = function (next, func, source) + end + } +} + ---@class plugin local m = {} @@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...) return failed == 0, res1, res2 end +function m.getVmPlugin(uri) + local scp = scope.getScope(uri) + local interfaces = scp:get('pluginInterfaces') + if not interfaces then + return + end + return interfaces.VM +end + ---@async ---@param scp scope local function checkTrustLoad(scp) @@ -78,6 +96,40 @@ local function checkTrustLoad(scp) return true end +local function createMethodGroup(interfaces, key, methods) + local methodGroup = {} + + for method in pairs(methods) do + local funcs = setmetatable({}, { + __call = function (t, next, ...) + if #t == 0 then + return next(...) + else + local result + for _, fn in ipairs(t) do + result = fn(next, ...) + end + return result + end + end + }) + for _, interface in ipairs(interfaces) do + local func = interface[method] + if not func then + local namespace = interface[key] + if namespace then + func = namespace[method] + end + end + if func then + funcs[#funcs+1] = func + end + end + methodGroup[method] = funcs + end + return #methodGroup>0 and methodGroup or nil +end + ---@param uri uri local function initPlugin(uri) await.call(function () ---@async @@ -148,6 +200,11 @@ local function initPlugin(uri) end interfaces[#interfaces+1] = interface end + + for key, config in pairs(pluginConfigs) do + interfaces[key] = createMethodGroup(interfaces, key, config) + end + ws.resetFiles(scp) end) end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 4621006e..51931984 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1031,6 +1031,56 @@ local function compileForVars(source, target) return false end +---@param source parser.object +local function compileFunctionParam(func, source) + -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number + local funcNode = vm.compileNode(func) + local hasDocArg + for n in funcNode:eachObject() do + if n.type == 'doc.type.function' then + for index, arg in ipairs(n.args) do + if func.args[index] == source then + local argNode = vm.compileNode(arg) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end + end + return true + end + end + end + end + if func.parent.type == 'local' then + local refs = func.parent.ref + local findCall + if refs then + for i, ref in ipairs(refs) do + if ref.parent.type == 'call' then + findCall = ref.parent + break + end + end + end + if findCall and findCall.args then + local index + for i, arg in ipairs(source.parent) do + if arg == source then + index = i + break + end + end + if index then + local callerArg = findCall.args[index] + if callerArg then + vm.setNode(source, vm.compileNode(callerArg)) + return true + end + end + end + end +end + ---@param source parser.object local function compileLocal(source) local myNode = vm.setNode(source, source) @@ -1070,63 +1120,11 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(setfield.node)) end end - if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then local func = source.parent.parent - -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number - local funcNode = vm.compileNode(func) - local hasDocArg - for n in funcNode:eachObject() do - if n.type == 'doc.type.function' then - for index, arg in ipairs(n.args) do - if func.args[index] == source then - local argNode = vm.compileNode(arg) - for an in argNode:eachObject() do - if an.type ~= 'doc.generic.name' then - vm.setNode(source, an) - end - end - hasDocArg = true - end - end - end - end - if not hasDocArg - and func.parent.type == 'local' then - local refs = func.parent.ref - local findCall - if refs then - for i, ref in ipairs(refs) do - if ref.parent.type == 'call' then - findCall = ref.parent - break - end - end - end - if findCall and findCall.args then - local index - for i, arg in ipairs(source.parent) do - if arg == source then - index = i - break - end - end - if index then - local callerArg = findCall.args[index] - if callerArg then - hasDocArg = true - vm.setNode(source, vm.compileNode(callerArg)) - end - end - end - end - if not hasDocArg then - local suc, node = plugin.dispatch("OnNodeCompileFunctionParam", guide.getUri(source), source) - if suc and node then - hasDocArg = true - vm.setNode(source, node) - end - end + local vmPlugin = plugin.getVmPlugin(guide.getUri(source)) + local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source) + or compileFunctionParam(func, source) if not hasDocArg then vm.setNode(source, vm.declareGlobal('type', 'any')) end diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua index da72a1eb..a158c8de 100644 --- a/script/workspace/scope.lua +++ b/script/workspace/scope.lua @@ -125,8 +125,9 @@ function mt:set(k, v) return v end ----@param k string ----@return any +---@generic T +---@param k `T` +---@return T function mt:get(k) return self._data[k] end diff --git a/test/plugins/node/test.lua b/test/plugins/node/test.lua index 466d0189..15e4d16c 100644 --- a/test/plugins/node/test.lua +++ b/test/plugins/node/test.lua @@ -9,15 +9,19 @@ local pattern, msg = nodeHelper.createFieldPattern("*.components") assert(pattern, msg) ---@param source parser.object -function OnNodeCompileFunctionParam(uri, source) +function OnCompileFunctionParam(next, func, source) + if next(func, source) then + return true + end --从该参数的使用模式来推导该类型 if nodeHelper.matchPattern(source, pattern) then local type = vm.declareGlobal('type', 'TestClass', TESTURI) - return vm.createNode(type, source) + vm.setNode(source, vm.createNode(type, source)) + return true end end -local myplugin = { OnNodeCompileFunctionParam = OnNodeCompileFunctionParam } +local myplugin = { OnCompileFunctionParam = OnCompileFunctionParam } ---@diagnostic disable: await-in-sync local function TestPlugin(script) -- cgit debian/1.2.3+git2.25.1-1-2-gaceb0 From 155f831624639c611891a3c1390ce3b19e92888b Mon Sep 17 00:00:00 2001 From: fesily Date: Sat, 20 Jan 2024 08:44:03 +0800 Subject: remove unused --- script/vm/compiler.lua | 1 - 1 file changed, 1 deletion(-) diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 51931984..0fe2efe8 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1035,7 +1035,6 @@ end local function compileFunctionParam(func, source) -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number local funcNode = vm.compileNode(func) - local hasDocArg for n in funcNode:eachObject() do if n.type == 'doc.type.function' then for index, arg in ipairs(n.args) do -- cgit debian/1.2.3+git2.25.1-1-2-gaceb0