diff options
Diffstat (limited to 'script/vm/compiler.lua')
-rw-r--r-- | script/vm/compiler.lua | 281 |
1 files changed, 281 insertions, 0 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua new file mode 100644 index 00000000..3af72110 --- /dev/null +++ b/script/vm/compiler.lua @@ -0,0 +1,281 @@ +local guide = require 'parser.guide' +local util = require 'utility' +local union = require 'vm.union' +local localID = require 'vm.local-id' +local localMgr = require 'vm.local-manager' +local globalMgr = require 'vm.global-manager' + +---@class parser.object +---@field _compiledNodes boolean +---@field _node vm.node + +---@class vm.node.compiler +local m = {} + +---@class vm.node.cross + +---@alias vm.node parser.object | vm.node.union | vm.node.cross | vm.node.global + +function m.setNode(source, node) + if not node then + return + end + local me = source._node + if not me then + source._node = node + return + end + if me == node then + return + end + if me.type == 'union' + or me.type == 'cross' then + me:merge(node) + return + end + source._node = union(me, node) +end + +function m.eachNode(node) + if node.type == 'union' then + return node:eachNode() + end + local first = true + return function () + if first then + first = false + return node + end + return nil + end +end + +local searchFieldMap = util.switch() + : case 'table' + : call(function (node, key, pushResult) + for _, field in ipairs(node) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if guide.getKeyName(field) == key then + pushResult(m.compileNode(field)) + end + end + end + end) + : case 'global' + ---@param node vm.node.global + : call(function (node, key, pushResult) + local global = globalMgr.getGlobal('variable', node.name, key) + if global then + pushResult(global) + end + end) + : case 'local' + : call(function (node, key, pushResult) + local sources = localID.getSources(node, key) + if sources then + for _, src in ipairs(sources) do + pushResult(m.compileNode(src)) + end + end + end) + : getMap() + +local function getReturnOfFunction(func, index) + if not func._returns then + func._returns = util.defaultTable(function () + return { + type = 'function.return', + parent = func, + index = index, + } + end) + end + return m.compileNode(func._returns[index]) +end + +local function getReturnOfSetMetaTable(source, args) + local tbl = args and args[1] + local mt = args and args[2] + if tbl then + m.setNode(source, m.compileNode(tbl)) + end + if mt then + m.compileByParentNode(mt, '__index', function (node) + m.setNode(source, node) + end) + end + return source._node +end + +local function getReturn(func, index, source, args) + if func.special == 'setmetatable' then + return getReturnOfSetMetaTable(source, args) + end + local node = m.compileNode(func) + if node then + for cnode in m.eachNode(node) do + if cnode.type == 'function' then + return getReturnOfFunction(cnode, index) + end + end + end +end + +local function compileByLocalID(source) + local sources = localID.getSources(source) + if not sources then + return + end + for _, src in ipairs(sources) do + if src.value then + m.setNode(source, m.compileNode(src.value)) + end + end +end + +---@param source vm.node +---@param key any +---@param pushResult fun(node:vm.node) +function m.compileByParentNode(source, key, pushResult) + local parentNode = m.compileNode(source) + if not parentNode then + return + end + for node in m.eachNode(parentNode) do + local f = searchFieldMap[node.type] + if f then + f(node, key, pushResult) + end + end +end + +local compilerMap = util.switch() + : case 'boolean' + : case 'table' + : case 'integer' + : case 'number' + : case 'string' + : case 'function' + : call(function (source) + localMgr.declareLocal(source) + m.setNode(source, source) + end) + : case 'local' + : call(function (source) + m.setNode(source, source) + if source.value then + m.setNode(source, m.compileNode(source.value)) + end + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + m.setNode(source, m.compileNode(ref.value)) + end + end + end + if source.dummy then + m.setNode(source, m.compileNode(source.method.node)) + end + -- function x.y(self, ...) --> function x:y(...) + if source[1] == 'self' + and source.parent.type == 'funcargs' + and source.parent[1] == source then + local setfield = source.parent.parent.parent + if setfield.type == 'setfield' then + m.setNode(source, m.compileNode(setfield.node)) + end + end + end) + : case 'getlocal' + : call(function (source) + m.setNode(source, m.compileNode(source.node)) + end) + : case 'setfield' + : case 'setmethod' + : case 'setindex' + : call(function (source) + compileByLocalID(source) + end) + : case 'getfield' + : case 'getmethod' + : case 'getindex' + : call(function (source) + compileByLocalID(source) + m.compileByParentNode(source.node, guide.getKeyName(source), function (node) + m.setNode(source, node) + end) + end) + : case 'tablefield' + : case 'tableindex' + : call(function (source) + if source.value then + m.setNode(source, m.compileNode(source.value)) + end + end) + : case 'field' + : case 'method' + : call(function (source) + m.setNode(source, m.compileNode(source.parent)) + end) + : case 'function.return' + : call(function (source) + local func = source.parent + local index = source.index + if func.returns then + for _, rtn in ipairs(func.returns) do + if rtn[index] then + m.setNode(source, m.compileNode(rtn[index])) + end + end + end + end) + : case 'select' + : call(function (source) + local vararg = source.vararg + if vararg.type == 'call' then + m.setNode(source, getReturn(vararg.node, source.sindex, source, vararg.args)) + end + end) + : case 'call' + : call(function (source) + m.setNode(source, getReturn(source.node, 1, source, source.args)) + end) + : getMap() + +---@param source parser.object +local function compileByNode(source) + local compiler = compilerMap[source.type] + if compiler then + compiler(source) + end +end + +---@param source parser.object +local function compileByGlobal(source) + if source._globalNode then + m.setNode(source, source._globalNode) + for _, set in ipairs(source._globalNode:getSets()) do + if set.value then + m.setNode(source, m.compileNode(set.value)) + end + end + end +end + +---@param source parser.object +---@return vm.node +function m.compileNode(source) + if source._node ~= nil then + return source._node + end + source._node = false + compileByNode(source) + compileByGlobal(source) + + localMgr.subscribeLocal(source, source._node) + + return source._node +end + +return m |