diff options
Diffstat (limited to 'script/vm')
-rw-r--r-- | script/vm/getDef.lua | 26 | ||||
-rw-r--r-- | script/vm/local-id.lua | 42 | ||||
-rw-r--r-- | script/vm/node/compiler.lua | 102 | ||||
-rw-r--r-- | script/vm/node/union.lua | 3 |
4 files changed, 130 insertions, 43 deletions
diff --git a/script/vm/getDef.lua b/script/vm/getDef.lua index 72a8ae08..c423ba1a 100644 --- a/script/vm/getDef.lua +++ b/script/vm/getDef.lua @@ -94,14 +94,25 @@ local searchFieldMap = util.switch() pushResult(set) end end) + : case 'local' + : call(function (node, key, pushResult) + local sources = localID.getSources(node, key) + if sources then + for _, src in ipairs(sources) do + if guide.isSet(src) then + pushResult(src) + end + end + end + end) : getMap() -local searchByNode +local searchByParentNode local nodeMap = util.switch() : case 'field' : case 'method' : call(function (source, pushResult) - searchByNode(source.parent, pushResult) + searchByParentNode(source.parent, pushResult) end) : case 'getfield' : case 'setfield' @@ -114,8 +125,11 @@ local nodeMap = util.switch() if not node then return end - if searchFieldMap[node.type] then - searchFieldMap[node.type](node, guide.getKeyName(source), pushResult) + local key = guide.getKeyName(source) + for n in compiler.eachNode(node) do + if searchFieldMap[n.type] then + searchFieldMap[n.type](n, key, pushResult) + end end end) : getMap() @@ -157,7 +171,7 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) -function searchByNode(source, pushResult) +function searchByParentNode(source, pushResult) local node = nodeMap[source.type] if node then node(source, pushResult) @@ -180,7 +194,7 @@ function vm.getDefs(source) searchBySimple(source, pushResult) searchByGlobal(source, pushResult) searchByID(source, pushResult) - searchByNode(source, pushResult) + searchByParentNode(source, pushResult) return results end diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua index 8487d96a..ddcb9e97 100644 --- a/script/vm/local-id.lua +++ b/script/vm/local-id.lua @@ -13,6 +13,7 @@ m.ID_SPLITE = '\x1F' local compileMap = util.switch() : case 'local' : call(function (source) + source._localID = ('%d'):format(source.start) if not source.ref then return end @@ -38,15 +39,46 @@ local compileMap = util.switch() m.compileLocalID(source.next) end end) + : case 'getmethod' + : case 'setmethod' + : call(function (source) + local parentID = source.node._localID + if not parentID then + return + end + source._localID = parentID .. m.ID_SPLITE .. guide.getKeyName(source) + source.method._localID = source._localID + if source.type == 'getmethod' then + m.compileLocalID(source.next) + end + end) + : case 'getindex' + : case 'setindex' + : call(function (source) + local parentID = source.node._localID + if not parentID then + return + end + source._localID = parentID .. m.ID_SPLITE .. guide.getKeyName(source) + source.index._localID = source._localID + if source.type == 'setindex' then + m.compileLocalID(source.next) + end + end) : getMap() local leftMap = util.switch() : case 'field' + : case 'method' : call(function (source) return m.getLocal(source.parent) end) : case 'getfield' : case 'setfield' + : case 'getmethod' + : case 'setmethod' + : case 'getindex' + : case 'setindex' : call(function (source) return m.getLocal(source.node) end) @@ -54,6 +86,10 @@ local leftMap = util.switch() : call(function (source) return source.node end) + : case 'local' + : call(function (source) + return source + end) : getMap() ---@param source parser.object @@ -100,8 +136,9 @@ function m.getID(source) end ---@param source parser.object +---@param key? string ---@return parser.object[]? -function m.getSources(source) +function m.getSources(source, key) local id = m.getID(source) if not id then return nil @@ -110,6 +147,9 @@ function m.getSources(source) if not root._localIDs then return nil end + if key then + id = id .. m.ID_SPLITE .. key + end return root._localIDs[id] end diff --git a/script/vm/node/compiler.lua b/script/vm/node/compiler.lua index 9d86768f..6c51db86 100644 --- a/script/vm/node/compiler.lua +++ b/script/vm/node/compiler.lua @@ -50,6 +50,37 @@ function m.eachNode(node) end end +local searchFieldMap = util.switch() + : case 'table' + : call(function (node, key, pushResult) + for _, field in ipairs(node) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if guide.getKeyName(field) == key then + pushResult(m.compileNode(field)) + end + end + end + end) + : case 'global' + ---@param node vm.node.global + : call(function (node, key, pushResult) + local global = globalMgr.getGlobal(node.name, key) + if global then + pushResult(global) + end + end) + : case 'local' + : call(function (node, key, pushResult) + local sources = localID.getSources(node, key) + if sources then + for _, src in ipairs(sources) do + pushResult(m.compileNode(src)) + end + end + end) + : getMap() + local function getReturnOfFunction(func, index) if not func._returns then func._returns = util.defaultTable(function () @@ -63,14 +94,32 @@ local function getReturnOfFunction(func, index) return m.compileNode(func._returns[index]) end -local function getReturn(func, index) +local function getReturnOfSetMetaTable(source, args) + local tbl = args and args[1] + local mt = args and args[2] + if tbl then + m.setNode(source, m.compileNode(tbl)) + end + if mt then + m.compileByParentNode(mt, '__index', function (node) + m.setNode(source, node) + end) + end + return source._node +end + +local function getReturn(func, index, source, args) local node = m.compileNode(func) if not node then return end for cnode in m.eachNode(node) do - if cnode.type == 'function' then + if cnode.type == 'function' then return getReturnOfFunction(cnode, index) + elseif cnode.type == 'global' then + if cnode.name == 'setmetatable' and index == 1 then + return getReturnOfSetMetaTable(source, args) + end end end end @@ -87,39 +136,19 @@ local function compileByLocalID(source) end end -local searchFieldMap = util.switch() - : case 'table' - : call(function (node, key, pushResult) - for _, field in ipairs(node) do - if field.type == 'tablefield' - or field.type == 'tableindex' then - if guide.getKeyName(field) == key then - pushResult(m.compileNode(field)) - end - end - end - end) - : case 'global' - ---@param node vm.node.global - : call(function (node, key, pushResult) - local global = globalMgr.getGlobal(node.name, key) - if global then - pushResult(global) - end - end) - : getMap() - -local function compileByParentNode(source) - local parentNode = m.compileNode(source.node) +---@param source vm.node +---@param key any +---@param pushResult fun(node:vm.node) +function m.compileByParentNode(source, key, pushResult) + local parentNode = m.compileNode(source) if not parentNode then return end - local key = guide.getKeyName(source) - local f = searchFieldMap[parentNode.type] - if f then - f(parentNode, key, function (fieldNode) - m.setNode(source, fieldNode) - end) + for node in m.eachNode(parentNode) do + local f = searchFieldMap[node.type] + if f then + f(node, key, pushResult) + end end end @@ -136,6 +165,7 @@ local compilerMap = util.switch() end) : case 'local' : call(function (source) + m.setNode(source, source) if source.value then m.setNode(source, m.compileNode(source.value)) end @@ -165,7 +195,9 @@ local compilerMap = util.switch() : case 'getindex' : call(function (source) compileByLocalID(source) - compileByParentNode(source) + m.compileByParentNode(source.node, guide.getKeyName(source), function (node) + m.setNode(source, node) + end) end) : case 'tablefield' : case 'tableindex' @@ -190,7 +222,7 @@ local compilerMap = util.switch() : call(function (source) local vararg = source.vararg if vararg.type == 'call' then - m.setNode(source, getReturn(vararg.node, source.sindex)) + m.setNode(source, getReturn(vararg.node, source.sindex, source, vararg.args)) end end) : getMap() @@ -198,7 +230,7 @@ local compilerMap = util.switch() ---@param source parser.object ---@return vm.node function m.compileNode(source) - if source._node then + if source._node ~= nil then return source._node end source._node = false diff --git a/script/vm/node/union.lua b/script/vm/node/union.lua index b944b92d..538e7586 100644 --- a/script/vm/node/union.lua +++ b/script/vm/node/union.lua @@ -48,7 +48,8 @@ end ---@return vm.node.union return function (me, node) local union = setmetatable({ - [1] = me, + [1] = me, + [me] = true, }, mt) union:merge(node) return union |