summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2024-01-24 14:25:00 +0800
committerGitHub <noreply@github.com>2024-01-24 14:25:00 +0800
commit0a962fcc5e52114d73d65928898e5a32b1e1bc7d (patch)
treec6f4b6b2295888cc64007ba258df8b59d5a3942a
parentea3aed4549900437793ddd7beafa88fa4ce10061 (diff)
parent155f831624639c611891a3c1390ce3b19e92888b (diff)
downloadlua-language-server-0a962fcc5e52114d73d65928898e5a32b1e1bc7d.zip
Merge pull request #2486 from fesily/plugin-OnNodeCompileFunctionParam
Plugin on node compile function param
-rw-r--r--script/plugin.lua57
-rw-r--r--script/plugins/nodeHelper.lua75
-rw-r--r--script/vm/compiler.lua101
-rw-r--r--script/workspace/scope.lua5
-rw-r--r--test/plugins/node/test.lua56
-rw-r--r--test/plugins/test.lua1
6 files changed, 245 insertions, 50 deletions
diff --git a/script/plugin.lua b/script/plugin.lua
index 7a661e0d..b297cd9b 100644
--- a/script/plugin.lua
+++ b/script/plugin.lua
@@ -7,6 +7,15 @@ local scope = require 'workspace.scope'
local ws = require 'workspace'
local fs = require 'bee.filesystem'
+---@class pluginInterfaces
+local pluginConfigs = {
+ -- create plugin for vm module
+ VM = {
+ OnCompileFunctionParam = function (next, func, source)
+ end
+ }
+}
+
---@class plugin
local m = {}
@@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...)
return failed == 0, res1, res2
end
+function m.getVmPlugin(uri)
+ local scp = scope.getScope(uri)
+ local interfaces = scp:get('pluginInterfaces')
+ if not interfaces then
+ return
+ end
+ return interfaces.VM
+end
+
---@async
---@param scp scope
local function checkTrustLoad(scp)
@@ -78,6 +96,40 @@ local function checkTrustLoad(scp)
return true
end
+local function createMethodGroup(interfaces, key, methods)
+ local methodGroup = {}
+
+ for method in pairs(methods) do
+ local funcs = setmetatable({}, {
+ __call = function (t, next, ...)
+ if #t == 0 then
+ return next(...)
+ else
+ local result
+ for _, fn in ipairs(t) do
+ result = fn(next, ...)
+ end
+ return result
+ end
+ end
+ })
+ for _, interface in ipairs(interfaces) do
+ local func = interface[method]
+ if not func then
+ local namespace = interface[key]
+ if namespace then
+ func = namespace[method]
+ end
+ end
+ if func then
+ funcs[#funcs+1] = func
+ end
+ end
+ methodGroup[method] = funcs
+ end
+ return #methodGroup>0 and methodGroup or nil
+end
+
---@param uri uri
local function initPlugin(uri)
await.call(function () ---@async
@@ -148,6 +200,11 @@ local function initPlugin(uri)
end
interfaces[#interfaces+1] = interface
end
+
+ for key, config in pairs(pluginConfigs) do
+ interfaces[key] = createMethodGroup(interfaces, key, config)
+ end
+
ws.resetFiles(scp)
end)
end
diff --git a/script/plugins/nodeHelper.lua b/script/plugins/nodeHelper.lua
new file mode 100644
index 00000000..3f90b152
--- /dev/null
+++ b/script/plugins/nodeHelper.lua
@@ -0,0 +1,75 @@
+local vm = require 'vm'
+local guide = require 'parser.guide'
+
+local _M = {}
+
+---@class node.match.pattern
+---@field next node.match.pattern?
+
+local function deepCompare(source, pattern)
+ local type1, type2 = type(source), type(pattern)
+ if type1 ~= type2 then
+ return false
+ end
+
+ if type1 ~= "table" then
+ return source == pattern
+ end
+
+ for key2, value2 in pairs(pattern) do
+ local value1 = source[key2]
+ if value1 == nil or not deepCompare(value1, value2) then
+ return false
+ end
+ end
+
+ return true
+end
+
+---@param source parser.object
+---@param pattern node.match.pattern
+---@return boolean
+function _M.matchPattern(source, pattern)
+ if source.type == 'local' then
+ if source.parent.type == 'funcargs' and source.parent.parent.type == 'function' then
+ for i, ref in ipairs(source.ref) do
+ if deepCompare(ref, pattern) then
+ return true
+ end
+ end
+ end
+ end
+ return false
+end
+
+local vaildVarRegex = "()([a-zA-Z][a-zA-Z0-9_]*)()"
+---创建类型 *.field.field形式的 pattern
+---@param pattern string
+---@return node.match.pattern?, string?
+function _M.createFieldPattern(pattern)
+ local ret = { next = nil }
+ local next = ret
+ local init = 1
+ while true do
+ local startpos, matched, endpos
+ if pattern:sub(1, 1) == "*" then
+ startpos, matched, endpos = init, "*", init + 1
+ else
+ startpos, matched, endpos = vaildVarRegex:match(pattern, init)
+ end
+ if not startpos then
+ break
+ end
+ if startpos ~= init then
+ return nil, "invalid pattern"
+ end
+ local field = matched == "*" and { next = nil }
+ or { field = { type = 'field', matched }, type = 'getfield', next = nil }
+ next.next = field
+ next = field
+ pattern = pattern:sub(endpos)
+ end
+ return ret
+end
+
+return _M
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 2222fa9b..0fe2efe8 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -5,6 +5,7 @@ local rpath = require 'workspace.require-path'
local files = require 'files'
---@class vm
local vm = require 'vm.vm'
+local plugin = require 'plugin'
---@class parser.object
---@field _compiledNodes boolean
@@ -1031,6 +1032,55 @@ local function compileForVars(source, target)
end
---@param source parser.object
+local function compileFunctionParam(func, source)
+ -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
+ local funcNode = vm.compileNode(func)
+ for n in funcNode:eachObject() do
+ if n.type == 'doc.type.function' then
+ for index, arg in ipairs(n.args) do
+ if func.args[index] == source then
+ local argNode = vm.compileNode(arg)
+ for an in argNode:eachObject() do
+ if an.type ~= 'doc.generic.name' then
+ vm.setNode(source, an)
+ end
+ end
+ return true
+ end
+ end
+ end
+ end
+ if func.parent.type == 'local' then
+ local refs = func.parent.ref
+ local findCall
+ if refs then
+ for i, ref in ipairs(refs) do
+ if ref.parent.type == 'call' then
+ findCall = ref.parent
+ break
+ end
+ end
+ end
+ if findCall and findCall.args then
+ local index
+ for i, arg in ipairs(source.parent) do
+ if arg == source then
+ index = i
+ break
+ end
+ end
+ if index then
+ local callerArg = findCall.args[index]
+ if callerArg then
+ vm.setNode(source, vm.compileNode(callerArg))
+ return true
+ end
+ end
+ end
+ end
+end
+
+---@param source parser.object
local function compileLocal(source)
local myNode = vm.setNode(source, source)
@@ -1069,56 +1119,11 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(setfield.node))
end
end
-
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
- -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
- local funcNode = vm.compileNode(func)
- local hasDocArg
- for n in funcNode:eachObject() do
- if n.type == 'doc.type.function' then
- for index, arg in ipairs(n.args) do
- if func.args[index] == source then
- local argNode = vm.compileNode(arg)
- for an in argNode:eachObject() do
- if an.type ~= 'doc.generic.name' then
- vm.setNode(source, an)
- end
- end
- hasDocArg = true
- end
- end
- end
- end
- if not hasDocArg
- and func.parent.type == 'local' then
- local refs = func.parent.ref
- local findCall
- if refs then
- for i, ref in ipairs(refs) do
- if ref.parent.type == 'call' then
- findCall = ref.parent
- break
- end
- end
- end
- if findCall and findCall.args then
- local index
- for i, arg in ipairs(source.parent) do
- if arg == source then
- index = i
- break
- end
- end
- if index then
- local callerArg = findCall.args[index]
- if callerArg then
- hasDocArg = true
- vm.setNode(source, vm.compileNode(callerArg))
- end
- end
- end
- end
+ local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
+ local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
+ or compileFunctionParam(func, source)
if not hasDocArg then
vm.setNode(source, vm.declareGlobal('type', 'any'))
end
diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua
index da72a1eb..a158c8de 100644
--- a/script/workspace/scope.lua
+++ b/script/workspace/scope.lua
@@ -125,8 +125,9 @@ function mt:set(k, v)
return v
end
----@param k string
----@return any
+---@generic T
+---@param k `T`
+---@return T
function mt:get(k)
return self._data[k]
end
diff --git a/test/plugins/node/test.lua b/test/plugins/node/test.lua
new file mode 100644
index 00000000..15e4d16c
--- /dev/null
+++ b/test/plugins/node/test.lua
@@ -0,0 +1,56 @@
+local files = require 'files'
+local scope = require 'workspace.scope'
+local nodeHelper = require 'plugins.nodeHelper'
+local vm = require 'vm'
+local guide = require 'parser.guide'
+
+
+local pattern, msg = nodeHelper.createFieldPattern("*.components")
+assert(pattern, msg)
+
+---@param source parser.object
+function OnCompileFunctionParam(next, func, source)
+ if next(func, source) then
+ return true
+ end
+ --从该参数的使用模式来推导该类型
+ if nodeHelper.matchPattern(source, pattern) then
+ local type = vm.declareGlobal('type', 'TestClass', TESTURI)
+ vm.setNode(source, vm.createNode(type, source))
+ return true
+ end
+end
+
+local myplugin = { OnCompileFunctionParam = OnCompileFunctionParam }
+
+---@diagnostic disable: await-in-sync
+local function TestPlugin(script)
+ local prefix = [[
+ ---@class TestClass
+ ---@field b string
+ ]]
+ ---@param checker fun(state:parser.state)
+ return function (plugin, checker)
+ files.open(TESTURI)
+ files.setText(TESTURI, prefix .. script, true)
+ scope.getScope(TESTURI):set('pluginInterfaces', plugin)
+ local state = files.getState(TESTURI)
+ assert(state)
+ checker(state)
+ files.remove(TESTURI)
+ end
+end
+
+TestPlugin [[
+ local function t(a)
+ a.components:test()
+ end
+]](myplugin, function (state)
+ guide.eachSourceType(state.ast, 'local', function (src)
+ if guide.getKeyName(src) == 'a' then
+ local node = vm.compileNode(src)
+ assert(node)
+ assert(not vm.isUnknown(node))
+ end
+ end)
+end)
diff --git a/test/plugins/test.lua b/test/plugins/test.lua
index 655d30b8..b6533de9 100644
--- a/test/plugins/test.lua
+++ b/test/plugins/test.lua
@@ -1,2 +1,3 @@
require 'plugins.ast.test'
require 'plugins.ffi.test'
+require 'plugins.node.test'