diff options
-rw-r--r-- | server-beta/src/core/definition.lua | 23 | ||||
-rw-r--r-- | server-beta/src/core/engineer.lua | 207 | ||||
-rw-r--r-- | server-beta/src/parser/compile.lua | 17 | ||||
-rw-r--r-- | server-beta/src/parser/guide.lua | 11 | ||||
-rw-r--r-- | server-beta/test/definition/set.lua | 2 |
5 files changed, 216 insertions, 44 deletions
diff --git a/server-beta/src/core/definition.lua b/server-beta/src/core/definition.lua index de041ef2..b15519bd 100644 --- a/server-beta/src/core/definition.lua +++ b/server-beta/src/core/definition.lua @@ -19,8 +19,8 @@ function m.search(state, ast, source) end function m.aslocal(state, ast, source) - engineer(ast):eachLocalRef(source, function (src) - if src.type == 'local' or src.type == 'setlocal' then + engineer(ast):eachLocalRef(source, function (src, mode) + if mode == 'local' or mode == 'set' then state.callback(src) end end) @@ -30,24 +30,15 @@ m.asgetlocal = m.aslocal m.assetlocal = m.aslocal function m.globals(state, ast, source) - local name = source[1] - guide.eachGloabl(ast.root, function (src, gname) - if name ~= gname then - return - end - if src.type == 'setglobal' or src.type == 'setfield' then - state.callback(src, ast.uri) + engineer(ast):eachGloablOfName(source[1], function (src, mode) + if mode == 'set' then + state.callback(src) end end) end -function m.assetglobal(state, ast, source) - m.globals(state, ast, source) -end - -function m.asgetglobal(state, ast, source) - m.globals(state, ast, source) -end +m.assetglobal = m.globals +m.asgetglobal = m.globals return function (ast, text, offset) local results = {} diff --git a/server-beta/src/core/engineer.lua b/server-beta/src/core/engineer.lua index e4630670..c82efa6d 100644 --- a/server-beta/src/core/engineer.lua +++ b/server-beta/src/core/engineer.lua @@ -1,35 +1,197 @@ -local guide = require 'parser.guide' +local guide = require 'parser.guide' +local config = require 'config' + +local type = type +local setmetatable = setmetatable + +_ENV = nil ---@class engineer local mt = {} mt.__index = mt mt.type = 'engineer' ---- 遍历全局变量 -function mt:eachGloabl(root, callback) - guide.eachSourceOf(root, {'setglobal', 'getglobal', 'setfield', 'getfield'}, function (src) - if src.type == 'setglobal' or src.type == 'getglobal' then - callback(src, src[1]) - elseif src.type == 'setfield' or src.type == 'getfield' then - local node = root[src.node] - if self.isGlobal(root, node) then - callback(src, src.field[1]) +function mt:call(method, obj, ...) + self.step = self.step + 1 + if self.step > 100 then + return nil + end + if not obj then + return nil + end + if ... == nil and obj['_'..method] ~= nil then + return obj['_'..method] + end + local res = self[method](self, obj, ...) + self.step = self.step - 1 + if ... == nil then + obj['_'..method] = res + end + return res +end + +--- 根据变量名,遍历全局变量 +function mt:eachGloablOfName(name, callback) + if type(name) ~= 'string' then + return + end + guide.eachSourceOf(self.ast.root, { + ['setglobal'] = function (source) + if source[1] == name then + callback(source, 'set') end - end - end) + end, + ['getglobal'] = function (source) + if source[1] == name then + callback(source, 'get') + end + end, + ['setfield'] = function (source) + if source.field[1] ~= name then + return + end + if self:call('isGlobalField', source) then + callback(source, 'set') + end + end, + ['getfield'] = function (source) + if source.field[1] ~= name then + return + end + if self:call('isGlobalField', source) then + callback(source, 'get') + end + end, + ['call'] = function (source) + local d = self:call('asRawSet', source) + if d then + if self:call('getLiteral', d.k) == name then + callback(source, 'set') + end + end + local d = self:call('asRawGet', source) + if d then + if self:call('getLiteral', d.k) == name then + callback(source, 'get') + end + end + end, + }) end ---- 判断全局变量 -function mt:isGlobal(root, obj) +--- 是否是全局变量 +function mt:isGlobal(obj) if obj.type == 'getglobal' then - if obj[1] == '_G' or obj[1] == '_ENV' then - return true - end + return true + end + if obj.type == 'getfield' then + return self:call('isGlobalField', obj) + end + return false +end + +--- 是否是指定名称的全局变量 +function mt:isGlobalOfName(obj, name) + if not self:call('isGlobal', obj) then + return false + end + return self:call('getName', obj) == name +end + +--- 获取名称 +function mt:getName(obj) + if obj.type == 'setglobal' or obj.type == 'getglobal' then + return obj[1] + elseif obj.type == 'setfield' or obj.type == 'getfield' then + return obj.field[1] + elseif obj.type == 'local' or obj.type == 'setlocal' or obj.type == 'getlocal' then + return obj[1] end return false end ---- 遍历局部变量引用 +--- 获取字面量值 +function mt:getLiteral(obj) + if obj.type == 'number' then + return obj[1] + elseif obj.type == 'boolean' then + return obj[1] + elseif obj.type == 'string' then + return obj[1] + end + return nil +end + +--- 是否是全局field +---|_G.xxx +---|_ENV.xxx +---|_ENV._G.xxx +function mt:isGlobalField(obj) + local node = self.ast.root[obj.node] + if self:call('isG', node) then + return true + end + if self:call('isENV', node) then + return true + end + return false +end + +--- 是否是_ENV +function mt:isENV(obj) + local version = config.config.runtime.version + if version == 'Lua 5.1' or version == 'LuaJIT' then + return false + end + if self:isGlobalOfName(obj, '_ENV') then + return true + end + return false +end + +--- 是否是_G +function mt:isG(obj) + if self:isGlobalOfName(obj, '_G') then + return true + end + return false +end + +--- 获取call的参数 +function mt:getCallArg(obj, i) + local args = self.ast.root[obj.args] + if not args then + return nil + end + return self.ast.root[args[i]] +end + +--- 获取rawset信息 +function mt:asRawSet(obj) + local node = self.ast.root[obj.node] + if not self:isGlobalOfName(node, 'rawset') then + return false + end + return { + t = self:getCallArg(obj, 1), + k = self:getCallArg(obj, 2), + v = self:getCallArg(obj, 3), + } +end + +--- 获取rawget信息 +function mt:asRawGet(obj) + local node = self.ast.root[obj.node] + if not self:isGlobalOfName(node, 'rawget') then + return false + end + return { + t = self:getCallArg(obj, 1), + k = self:getCallArg(obj, 2), + } +end + +--- 根据指定的局部变量,遍历局部变量引用 function mt:eachLocalRef(obj, callback) if not obj then return @@ -42,11 +204,16 @@ function mt:eachLocalRef(obj, callback) else return end - callback(src) + callback(src, 'local') if src.ref then for i = 1, #src.ref do local ref = src.ref[i] - callback(self.ast.root[ref]) + local refObj = self.ast.root[ref] + if refObj.type == 'setlocal' then + callback(refObj, 'set') + elseif refObj.type == 'getlocal' then + callback(refObj, 'get') + end end end end diff --git a/server-beta/src/parser/compile.lua b/server-beta/src/parser/compile.lua index 457af42c..45fd808b 100644 --- a/server-beta/src/parser/compile.lua +++ b/server-beta/src/parser/compile.lua @@ -3,7 +3,7 @@ local type = type _ENV = nil -local pushError, Root, Compile, CompileBlock, Cache, Block, GoToTag +local pushError, Root, Compile, CompileBlock, Cache, Block, GoToTag, Version, ENVMode local vmMap = { ['nil'] = function (obj) @@ -475,6 +475,15 @@ local vmMap = { Block = obj Root[#Root+1] = obj local id = #Root + if ENVMode == '_ENV' then + Compile({ + type = 'local', + start = 0, + finish = 0, + effect = 0, + [1] = '_ENV', + }, id) + end CompileBlock(obj, id) Block = nil return id @@ -584,6 +593,12 @@ return function (self, lua, mode, version) end pushError = state.pushError Root = state.root + Version = version + if version == 'Lua 5.1' or version == 'LuaJIT' then + ENVMode = 'fenv' + else + ENVMode = '_ENV' + end Cache = {} GoToTag = {} if type(state.ast) == 'table' then diff --git a/server-beta/src/parser/guide.lua b/server-beta/src/parser/guide.lua index bd7f9150..7925abe2 100644 --- a/server-beta/src/parser/guide.lua +++ b/server-beta/src/parser/guide.lua @@ -1,6 +1,4 @@ local error = error -local utf8Len = utf8.len -local utf8Offset = utf8.offset local type = type _ENV = nil @@ -199,18 +197,19 @@ end --- 遍历所有某种类型的source function m.eachSourceOf(root, types, callback) if type(types) == 'string' then - types = {[types] = true} + types = {[types] = callback} elseif type(types) == 'table' then for i = 1, #types do - types[types[i]] = true + types[types[i]] = callback end else return end for i = 1, #root do local source = root[i] - if types[source.type] then - callback(source) + local f = types[source.type] + if f then + f(source) end end end diff --git a/server-beta/test/definition/set.lua b/server-beta/test/definition/set.lua index ec582d38..b7e7e5d4 100644 --- a/server-beta/test/definition/set.lua +++ b/server-beta/test/definition/set.lua @@ -45,6 +45,6 @@ print(<?x?>) ]] TEST [[ -rawset<!(_G, 'x', 1)!> +<!rawset(_G, 'x', 1)!> print(<?x?>) ]] |