summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/getDef.lua26
-rw-r--r--script/vm/local-id.lua42
-rw-r--r--script/vm/node/compiler.lua102
-rw-r--r--script/vm/node/union.lua3
4 files changed, 130 insertions, 43 deletions
diff --git a/script/vm/getDef.lua b/script/vm/getDef.lua
index 72a8ae08..c423ba1a 100644
--- a/script/vm/getDef.lua
+++ b/script/vm/getDef.lua
@@ -94,14 +94,25 @@ local searchFieldMap = util.switch()
pushResult(set)
end
end)
+ : case 'local'
+ : call(function (node, key, pushResult)
+ local sources = localID.getSources(node, key)
+ if sources then
+ for _, src in ipairs(sources) do
+ if guide.isSet(src) then
+ pushResult(src)
+ end
+ end
+ end
+ end)
: getMap()
-local searchByNode
+local searchByParentNode
local nodeMap = util.switch()
: case 'field'
: case 'method'
: call(function (source, pushResult)
- searchByNode(source.parent, pushResult)
+ searchByParentNode(source.parent, pushResult)
end)
: case 'getfield'
: case 'setfield'
@@ -114,8 +125,11 @@ local nodeMap = util.switch()
if not node then
return
end
- if searchFieldMap[node.type] then
- searchFieldMap[node.type](node, guide.getKeyName(source), pushResult)
+ local key = guide.getKeyName(source)
+ for n in compiler.eachNode(node) do
+ if searchFieldMap[n.type] then
+ searchFieldMap[n.type](n, key, pushResult)
+ end
end
end)
: getMap()
@@ -157,7 +171,7 @@ end
---@param source parser.object
---@param pushResult fun(src: parser.object)
-function searchByNode(source, pushResult)
+function searchByParentNode(source, pushResult)
local node = nodeMap[source.type]
if node then
node(source, pushResult)
@@ -180,7 +194,7 @@ function vm.getDefs(source)
searchBySimple(source, pushResult)
searchByGlobal(source, pushResult)
searchByID(source, pushResult)
- searchByNode(source, pushResult)
+ searchByParentNode(source, pushResult)
return results
end
diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua
index 8487d96a..ddcb9e97 100644
--- a/script/vm/local-id.lua
+++ b/script/vm/local-id.lua
@@ -13,6 +13,7 @@ m.ID_SPLITE = '\x1F'
local compileMap = util.switch()
: case 'local'
: call(function (source)
+ source._localID = ('%d'):format(source.start)
if not source.ref then
return
end
@@ -38,15 +39,46 @@ local compileMap = util.switch()
m.compileLocalID(source.next)
end
end)
+ : case 'getmethod'
+ : case 'setmethod'
+ : call(function (source)
+ local parentID = source.node._localID
+ if not parentID then
+ return
+ end
+ source._localID = parentID .. m.ID_SPLITE .. guide.getKeyName(source)
+ source.method._localID = source._localID
+ if source.type == 'getmethod' then
+ m.compileLocalID(source.next)
+ end
+ end)
+ : case 'getindex'
+ : case 'setindex'
+ : call(function (source)
+ local parentID = source.node._localID
+ if not parentID then
+ return
+ end
+ source._localID = parentID .. m.ID_SPLITE .. guide.getKeyName(source)
+ source.index._localID = source._localID
+ if source.type == 'setindex' then
+ m.compileLocalID(source.next)
+ end
+ end)
: getMap()
local leftMap = util.switch()
: case 'field'
+ : case 'method'
: call(function (source)
return m.getLocal(source.parent)
end)
: case 'getfield'
: case 'setfield'
+ : case 'getmethod'
+ : case 'setmethod'
+ : case 'getindex'
+ : case 'setindex'
: call(function (source)
return m.getLocal(source.node)
end)
@@ -54,6 +86,10 @@ local leftMap = util.switch()
: call(function (source)
return source.node
end)
+ : case 'local'
+ : call(function (source)
+ return source
+ end)
: getMap()
---@param source parser.object
@@ -100,8 +136,9 @@ function m.getID(source)
end
---@param source parser.object
+---@param key? string
---@return parser.object[]?
-function m.getSources(source)
+function m.getSources(source, key)
local id = m.getID(source)
if not id then
return nil
@@ -110,6 +147,9 @@ function m.getSources(source)
if not root._localIDs then
return nil
end
+ if key then
+ id = id .. m.ID_SPLITE .. key
+ end
return root._localIDs[id]
end
diff --git a/script/vm/node/compiler.lua b/script/vm/node/compiler.lua
index 9d86768f..6c51db86 100644
--- a/script/vm/node/compiler.lua
+++ b/script/vm/node/compiler.lua
@@ -50,6 +50,37 @@ function m.eachNode(node)
end
end
+local searchFieldMap = util.switch()
+ : case 'table'
+ : call(function (node, key, pushResult)
+ for _, field in ipairs(node) do
+ if field.type == 'tablefield'
+ or field.type == 'tableindex' then
+ if guide.getKeyName(field) == key then
+ pushResult(m.compileNode(field))
+ end
+ end
+ end
+ end)
+ : case 'global'
+ ---@param node vm.node.global
+ : call(function (node, key, pushResult)
+ local global = globalMgr.getGlobal(node.name, key)
+ if global then
+ pushResult(global)
+ end
+ end)
+ : case 'local'
+ : call(function (node, key, pushResult)
+ local sources = localID.getSources(node, key)
+ if sources then
+ for _, src in ipairs(sources) do
+ pushResult(m.compileNode(src))
+ end
+ end
+ end)
+ : getMap()
+
local function getReturnOfFunction(func, index)
if not func._returns then
func._returns = util.defaultTable(function ()
@@ -63,14 +94,32 @@ local function getReturnOfFunction(func, index)
return m.compileNode(func._returns[index])
end
-local function getReturn(func, index)
+local function getReturnOfSetMetaTable(source, args)
+ local tbl = args and args[1]
+ local mt = args and args[2]
+ if tbl then
+ m.setNode(source, m.compileNode(tbl))
+ end
+ if mt then
+ m.compileByParentNode(mt, '__index', function (node)
+ m.setNode(source, node)
+ end)
+ end
+ return source._node
+end
+
+local function getReturn(func, index, source, args)
local node = m.compileNode(func)
if not node then
return
end
for cnode in m.eachNode(node) do
- if cnode.type == 'function' then
+ if cnode.type == 'function' then
return getReturnOfFunction(cnode, index)
+ elseif cnode.type == 'global' then
+ if cnode.name == 'setmetatable' and index == 1 then
+ return getReturnOfSetMetaTable(source, args)
+ end
end
end
end
@@ -87,39 +136,19 @@ local function compileByLocalID(source)
end
end
-local searchFieldMap = util.switch()
- : case 'table'
- : call(function (node, key, pushResult)
- for _, field in ipairs(node) do
- if field.type == 'tablefield'
- or field.type == 'tableindex' then
- if guide.getKeyName(field) == key then
- pushResult(m.compileNode(field))
- end
- end
- end
- end)
- : case 'global'
- ---@param node vm.node.global
- : call(function (node, key, pushResult)
- local global = globalMgr.getGlobal(node.name, key)
- if global then
- pushResult(global)
- end
- end)
- : getMap()
-
-local function compileByParentNode(source)
- local parentNode = m.compileNode(source.node)
+---@param source vm.node
+---@param key any
+---@param pushResult fun(node:vm.node)
+function m.compileByParentNode(source, key, pushResult)
+ local parentNode = m.compileNode(source)
if not parentNode then
return
end
- local key = guide.getKeyName(source)
- local f = searchFieldMap[parentNode.type]
- if f then
- f(parentNode, key, function (fieldNode)
- m.setNode(source, fieldNode)
- end)
+ for node in m.eachNode(parentNode) do
+ local f = searchFieldMap[node.type]
+ if f then
+ f(node, key, pushResult)
+ end
end
end
@@ -136,6 +165,7 @@ local compilerMap = util.switch()
end)
: case 'local'
: call(function (source)
+ m.setNode(source, source)
if source.value then
m.setNode(source, m.compileNode(source.value))
end
@@ -165,7 +195,9 @@ local compilerMap = util.switch()
: case 'getindex'
: call(function (source)
compileByLocalID(source)
- compileByParentNode(source)
+ m.compileByParentNode(source.node, guide.getKeyName(source), function (node)
+ m.setNode(source, node)
+ end)
end)
: case 'tablefield'
: case 'tableindex'
@@ -190,7 +222,7 @@ local compilerMap = util.switch()
: call(function (source)
local vararg = source.vararg
if vararg.type == 'call' then
- m.setNode(source, getReturn(vararg.node, source.sindex))
+ m.setNode(source, getReturn(vararg.node, source.sindex, source, vararg.args))
end
end)
: getMap()
@@ -198,7 +230,7 @@ local compilerMap = util.switch()
---@param source parser.object
---@return vm.node
function m.compileNode(source)
- if source._node then
+ if source._node ~= nil then
return source._node
end
source._node = false
diff --git a/script/vm/node/union.lua b/script/vm/node/union.lua
index b944b92d..538e7586 100644
--- a/script/vm/node/union.lua
+++ b/script/vm/node/union.lua
@@ -48,7 +48,8 @@ end
---@return vm.node.union
return function (me, node)
local union = setmetatable({
- [1] = me,
+ [1] = me,
+ [me] = true,
}, mt)
union:merge(node)
return union