diff options
Diffstat (limited to 'script/core/linker.lua')
-rw-r--r-- | script/core/linker.lua | 592 |
1 files changed, 592 insertions, 0 deletions
diff --git a/script/core/linker.lua b/script/core/linker.lua new file mode 100644 index 00000000..9741d7b4 --- /dev/null +++ b/script/core/linker.lua @@ -0,0 +1,592 @@ +local util = require 'utility' +local guide = require 'parser.guide' + +local Linkers, GetLink +local LastIDCache = {} +local SPLIT_CHAR = '\x1F' +local SPLIT_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']+$' +local RETURN_INDEX_CHAR = '#' +local PARAM_INDEX_CHAR = '@' + +---是否是全局变量(包括 _G.XXX 形式) +---@param source parser.guide.object +---@return boolean +local function isGlobal(source) + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.special == '_G' then + return true + end + return false +end + +---获取语法树单元的key +---@param source parser.guide.object +---@return string? key +---@return parser.guide.object? node +local function getKey(source) + if source.type == 'local' then + return tostring(source.start), nil + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return tostring(source.node.start), nil + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + return ('%q'):format(source[1] or ''), nil + elseif source.type == 'getfield' + or source.type == 'setfield' then + return ('%q'):format(source.field and source.field[1] or ''), source.node + elseif source.type == 'tablefield' then + return ('%q'):format(source.field and source.field[1] or ''), source.parent + elseif source.type == 'getmethod' + or source.type == 'setmethod' then + return ('%q'):format(source.method and source.method[1] or ''), source.node + elseif source.type == 'setindex' + or source.type == 'getindex' then + local index = source.index + if not index then + return '', source.node + end + if index.type == 'string' then + return ('%q'):format(index[1] or ''), source.node + else + return '', source.node + end + elseif source.type == 'tableindex' then + local index = source.index + if not index then + return '', source.parent + end + if index.type == 'string' then + return ('%q'):format(index[1] or ''), source.parent + else + return '', source.parent + end + elseif source.type == 'table' then + return source.start, nil + elseif source.type == 'label' then + return source.start, nil + elseif source.type == 'goto' then + if source.node then + return source.node.start, nil + end + return nil, nil + elseif source.type == 'function' then + return source.start, nil + elseif source.type == '...' then + return source.start, nil + elseif source.type == 'select' then + return ('%d%s%s%d'):format(source.start, SPLIT_CHAR, RETURN_INDEX_CHAR, source.index) + elseif source.type == 'call' then + local node = source.node + if node.special == 'rawget' + or node.special == 'rawset' then + if not source.args then + return nil, nil + end + local tbl, key = source.args[1], source.args[2] + if not tbl or not key then + return nil, nil + end + if key.type == 'string' then + return ('%q'):format(key[1] or ''), tbl + else + return '', tbl + end + end + return source.start, nil + elseif source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.see.name' then + return source[1], nil + elseif source.type == 'doc.class' + or source.type == 'doc.type' + or source.type == 'doc.alias' + or source.type == 'doc.param' + or source.type == 'doc.vararg' + or source.type == 'doc.field.name' + or source.type == 'doc.type.function' then + return source.start, nil + elseif source.type == 'doc.see.field' then + return ('%q'):format(source[1]), source.parent.name + end + return nil, nil +end + +local function checkMode(source) + if source.type == 'table' then + return 't:' + end + if source.type == 'select' then + return 's:' + end + if source.type == 'function' then + return 'f:' + end + if source.type == 'call' then + return 'c:' + end + if source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' then + return 'dn:' + end + if source.type == 'doc.field.name' then + return 'dfn:' + end + if source.type == 'doc.see.name' then + return 'dsn:' + end + if source.type == 'doc.class' then + return 'dc:' + end + if source.type == 'doc.type' then + return 'dt:' + end + if source.type == 'doc.param' then + return 'dp:' + end + if source.type == 'doc.alias' then + return 'da:' + end + if source.type == 'doc.type.function' then + return 'df:' + end + if source.type == 'doc.vararg' then + return 'dv:' + end + if isGlobal(source) then + return 'g:' + end + return 'l:' +end + +local IDList = {} +---获取语法树单元的字符串ID +---@param source parser.guide.object +---@return string? id +local function getID(source) + if not source then + return nil + end + if source._id ~= nil then + return source._id or nil + end + if source.type == 'field' + or source.type == 'method' then + source._id = false + return nil + end + local current = source + local index = 0 + while true do + local id, node = getKey(current) + if not id then + break + end + index = index + 1 + IDList[index] = id + if not node then + break + end + current = node + if current.special == '_G' then + break + end + end + if index == 0 then + source._id = false + return nil + end + for i = index + 1, #IDList do + IDList[i] = nil + end + local mode = checkMode(current) + if not mode then + source._id = false + return nil + end + util.revertTable(IDList) + local id = mode .. table.concat(IDList, SPLIT_CHAR) + source._id = id + return id +end + +---添加关联单元 +---@param id string +---@param source parser.guide.object +local function pushSource(id, source) + local link = GetLink(id) + if not link.sources then + link.sources = {} + end + link.sources[#link.sources+1] = source +end + +---添加关联的前进ID +---@param id string +---@param forwardID string +local function pushForward(id, forwardID) + if not id + or not forwardID + or forwardID == '' + or id == forwardID then + return + end + local link = GetLink(id) + if not link.forward then + link.forward = {} + end + link.forward[#link.forward+1] = forwardID +end + +---添加关联的后退ID +---@param id string +---@param backwardID string +local function pushBackward(id, backwardID) + if not id + or not backwardID + or backwardID == '' + or id == backwardID then + return + end + local link = GetLink(id) + if not link.backward then + link.backward = {} + end + link.backward[#link.backward+1] = backwardID +end + +local function eachParentSelect(call, callback) + if call.type ~= 'call' then + return + end + if call.parent.type == 'select' then + callback(call.parent, call.parent.index) + end + if not call.extParent then + return + end + for _, sel in ipairs(call.extParent) do + if sel.type == 'select' then + callback(sel, sel.index) + end + end +end + +---前进 +---@param source parser.guide.object +---@return parser.guide.object[] +local function compileLink(source) + local id = getID(source) + local parent = source.parent + if not parent then + return + end + if source.value then + -- x = y : x -> y + pushForward(id, getID(source.value)) + pushBackward(getID(source.value), id) + end + -- self -> mt:xx + if source.type == 'local' and source[1] == 'self' then + local func = guide.getParentFunction(source) + local setmethod = func.parent + -- guess `self` + if setmethod and ( setmethod.type == 'setmethod' + or setmethod.type == 'setfield' + or setmethod.type == 'setindex') then + pushForward(id, getID(setmethod.node)) + pushBackward(getID(setmethod.node), id) + end + end + -- 分解 @type + if source.type == 'doc.type' then + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(getID(src), id) + pushForward(id, getID(src)) + end + end + for _, typeUnit in ipairs(source.types) do + pushForward(id, getID(typeUnit)) + pushBackward(getID(typeUnit), id) + end + end + -- 分解 @class + if source.type == 'doc.class' then + pushForward(id, getID(source.class)) + pushForward(getID(source.class), id) + if source.extends then + for _, ext in ipairs(source.extends) do + pushForward(id, getID(ext)) + pushBackward(getID(ext), id) + end + end + if source.bindSources then + for _, src in ipairs(source.bindSources) do + pushForward(getID(src), id) + pushForward(id, getID(src)) + end + end + do + local start + for _, doc in ipairs(source.bindGroup) do + if doc.type == 'doc.class' then + start = doc == source + end + if start and doc.type == 'doc.field' then + local key = doc.field[1] + if key then + local keyID = ('%s%s%q'):format( + id, + SPLIT_CHAR, + key + ) + pushForward(keyID, getID(doc.field)) + pushBackward(getID(doc.field), keyID) + pushForward(keyID, getID(doc.extends)) + pushBackward(getID(doc.extends), keyID) + end + end + end + end + end + if source.type == 'doc.param' then + pushForward(getID(source), getID(source.extends)) + end + if source.type == 'doc.vararg' then + pushForward(getID(source), getID(source.vararg)) + end + if source.type == 'doc.see' then + local nameID = getID(source.name) + local classID = nameID:gsub('^dsn:', 'dn:') + pushForward(nameID, classID) + if source.field then + local fieldID = getID(source.field) + local fieldClassID = fieldID:gsub('^dsn:', 'dn:') + pushForward(fieldID, fieldClassID) + end + end + if source.type == 'call' then + local node = source.node + local nodeID = getID(node) + -- 将call的返回值接收映射到函数返回值上 + eachParentSelect(source, function (sel) + local selectID = getID(sel) + local callID = ('%s%s%s%s'):format( + nodeID, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + sel.index + ) + pushForward(selectID, callID) + pushBackward(callID, selectID) + if sel.index == 1 then + pushForward(id, callID) + pushBackward(callID, id) + end + end) + -- 将setmetatable映射到 param1 以及 param2.__index 上 + if node.special == 'setmetatable' then + local callID = ('%s%s%s%s'):format( + nodeID, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + 1 + ) + local tblID = getID(source.args and source.args[1]) + local metaID = getID(source.args and source.args[2]) + local indexID + if metaID then + indexID = ('%s%s%q'):format( + metaID, + SPLIT_CHAR, + '__index' + ) + end + pushForward(id, callID) + pushBackward(callID, id) + pushForward(callID, tblID) + pushForward(callID, indexID) + pushBackward(tblID, callID) + --pushBackward(indexID, callID) + end + end + -- 将函数的返回值映射到具体的返回值上 + if source.type == 'function' then + -- 检查实体返回值 + if source.returns then + local returns = {} + for _, rtn in ipairs(source.returns) do + for index, rtnObj in ipairs(rtn) do + if not returns[index] then + returns[index] = {} + end + returns[index][#returns[index]+1] = rtnObj + end + end + for index, rtnObjs in ipairs(returns) do + local returnID = ('%s%s%s%s'):format( + getID(source), + SPLIT_CHAR, + RETURN_INDEX_CHAR, + index + ) + for _, rtnObj in ipairs(rtnObjs) do + pushForward(returnID, getID(rtnObj)) + if rtnObj.type == 'function' + or rtnObj.type == 'call' then + pushBackward(getID(rtnObj), returnID) + end + end + end + end + -- 检查 luadoc + if source.bindDocs then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + local fullID = ('%s%s%s%s'):format( + id, + SPLIT_CHAR, + RETURN_INDEX_CHAR, + rtn.returnIndex + ) + pushForward(getID(rtn), fullID) + pushBackward(fullID, getID(rtn)) + end + end + if doc.type == 'doc.param' then + local paramName = doc.param[1] + for _, param in ipairs(source.args) do + if param[1] == paramName then + pushForward(getID(param), getID(doc)) + end + end + end + if doc.type == 'doc.vararg' then + for _, param in ipairs(source.args) do + if param.type == '...' then + pushForward(getID(param), getID(doc)) + end + end + end + end + end + end +end + +---@class link +-- 当前节点的id +---@field id string +-- 使用该ID的单元 +---@field sources parser.guide.object[] +-- 前进的关联ID +---@field forward string[] +-- 后退的关联ID +---@field backward string[] + +---创建source的链接信息 +---@param id string +---@return link +function GetLink(id) + if not Linkers[id] then + Linkers[id] = { + id = id, + } + end + return Linkers[id] +end + +local m = {} + +m.SPLIT_CHAR = SPLIT_CHAR +m.RETURN_INDEX_CHAR = RETURN_INDEX_CHAR +m.PARAM_INDEX_CHAR = PARAM_INDEX_CHAR + +---根据ID来获取所有的link +---@param root parser.guide.object +---@param id string +---@return link? +function m.getLinkByID(root, id) + root = guide.getRoot(root) + local linkers = root._linkers + if not linkers then + return nil + end + return linkers[id] +end + +---根据ID来获取上个节点的ID +---@param id string +---@return string +function m.getLastID(id) + if LastIDCache[id] then + return LastIDCache[id] or nil + end + local lastID, count = id:gsub(SPLIT_REGEX, '') + if count == 0 then + LastIDCache[id] = false + return nil + end + LastIDCache[id] = lastID + return lastID +end + +---获取source的ID +---@param source parser.guide.object +---@return string +function m.getID(source) + return getID(source) +end + +---获取source的special +---@param source parser.guide.object +---@return table +function m.getSpecial(source, key) + if not source then + return nil + end + local link = m.getLink(source) + if not link then + return nil + end + local special = link.special + if not special then + return nil + end + return special[key] +end + +---编译整个文件的link +---@param source parser.guide.object +---@return table +function m.compileLinks(source) + local root = guide.getRoot(source) + if root._linkers then + return root._linkers + end + Linkers = {} + root._linkers = Linkers + guide.eachSource(root, function (src) + local id = getID(src) + if id then + pushSource(id, src) + end + compileLink(src) + end) + return Linkers +end + +return m |