diff options
-rw-r--r-- | script/core/generic.lua | 8 | ||||
-rw-r--r-- | script/core/noder.lua | 212 | ||||
-rw-r--r-- | script/core/searcher.lua | 25 | ||||
-rw-r--r-- | test/definition/init.lua | 1 |
4 files changed, 141 insertions, 105 deletions
diff --git a/script/core/generic.lua b/script/core/generic.lua index c8dcf66d..78c3eff9 100644 --- a/script/core/generic.lua +++ b/script/core/generic.lua @@ -52,7 +52,7 @@ local function createValue(closure, proto, callback, road) end local value = instantValue(closure, proto) value.types = types - noder.compileNode(value) + noder.compileNode(noder.getNoders(proto), value) return value end if proto.type == 'doc.type.name' then @@ -64,7 +64,7 @@ local function createValue(closure, proto, callback, road) if callback then callback(road, key, proto) end - noder.compileNode(value) + noder.compileNode(noder.getNoders(proto), value) return value end if proto.type == 'doc.type.function' then @@ -92,7 +92,7 @@ local function createValue(closure, proto, callback, road) value.args = args value.returns = returns value.isGeneric = true - noder.pushSource(value) + noder.pushSource(noder.getNoders(proto), value) return value end if proto.type == 'doc.type.array' then @@ -221,7 +221,7 @@ function m.createClosure(proto, call) return nil end - noder.compileNode(closure) + noder.compileNode(noder.getNoders(proto), closure) return closure end diff --git a/script/core/noder.lua b/script/core/noder.lua index a55e93da..fa9d7d72 100644 --- a/script/core/noder.lua +++ b/script/core/noder.lua @@ -1,7 +1,6 @@ local util = require 'utility' local guide = require 'parser.guide' -local Noders local LastIDCache = {} local FirstIDCache = {} local SPLIT_CHAR = '\x1F' @@ -13,16 +12,31 @@ local PARAM_INDEX = SPLIT_CHAR .. '@' local TABLE_KEY = SPLIT_CHAR .. '<' local ANY_FIELD = SPLIT_CHAR .. ANY_FIELD_CHAR +---@class node +-- 当前节点的id +---@field id string +-- 使用该ID的单元 +---@field sources parser.guide.object[] +-- 前进的关联ID +---@field forward string[] +-- 后退的关联ID +---@field backward string[] +-- 函数调用参数信息(用于泛型) +---@field call parser.guide.object + +---@alias noders table<string, node[]> + ---创建source的链接信息 +---@param noders noders ---@param id string ---@return node -local function getNode(id) - if not Noders[id] then - Noders[id] = { +local function getNode(noders, id) + if not noders[id] then + noders[id] = { id = id, } end - return Noders[id] + return noders[id] end ---是否是全局变量(包括 _G.XXX 形式) @@ -133,10 +147,10 @@ local function getKey(source) local name = source[1] return name, nil elseif source.type == 'doc.type.name' then + local name = source[1] if source.typeGeneric then - return source.start, nil + return source.typeGeneric[name][1].start, nil else - local name = source[1] return name, nil end elseif source.type == 'doc.class' @@ -296,51 +310,49 @@ local function getID(source) end ---添加关联的前进ID +---@param noders noders ---@param id string ---@param forwardID string -local function pushForward(id, forwardID) +local function pushForward(noders, id, forwardID) if not id or not forwardID or forwardID == '' or id == forwardID then return end - local node = getNode(id) + local node = getNode(noders, id) if not node.forward then node.forward = {} end + if node.forward[forwardID] then + return + end + node.forward[forwardID] = true node.forward[#node.forward+1] = forwardID end ---添加关联的后退ID +---@param noders noders ---@param id string ---@param backwardID string -local function pushBackward(id, backwardID) +local function pushBackward(noders, id, backwardID) if not id or not backwardID or backwardID == '' or id == backwardID then return end - local node = getNode(id) + local node = getNode(noders, id) if not node.backward then node.backward = {} end + if node.backward[backwardID] then + return + end + node.backward[backwardID] = true node.backward[#node.backward+1] = backwardID end ----@class node --- 当前节点的id ----@field id string --- 使用该ID的单元 ----@field sources parser.guide.object[] --- 前进的关联ID ----@field forward string[] --- 后退的关联ID ----@field backward string[] --- 函数调用参数信息(用于泛型) ----@field call parser.guide.object - local m = {} m.SPLIT_CHAR = SPLIT_CHAR @@ -370,27 +382,29 @@ local function getDocStateWithoutCrossFunction(obj) end ---添加关联单元 +---@param noders noders ---@param source parser.guide.object -function m.pushSource(source) +function m.pushSource(noders, source) local id = m.getID(source) if not id then return end - local node = getNode(id) + local node = getNode(noders, id) if not node.sources then node.sources = {} end node.sources[#node.sources+1] = source end +---@param noders noders ---@param source parser.guide.object ---@return parser.guide.object[] -function m.compileNode(source) +function m.compileNode(noders, source) local id = getID(source) if source.value then -- x = y : x -> y - pushForward(id, getID(source.value)) - pushBackward(getID(source.value), id) + pushForward(noders, id, getID(source.value)) + pushBackward(noders, getID(source.value), id) end -- self -> mt:xx if source.type == 'local' and source[1] == 'self' then @@ -403,40 +417,40 @@ function m.compileNode(source) if setmethod and ( setmethod.type == 'setmethod' or setmethod.type == 'setfield' or setmethod.type == 'setindex') then - pushForward(id, getID(setmethod.node)) - pushBackward(getID(setmethod.node), id) + pushForward(noders, id, getID(setmethod.node)) + pushBackward(noders, getID(setmethod.node), id) end end -- 分解 @type if source.type == 'doc.type' then if source.bindSources then for _, src in ipairs(source.bindSources) do - pushForward(getID(src), id) - pushForward(id, getID(src)) + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) end end for _, typeUnit in ipairs(source.types) do - pushForward(id, getID(typeUnit)) - pushBackward(getID(typeUnit), id) + pushForward(noders, id, getID(typeUnit)) + pushBackward(noders, getID(typeUnit), id) end for _, enumUnit in ipairs(source.enums) do - pushForward(id, getID(enumUnit)) + pushForward(noders, id, getID(enumUnit)) end end -- 分解 @class if source.type == 'doc.class' then - pushForward(id, getID(source.class)) - pushForward(getID(source.class), id) + pushForward(noders, id, getID(source.class)) + pushForward(noders, getID(source.class), id) if source.extends then for _, ext in ipairs(source.extends) do - pushBackward(id, getID(ext)) - pushBackward(getID(ext), id) + pushBackward(noders, id, getID(ext)) + pushBackward(noders, getID(ext), id) end end if source.bindSources then for _, src in ipairs(source.bindSources) do - pushForward(getID(src), id) - pushForward(id, getID(src)) + pushForward(noders, getID(src), id) + pushForward(noders, id, getID(src)) end end do @@ -453,29 +467,29 @@ function m.compileNode(source) SPLIT_CHAR, key ) - pushForward(keyID, getID(doc.field)) - pushBackward(getID(doc.field), keyID) - pushForward(keyID, getID(doc.extends)) - pushBackward(getID(doc.extends), keyID) + pushForward(noders, keyID, getID(doc.field)) + pushBackward(noders, getID(doc.field), keyID) + pushForward(noders, keyID, getID(doc.extends)) + pushBackward(noders, getID(doc.extends), keyID) end end end end end if source.type == 'doc.param' then - pushForward(getID(source), getID(source.extends)) + pushForward(noders, getID(source), getID(source.extends)) end if source.type == 'doc.vararg' then - pushForward(getID(source), getID(source.vararg)) + pushForward(noders, getID(source), getID(source.vararg)) end if source.type == 'doc.see' then local nameID = getID(source.name) local classID = nameID:gsub('^dsn:', 'dn:') - pushForward(nameID, classID) + pushForward(noders, nameID, classID) if source.field then local fieldID = getID(source.field) local fieldClassID = fieldID:gsub('^dsn:', 'dn:') - pushForward(fieldID, fieldClassID) + pushForward(noders, fieldID, fieldClassID) end end if source.type == 'call' then @@ -484,14 +498,14 @@ function m.compileNode(source) if not nodeID then return end - getNode(id).call = source + getNode(noders, id).call = source -- 将 call 映射到 node#1 上 local callID = ('%s%s%s'):format( nodeID, RETURN_INDEX, 1 ) - pushForward(id, callID) + pushForward(noders, id, callID) -- 将setmetatable映射到 param1 以及 param2.__index 上 if node.special == 'setmetatable' then local tblID = getID(source.args and source.args[1]) @@ -504,17 +518,17 @@ function m.compileNode(source) '__index' ) end - pushForward(id, callID) - pushBackward(callID, id) - pushForward(callID, tblID) - pushForward(callID, indexID) - pushBackward(tblID, callID) - --pushBackward(indexID, callID) + pushForward(noders, id, callID) + pushBackward(noders, callID, id) + pushForward(noders, callID, tblID) + pushForward(noders, callID, indexID) + pushBackward(noders, tblID, callID) + --pushBackward(noders, indexID, callID) end if node.special == 'require' then local arg1 = source.args and source.args[1] if arg1 and arg1.type == 'string' then - getNode(callID).require = arg1[1] + getNode(noders, callID).require = arg1[1] end end end @@ -532,9 +546,9 @@ function m.compileNode(source) RETURN_INDEX, source.sindex ) - pushForward(id, callXID) - pushBackward(callXID, id) - getNode(id).call = call + pushForward(noders, id, callXID) + pushBackward(noders, callXID, id) + getNode(noders, id).call = call if node.special == 'pcall' or node.special == 'xpcall' then local index = source.sindex - 1 @@ -550,8 +564,8 @@ function m.compileNode(source) RETURN_INDEX, index ) - pushForward(id, funcXID) - pushBackward(funcXID, id) + pushForward(noders, id, funcXID) + pushBackward(noders, funcXID, id) end end end @@ -563,7 +577,7 @@ function m.compileNode(source) RETURN_INDEX, index ) - pushForward(returnID, getID(rtn)) + pushForward(noders, returnID, getID(rtn)) end end -- @type fun(x: T):T 的情况 @@ -583,14 +597,14 @@ function m.compileNode(source) id, TABLE_KEY ) - pushForward(keyID, getID(source.tkey)) + pushForward(noders, keyID, getID(source.tkey)) end if source.tvalue then local valueID = ('%s%s'):format( id, ANY_FIELD ) - pushForward(valueID, getID(source.tvalue)) + pushForward(noders, valueID, getID(source.tvalue)) end end if source.type == 'doc.type.array' then @@ -599,13 +613,13 @@ function m.compileNode(source) id, ANY_FIELD ) - pushForward(nodeID, getID(source.node)) + pushForward(noders, nodeID, getID(source.node)) end local keyID = ('%s%s'):format( id, TABLE_KEY ) - pushForward(keyID, 'dn:integer') + pushForward(noders, keyID, 'dn:integer') end -- 将函数的返回值映射到具体的返回值上 if source.type == 'function' then @@ -627,10 +641,10 @@ function m.compileNode(source) index ) for _, rtnObj in ipairs(rtnObjs) do - pushForward(returnID, getID(rtnObj)) + pushForward(noders, returnID, getID(rtnObj)) if rtnObj.type == 'function' or rtnObj.type == 'call' then - pushBackward(getID(rtnObj), returnID) + pushBackward(noders, getID(rtnObj), returnID) end end end @@ -645,8 +659,8 @@ function m.compileNode(source) RETURN_INDEX, rtn.returnIndex ) - pushForward(fullID, getID(rtn)) - pushBackward(getID(rtn), fullID) + pushForward(noders, fullID, getID(rtn)) + pushBackward(noders, getID(rtn), fullID) end end if doc.type == 'doc.param' then @@ -655,7 +669,7 @@ function m.compileNode(source) local paramIndex = source.docParamMap[paramName] local param = source.args[paramIndex] if param then - pushForward(getID(param), getID(doc)) + pushForward(noders, getID(param), getID(doc)) param.docParam = doc end end @@ -663,7 +677,7 @@ function m.compileNode(source) if doc.type == 'doc.vararg' then for _, param in ipairs(source.args) do if param.type == '...' then - pushForward(getID(param), getID(doc)) + pushForward(noders, getID(param), getID(doc)) end end end @@ -678,8 +692,8 @@ function m.compileNode(source) for _, rtn in ipairs(source.returns) do local rtnObj = rtn[1] if rtnObj then - pushForward('mainreturn', getID(rtnObj)) - pushBackward(getID(rtnObj), 'mainreturn') + pushForward(noders, 'mainreturn', getID(rtnObj)) + pushBackward(noders, getID(rtnObj), 'mainreturn') end end end @@ -692,8 +706,8 @@ function m.compileNode(source) i ) local returnID = getID(rtn) - pushForward(closureID, returnID) - pushBackward(returnID, closureID) + pushForward(noders, closureID, returnID) + pushBackward(noders, returnID, closureID) end end if source.type == 'generic.value' then @@ -704,15 +718,15 @@ function m.compileNode(source) local key = proto[1] if upvalues[key] then for _, paramID in ipairs(upvalues[key]) do - pushForward(id, paramID) - pushBackward(paramID, id) + pushForward(noders, id, paramID) + pushBackward(noders, paramID, id) end end end if proto.type == 'doc.type' then for _, tp in ipairs(source.types) do - pushForward(id, getID(tp)) - pushBackward(getID(tp), id) + pushForward(noders, id, getID(tp)) + pushBackward(noders, getID(tp), id) end end if proto.type == 'doc.type.array' then @@ -720,12 +734,12 @@ function m.compileNode(source) id, ANY_FIELD ) - pushForward(nodeID, getID(source.node)) + pushForward(noders, nodeID, getID(source.node)) local keyID = ('%s%s'):format( id, TABLE_KEY ) - pushForward(keyID, 'dn:integer') + pushForward(noders, keyID, 'dn:integer') end if proto.type == 'doc.type.table' then if source.tkey then @@ -733,14 +747,14 @@ function m.compileNode(source) id, TABLE_KEY ) - pushForward(keyID, getID(source.tkey)) + pushForward(noders, keyID, getID(source.tkey)) end if source.tvalue then local valueID = ('%s%s'):format( id, ANY_FIELD ) - pushForward(valueID, getID(source.tvalue)) + pushForward(noders, valueID, getID(source.tvalue)) end end end @@ -814,23 +828,33 @@ function m.removeID(root, id) noders[id] = nil end +---获取对象的noders +---@param source parser.guide.object +---@return noders +function m.getNoders(source) + local root = guide.getRoot(source) + if not root._noders then + root._noders = {} + end + return root._noders +end + ---编译整个文件的node ---@param source parser.guide.object ---@return table function m.compileNodes(source) local root = guide.getRoot(source) - if root._noders then - return root._noders + local noders = m.getNoders(source) + if next(noders) then + return end - Noders = {} - root._noders = Noders guide.eachSource(root, function (src) - m.pushSource(src) - m.compileNode(src) + m.pushSource(noders, src) + m.compileNode(noders, src) end) -- Special rule: ('').XX -> stringlib.XX - pushForward('str:', 'dn:stringlib') - return Noders + pushForward(noders, 'str:', 'dn:stringlib') + return noders end return m diff --git a/script/core/searcher.lua b/script/core/searcher.lua index c97237d0..ccb5814b 100644 --- a/script/core/searcher.lua +++ b/script/core/searcher.lua @@ -3,6 +3,7 @@ local guide = require 'parser.guide' local files = require 'files' local generic = require 'core.generic' local ws = require 'workspace' +local vm = require 'vm.vm' local NONE = {'NONE'} local LAST = {'LAST'} @@ -152,7 +153,7 @@ end -- TODO function m.findGlobals(root) - noder.compileNode(root) + noder.compileNode(noder.getNoders(root), root) -- TODO return {} end @@ -291,14 +292,24 @@ function m.searchRefsByID(status, uri, expect, mode) return end local func = guide.getParentFunction(obj) - if not func or func.type ~= 'function' then + if not func then return end - local parentID = noder.getID(func) - if not parentID then - return + if func.type == 'function' then + local parentID = noder.getID(func) + if not parentID then + return + end + search(parentID, noder.RETURN_INDEX .. returnIndex) + end + if func.type == 'main' then + local calls = vm.getLinksTo(uri) + for _, call in ipairs(calls) do + local turi = guide.getUri(call) + local tid = noder.getID(call) + crossSearch(status, turi, tid, mode) + end end - search(parentID, noder.RETURN_INDEX .. returnIndex) end local function isCallID(field) @@ -477,7 +488,7 @@ function m.searchRefsByID(status, uri, expect, mode) end search(expect) - searchFunction(expect) + --searchFunction(expect) --清除来自泛型的临时对象 for _, closure in pairs(closureCache) do diff --git a/test/definition/init.lua b/test/definition/init.lua index 78170e0e..85bcd5d5 100644 --- a/test/definition/init.lua +++ b/test/definition/init.lua @@ -36,6 +36,7 @@ end function TEST(script) files.removeAll() + script = script:gsub('\n', '\r\n') local target = catch_target(script) local start = script:find('<?', 1, true) local finish = script:find('?>', 1, true) |