summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/generic.lua8
-rw-r--r--script/core/noder.lua212
-rw-r--r--script/core/searcher.lua25
-rw-r--r--test/definition/init.lua1
4 files changed, 141 insertions, 105 deletions
diff --git a/script/core/generic.lua b/script/core/generic.lua
index c8dcf66d..78c3eff9 100644
--- a/script/core/generic.lua
+++ b/script/core/generic.lua
@@ -52,7 +52,7 @@ local function createValue(closure, proto, callback, road)
end
local value = instantValue(closure, proto)
value.types = types
- noder.compileNode(value)
+ noder.compileNode(noder.getNoders(proto), value)
return value
end
if proto.type == 'doc.type.name' then
@@ -64,7 +64,7 @@ local function createValue(closure, proto, callback, road)
if callback then
callback(road, key, proto)
end
- noder.compileNode(value)
+ noder.compileNode(noder.getNoders(proto), value)
return value
end
if proto.type == 'doc.type.function' then
@@ -92,7 +92,7 @@ local function createValue(closure, proto, callback, road)
value.args = args
value.returns = returns
value.isGeneric = true
- noder.pushSource(value)
+ noder.pushSource(noder.getNoders(proto), value)
return value
end
if proto.type == 'doc.type.array' then
@@ -221,7 +221,7 @@ function m.createClosure(proto, call)
return nil
end
- noder.compileNode(closure)
+ noder.compileNode(noder.getNoders(proto), closure)
return closure
end
diff --git a/script/core/noder.lua b/script/core/noder.lua
index a55e93da..fa9d7d72 100644
--- a/script/core/noder.lua
+++ b/script/core/noder.lua
@@ -1,7 +1,6 @@
local util = require 'utility'
local guide = require 'parser.guide'
-local Noders
local LastIDCache = {}
local FirstIDCache = {}
local SPLIT_CHAR = '\x1F'
@@ -13,16 +12,31 @@ local PARAM_INDEX = SPLIT_CHAR .. '@'
local TABLE_KEY = SPLIT_CHAR .. '<'
local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR
+---@class node
+-- 当前节点的id
+---@field id string
+-- 使用该ID的单元
+---@field sources parser.guide.object[]
+-- 前进的关联ID
+---@field forward string[]
+-- 后退的关联ID
+---@field backward string[]
+-- 函数调用参数信息(用于泛型)
+---@field call parser.guide.object
+
+---@alias noders table<string, node[]>
+
---创建source的链接信息
+---@param noders noders
---@param id string
---@return node
-local function getNode(id)
- if not Noders[id] then
- Noders[id] = {
+local function getNode(noders, id)
+ if not noders[id] then
+ noders[id] = {
id = id,
}
end
- return Noders[id]
+ return noders[id]
end
---是否是全局变量(包括 _G.XXX 形式)
@@ -133,10 +147,10 @@ local function getKey(source)
local name = source[1]
return name, nil
elseif source.type == 'doc.type.name' then
+ local name = source[1]
if source.typeGeneric then
- return source.start, nil
+ return source.typeGeneric[name][1].start, nil
else
- local name = source[1]
return name, nil
end
elseif source.type == 'doc.class'
@@ -296,51 +310,49 @@ local function getID(source)
end
---添加关联的前进ID
+---@param noders noders
---@param id string
---@param forwardID string
-local function pushForward(id, forwardID)
+local function pushForward(noders, id, forwardID)
if not id
or not forwardID
or forwardID == ''
or id == forwardID then
return
end
- local node = getNode(id)
+ local node = getNode(noders, id)
if not node.forward then
node.forward = {}
end
+ if node.forward[forwardID] then
+ return
+ end
+ node.forward[forwardID] = true
node.forward[#node.forward+1] = forwardID
end
---添加关联的后退ID
+---@param noders noders
---@param id string
---@param backwardID string
-local function pushBackward(id, backwardID)
+local function pushBackward(noders, id, backwardID)
if not id
or not backwardID
or backwardID == ''
or id == backwardID then
return
end
- local node = getNode(id)
+ local node = getNode(noders, id)
if not node.backward then
node.backward = {}
end
+ if node.backward[backwardID] then
+ return
+ end
+ node.backward[backwardID] = true
node.backward[#node.backward+1] = backwardID
end
----@class node
--- 当前节点的id
----@field id string
--- 使用该ID的单元
----@field sources parser.guide.object[]
--- 前进的关联ID
----@field forward string[]
--- 后退的关联ID
----@field backward string[]
--- 函数调用参数信息(用于泛型)
----@field call parser.guide.object
-
local m = {}
m.SPLIT_CHAR = SPLIT_CHAR
@@ -370,27 +382,29 @@ local function getDocStateWithoutCrossFunction(obj)
end
---添加关联单元
+---@param noders noders
---@param source parser.guide.object
-function m.pushSource(source)
+function m.pushSource(noders, source)
local id = m.getID(source)
if not id then
return
end
- local node = getNode(id)
+ local node = getNode(noders, id)
if not node.sources then
node.sources = {}
end
node.sources[#node.sources+1] = source
end
+---@param noders noders
---@param source parser.guide.object
---@return parser.guide.object[]
-function m.compileNode(source)
+function m.compileNode(noders, source)
local id = getID(source)
if source.value then
-- x = y : x -> y
- pushForward(id, getID(source.value))
- pushBackward(getID(source.value), id)
+ pushForward(noders, id, getID(source.value))
+ pushBackward(noders, getID(source.value), id)
end
-- self -> mt:xx
if source.type == 'local' and source[1] == 'self' then
@@ -403,40 +417,40 @@ function m.compileNode(source)
if setmethod and ( setmethod.type == 'setmethod'
or setmethod.type == 'setfield'
or setmethod.type == 'setindex') then
- pushForward(id, getID(setmethod.node))
- pushBackward(getID(setmethod.node), id)
+ pushForward(noders, id, getID(setmethod.node))
+ pushBackward(noders, getID(setmethod.node), id)
end
end
-- 分解 @type
if source.type == 'doc.type' then
if source.bindSources then
for _, src in ipairs(source.bindSources) do
- pushForward(getID(src), id)
- pushForward(id, getID(src))
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
end
end
for _, typeUnit in ipairs(source.types) do
- pushForward(id, getID(typeUnit))
- pushBackward(getID(typeUnit), id)
+ pushForward(noders, id, getID(typeUnit))
+ pushBackward(noders, getID(typeUnit), id)
end
for _, enumUnit in ipairs(source.enums) do
- pushForward(id, getID(enumUnit))
+ pushForward(noders, id, getID(enumUnit))
end
end
-- 分解 @class
if source.type == 'doc.class' then
- pushForward(id, getID(source.class))
- pushForward(getID(source.class), id)
+ pushForward(noders, id, getID(source.class))
+ pushForward(noders, getID(source.class), id)
if source.extends then
for _, ext in ipairs(source.extends) do
- pushBackward(id, getID(ext))
- pushBackward(getID(ext), id)
+ pushBackward(noders, id, getID(ext))
+ pushBackward(noders, getID(ext), id)
end
end
if source.bindSources then
for _, src in ipairs(source.bindSources) do
- pushForward(getID(src), id)
- pushForward(id, getID(src))
+ pushForward(noders, getID(src), id)
+ pushForward(noders, id, getID(src))
end
end
do
@@ -453,29 +467,29 @@ function m.compileNode(source)
SPLIT_CHAR,
key
)
- pushForward(keyID, getID(doc.field))
- pushBackward(getID(doc.field), keyID)
- pushForward(keyID, getID(doc.extends))
- pushBackward(getID(doc.extends), keyID)
+ pushForward(noders, keyID, getID(doc.field))
+ pushBackward(noders, getID(doc.field), keyID)
+ pushForward(noders, keyID, getID(doc.extends))
+ pushBackward(noders, getID(doc.extends), keyID)
end
end
end
end
end
if source.type == 'doc.param' then
- pushForward(getID(source), getID(source.extends))
+ pushForward(noders, getID(source), getID(source.extends))
end
if source.type == 'doc.vararg' then
- pushForward(getID(source), getID(source.vararg))
+ pushForward(noders, getID(source), getID(source.vararg))
end
if source.type == 'doc.see' then
local nameID = getID(source.name)
local classID = nameID:gsub('^dsn:', 'dn:')
- pushForward(nameID, classID)
+ pushForward(noders, nameID, classID)
if source.field then
local fieldID = getID(source.field)
local fieldClassID = fieldID:gsub('^dsn:', 'dn:')
- pushForward(fieldID, fieldClassID)
+ pushForward(noders, fieldID, fieldClassID)
end
end
if source.type == 'call' then
@@ -484,14 +498,14 @@ function m.compileNode(source)
if not nodeID then
return
end
- getNode(id).call = source
+ getNode(noders, id).call = source
-- 将 call 映射到 node#1 上
local callID = ('%s%s%s'):format(
nodeID,
RETURN_INDEX,
1
)
- pushForward(id, callID)
+ pushForward(noders, id, callID)
-- 将setmetatable映射到 param1 以及 param2.__index 上
if node.special == 'setmetatable' then
local tblID = getID(source.args and source.args[1])
@@ -504,17 +518,17 @@ function m.compileNode(source)
'__index'
)
end
- pushForward(id, callID)
- pushBackward(callID, id)
- pushForward(callID, tblID)
- pushForward(callID, indexID)
- pushBackward(tblID, callID)
- --pushBackward(indexID, callID)
+ pushForward(noders, id, callID)
+ pushBackward(noders, callID, id)
+ pushForward(noders, callID, tblID)
+ pushForward(noders, callID, indexID)
+ pushBackward(noders, tblID, callID)
+ --pushBackward(noders, indexID, callID)
end
if node.special == 'require' then
local arg1 = source.args and source.args[1]
if arg1 and arg1.type == 'string' then
- getNode(callID).require = arg1[1]
+ getNode(noders, callID).require = arg1[1]
end
end
end
@@ -532,9 +546,9 @@ function m.compileNode(source)
RETURN_INDEX,
source.sindex
)
- pushForward(id, callXID)
- pushBackward(callXID, id)
- getNode(id).call = call
+ pushForward(noders, id, callXID)
+ pushBackward(noders, callXID, id)
+ getNode(noders, id).call = call
if node.special == 'pcall'
or node.special == 'xpcall' then
local index = source.sindex - 1
@@ -550,8 +564,8 @@ function m.compileNode(source)
RETURN_INDEX,
index
)
- pushForward(id, funcXID)
- pushBackward(funcXID, id)
+ pushForward(noders, id, funcXID)
+ pushBackward(noders, funcXID, id)
end
end
end
@@ -563,7 +577,7 @@ function m.compileNode(source)
RETURN_INDEX,
index
)
- pushForward(returnID, getID(rtn))
+ pushForward(noders, returnID, getID(rtn))
end
end
-- @type fun(x: T):T 的情况
@@ -583,14 +597,14 @@ function m.compileNode(source)
id,
TABLE_KEY
)
- pushForward(keyID, getID(source.tkey))
+ pushForward(noders, keyID, getID(source.tkey))
end
if source.tvalue then
local valueID = ('%s%s'):format(
id,
ANY_FIELD
)
- pushForward(valueID, getID(source.tvalue))
+ pushForward(noders, valueID, getID(source.tvalue))
end
end
if source.type == 'doc.type.array' then
@@ -599,13 +613,13 @@ function m.compileNode(source)
id,
ANY_FIELD
)
- pushForward(nodeID, getID(source.node))
+ pushForward(noders, nodeID, getID(source.node))
end
local keyID = ('%s%s'):format(
id,
TABLE_KEY
)
- pushForward(keyID, 'dn:integer')
+ pushForward(noders, keyID, 'dn:integer')
end
-- 将函数的返回值映射到具体的返回值上
if source.type == 'function' then
@@ -627,10 +641,10 @@ function m.compileNode(source)
index
)
for _, rtnObj in ipairs(rtnObjs) do
- pushForward(returnID, getID(rtnObj))
+ pushForward(noders, returnID, getID(rtnObj))
if rtnObj.type == 'function'
or rtnObj.type == 'call' then
- pushBackward(getID(rtnObj), returnID)
+ pushBackward(noders, getID(rtnObj), returnID)
end
end
end
@@ -645,8 +659,8 @@ function m.compileNode(source)
RETURN_INDEX,
rtn.returnIndex
)
- pushForward(fullID, getID(rtn))
- pushBackward(getID(rtn), fullID)
+ pushForward(noders, fullID, getID(rtn))
+ pushBackward(noders, getID(rtn), fullID)
end
end
if doc.type == 'doc.param' then
@@ -655,7 +669,7 @@ function m.compileNode(source)
local paramIndex = source.docParamMap[paramName]
local param = source.args[paramIndex]
if param then
- pushForward(getID(param), getID(doc))
+ pushForward(noders, getID(param), getID(doc))
param.docParam = doc
end
end
@@ -663,7 +677,7 @@ function m.compileNode(source)
if doc.type == 'doc.vararg' then
for _, param in ipairs(source.args) do
if param.type == '...' then
- pushForward(getID(param), getID(doc))
+ pushForward(noders, getID(param), getID(doc))
end
end
end
@@ -678,8 +692,8 @@ function m.compileNode(source)
for _, rtn in ipairs(source.returns) do
local rtnObj = rtn[1]
if rtnObj then
- pushForward('mainreturn', getID(rtnObj))
- pushBackward(getID(rtnObj), 'mainreturn')
+ pushForward(noders, 'mainreturn', getID(rtnObj))
+ pushBackward(noders, getID(rtnObj), 'mainreturn')
end
end
end
@@ -692,8 +706,8 @@ function m.compileNode(source)
i
)
local returnID = getID(rtn)
- pushForward(closureID, returnID)
- pushBackward(returnID, closureID)
+ pushForward(noders, closureID, returnID)
+ pushBackward(noders, returnID, closureID)
end
end
if source.type == 'generic.value' then
@@ -704,15 +718,15 @@ function m.compileNode(source)
local key = proto[1]
if upvalues[key] then
for _, paramID in ipairs(upvalues[key]) do
- pushForward(id, paramID)
- pushBackward(paramID, id)
+ pushForward(noders, id, paramID)
+ pushBackward(noders, paramID, id)
end
end
end
if proto.type == 'doc.type' then
for _, tp in ipairs(source.types) do
- pushForward(id, getID(tp))
- pushBackward(getID(tp), id)
+ pushForward(noders, id, getID(tp))
+ pushBackward(noders, getID(tp), id)
end
end
if proto.type == 'doc.type.array' then
@@ -720,12 +734,12 @@ function m.compileNode(source)
id,
ANY_FIELD
)
- pushForward(nodeID, getID(source.node))
+ pushForward(noders, nodeID, getID(source.node))
local keyID = ('%s%s'):format(
id,
TABLE_KEY
)
- pushForward(keyID, 'dn:integer')
+ pushForward(noders, keyID, 'dn:integer')
end
if proto.type == 'doc.type.table' then
if source.tkey then
@@ -733,14 +747,14 @@ function m.compileNode(source)
id,
TABLE_KEY
)
- pushForward(keyID, getID(source.tkey))
+ pushForward(noders, keyID, getID(source.tkey))
end
if source.tvalue then
local valueID = ('%s%s'):format(
id,
ANY_FIELD
)
- pushForward(valueID, getID(source.tvalue))
+ pushForward(noders, valueID, getID(source.tvalue))
end
end
end
@@ -814,23 +828,33 @@ function m.removeID(root, id)
noders[id] = nil
end
+---获取对象的noders
+---@param source parser.guide.object
+---@return noders
+function m.getNoders(source)
+ local root = guide.getRoot(source)
+ if not root._noders then
+ root._noders = {}
+ end
+ return root._noders
+end
+
---编译整个文件的node
---@param source parser.guide.object
---@return table
function m.compileNodes(source)
local root = guide.getRoot(source)
- if root._noders then
- return root._noders
+ local noders = m.getNoders(source)
+ if next(noders) then
+ return
end
- Noders = {}
- root._noders = Noders
guide.eachSource(root, function (src)
- m.pushSource(src)
- m.compileNode(src)
+ m.pushSource(noders, src)
+ m.compileNode(noders, src)
end)
-- Special rule: ('').XX -> stringlib.XX
- pushForward('str:', 'dn:stringlib')
- return Noders
+ pushForward(noders, 'str:', 'dn:stringlib')
+ return noders
end
return m
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index c97237d0..ccb5814b 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -3,6 +3,7 @@ local guide = require 'parser.guide'
local files = require 'files'
local generic = require 'core.generic'
local ws = require 'workspace'
+local vm = require 'vm.vm'
local NONE = {'NONE'}
local LAST = {'LAST'}
@@ -152,7 +153,7 @@ end
-- TODO
function m.findGlobals(root)
- noder.compileNode(root)
+ noder.compileNode(noder.getNoders(root), root)
-- TODO
return {}
end
@@ -291,14 +292,24 @@ function m.searchRefsByID(status, uri, expect, mode)
return
end
local func = guide.getParentFunction(obj)
- if not func or func.type ~= 'function' then
+ if not func then
return
end
- local parentID = noder.getID(func)
- if not parentID then
- return
+ if func.type == 'function' then
+ local parentID = noder.getID(func)
+ if not parentID then
+ return
+ end
+ search(parentID, noder.RETURN_INDEX .. returnIndex)
+ end
+ if func.type == 'main' then
+ local calls = vm.getLinksTo(uri)
+ for _, call in ipairs(calls) do
+ local turi = guide.getUri(call)
+ local tid = noder.getID(call)
+ crossSearch(status, turi, tid, mode)
+ end
end
- search(parentID, noder.RETURN_INDEX .. returnIndex)
end
local function isCallID(field)
@@ -477,7 +488,7 @@ function m.searchRefsByID(status, uri, expect, mode)
end
search(expect)
- searchFunction(expect)
+ --searchFunction(expect)
--清除来自泛型的临时对象
for _, closure in pairs(closureCache) do
diff --git a/test/definition/init.lua b/test/definition/init.lua
index 78170e0e..85bcd5d5 100644
--- a/test/definition/init.lua
+++ b/test/definition/init.lua
@@ -36,6 +36,7 @@ end
function TEST(script)
files.removeAll()
+ script = script:gsub('\n', '\r\n')
local target = catch_target(script)
local start = script:find('<?', 1, true)
local finish = script:find('?>', 1, true)