summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server-beta/src/core/definition.lua23
-rw-r--r--server-beta/src/core/engineer.lua207
-rw-r--r--server-beta/src/parser/compile.lua17
-rw-r--r--server-beta/src/parser/guide.lua11
-rw-r--r--server-beta/test/definition/set.lua2
5 files changed, 216 insertions, 44 deletions
diff --git a/server-beta/src/core/definition.lua b/server-beta/src/core/definition.lua
index de041ef2..b15519bd 100644
--- a/server-beta/src/core/definition.lua
+++ b/server-beta/src/core/definition.lua
@@ -19,8 +19,8 @@ function m.search(state, ast, source)
end
function m.aslocal(state, ast, source)
- engineer(ast):eachLocalRef(source, function (src)
- if src.type == 'local' or src.type == 'setlocal' then
+ engineer(ast):eachLocalRef(source, function (src, mode)
+ if mode == 'local' or mode == 'set' then
state.callback(src)
end
end)
@@ -30,24 +30,15 @@ m.asgetlocal = m.aslocal
m.assetlocal = m.aslocal
function m.globals(state, ast, source)
- local name = source[1]
- guide.eachGloabl(ast.root, function (src, gname)
- if name ~= gname then
- return
- end
- if src.type == 'setglobal' or src.type == 'setfield' then
- state.callback(src, ast.uri)
+ engineer(ast):eachGloablOfName(source[1], function (src, mode)
+ if mode == 'set' then
+ state.callback(src)
end
end)
end
-function m.assetglobal(state, ast, source)
- m.globals(state, ast, source)
-end
-
-function m.asgetglobal(state, ast, source)
- m.globals(state, ast, source)
-end
+m.assetglobal = m.globals
+m.asgetglobal = m.globals
return function (ast, text, offset)
local results = {}
diff --git a/server-beta/src/core/engineer.lua b/server-beta/src/core/engineer.lua
index e4630670..c82efa6d 100644
--- a/server-beta/src/core/engineer.lua
+++ b/server-beta/src/core/engineer.lua
@@ -1,35 +1,197 @@
-local guide = require 'parser.guide'
+local guide = require 'parser.guide'
+local config = require 'config'
+
+local type = type
+local setmetatable = setmetatable
+
+_ENV = nil
---@class engineer
local mt = {}
mt.__index = mt
mt.type = 'engineer'
---- 遍历全局变量
-function mt:eachGloabl(root, callback)
- guide.eachSourceOf(root, {'setglobal', 'getglobal', 'setfield', 'getfield'}, function (src)
- if src.type == 'setglobal' or src.type == 'getglobal' then
- callback(src, src[1])
- elseif src.type == 'setfield' or src.type == 'getfield' then
- local node = root[src.node]
- if self.isGlobal(root, node) then
- callback(src, src.field[1])
+function mt:call(method, obj, ...)
+ self.step = self.step + 1
+ if self.step > 100 then
+ return nil
+ end
+ if not obj then
+ return nil
+ end
+ if ... == nil and obj['_'..method] ~= nil then
+ return obj['_'..method]
+ end
+ local res = self[method](self, obj, ...)
+ self.step = self.step - 1
+ if ... == nil then
+ obj['_'..method] = res
+ end
+ return res
+end
+
+--- 根据变量名,遍历全局变量
+function mt:eachGloablOfName(name, callback)
+ if type(name) ~= 'string' then
+ return
+ end
+ guide.eachSourceOf(self.ast.root, {
+ ['setglobal'] = function (source)
+ if source[1] == name then
+ callback(source, 'set')
end
- end
- end)
+ end,
+ ['getglobal'] = function (source)
+ if source[1] == name then
+ callback(source, 'get')
+ end
+ end,
+ ['setfield'] = function (source)
+ if source.field[1] ~= name then
+ return
+ end
+ if self:call('isGlobalField', source) then
+ callback(source, 'set')
+ end
+ end,
+ ['getfield'] = function (source)
+ if source.field[1] ~= name then
+ return
+ end
+ if self:call('isGlobalField', source) then
+ callback(source, 'get')
+ end
+ end,
+ ['call'] = function (source)
+ local d = self:call('asRawSet', source)
+ if d then
+ if self:call('getLiteral', d.k) == name then
+ callback(source, 'set')
+ end
+ end
+ local d = self:call('asRawGet', source)
+ if d then
+ if self:call('getLiteral', d.k) == name then
+ callback(source, 'get')
+ end
+ end
+ end,
+ })
end
---- 判断全局变量
-function mt:isGlobal(root, obj)
+--- 是否是全局变量
+function mt:isGlobal(obj)
if obj.type == 'getglobal' then
- if obj[1] == '_G' or obj[1] == '_ENV' then
- return true
- end
+ return true
+ end
+ if obj.type == 'getfield' then
+ return self:call('isGlobalField', obj)
+ end
+ return false
+end
+
+--- 是否是指定名称的全局变量
+function mt:isGlobalOfName(obj, name)
+ if not self:call('isGlobal', obj) then
+ return false
+ end
+ return self:call('getName', obj) == name
+end
+
+--- 获取名称
+function mt:getName(obj)
+ if obj.type == 'setglobal' or obj.type == 'getglobal' then
+ return obj[1]
+ elseif obj.type == 'setfield' or obj.type == 'getfield' then
+ return obj.field[1]
+ elseif obj.type == 'local' or obj.type == 'setlocal' or obj.type == 'getlocal' then
+ return obj[1]
end
return false
end
---- 遍历局部变量引用
+--- 获取字面量值
+function mt:getLiteral(obj)
+ if obj.type == 'number' then
+ return obj[1]
+ elseif obj.type == 'boolean' then
+ return obj[1]
+ elseif obj.type == 'string' then
+ return obj[1]
+ end
+ return nil
+end
+
+--- 是否是全局field
+---|_G.xxx
+---|_ENV.xxx
+---|_ENV._G.xxx
+function mt:isGlobalField(obj)
+ local node = self.ast.root[obj.node]
+ if self:call('isG', node) then
+ return true
+ end
+ if self:call('isENV', node) then
+ return true
+ end
+ return false
+end
+
+--- 是否是_ENV
+function mt:isENV(obj)
+ local version = config.config.runtime.version
+ if version == 'Lua 5.1' or version == 'LuaJIT' then
+ return false
+ end
+ if self:isGlobalOfName(obj, '_ENV') then
+ return true
+ end
+ return false
+end
+
+--- 是否是_G
+function mt:isG(obj)
+ if self:isGlobalOfName(obj, '_G') then
+ return true
+ end
+ return false
+end
+
+--- 获取call的参数
+function mt:getCallArg(obj, i)
+ local args = self.ast.root[obj.args]
+ if not args then
+ return nil
+ end
+ return self.ast.root[args[i]]
+end
+
+--- 获取rawset信息
+function mt:asRawSet(obj)
+ local node = self.ast.root[obj.node]
+ if not self:isGlobalOfName(node, 'rawset') then
+ return false
+ end
+ return {
+ t = self:getCallArg(obj, 1),
+ k = self:getCallArg(obj, 2),
+ v = self:getCallArg(obj, 3),
+ }
+end
+
+--- 获取rawget信息
+function mt:asRawGet(obj)
+ local node = self.ast.root[obj.node]
+ if not self:isGlobalOfName(node, 'rawget') then
+ return false
+ end
+ return {
+ t = self:getCallArg(obj, 1),
+ k = self:getCallArg(obj, 2),
+ }
+end
+
+--- 根据指定的局部变量,遍历局部变量引用
function mt:eachLocalRef(obj, callback)
if not obj then
return
@@ -42,11 +204,16 @@ function mt:eachLocalRef(obj, callback)
else
return
end
- callback(src)
+ callback(src, 'local')
if src.ref then
for i = 1, #src.ref do
local ref = src.ref[i]
- callback(self.ast.root[ref])
+ local refObj = self.ast.root[ref]
+ if refObj.type == 'setlocal' then
+ callback(refObj, 'set')
+ elseif refObj.type == 'getlocal' then
+ callback(refObj, 'get')
+ end
end
end
end
diff --git a/server-beta/src/parser/compile.lua b/server-beta/src/parser/compile.lua
index 457af42c..45fd808b 100644
--- a/server-beta/src/parser/compile.lua
+++ b/server-beta/src/parser/compile.lua
@@ -3,7 +3,7 @@ local type = type
_ENV = nil
-local pushError, Root, Compile, CompileBlock, Cache, Block, GoToTag
+local pushError, Root, Compile, CompileBlock, Cache, Block, GoToTag, Version, ENVMode
local vmMap = {
['nil'] = function (obj)
@@ -475,6 +475,15 @@ local vmMap = {
Block = obj
Root[#Root+1] = obj
local id = #Root
+ if ENVMode == '_ENV' then
+ Compile({
+ type = 'local',
+ start = 0,
+ finish = 0,
+ effect = 0,
+ [1] = '_ENV',
+ }, id)
+ end
CompileBlock(obj, id)
Block = nil
return id
@@ -584,6 +593,12 @@ return function (self, lua, mode, version)
end
pushError = state.pushError
Root = state.root
+ Version = version
+ if version == 'Lua 5.1' or version == 'LuaJIT' then
+ ENVMode = 'fenv'
+ else
+ ENVMode = '_ENV'
+ end
Cache = {}
GoToTag = {}
if type(state.ast) == 'table' then
diff --git a/server-beta/src/parser/guide.lua b/server-beta/src/parser/guide.lua
index bd7f9150..7925abe2 100644
--- a/server-beta/src/parser/guide.lua
+++ b/server-beta/src/parser/guide.lua
@@ -1,6 +1,4 @@
local error = error
-local utf8Len = utf8.len
-local utf8Offset = utf8.offset
local type = type
_ENV = nil
@@ -199,18 +197,19 @@ end
--- 遍历所有某种类型的source
function m.eachSourceOf(root, types, callback)
if type(types) == 'string' then
- types = {[types] = true}
+ types = {[types] = callback}
elseif type(types) == 'table' then
for i = 1, #types do
- types[types[i]] = true
+ types[types[i]] = callback
end
else
return
end
for i = 1, #root do
local source = root[i]
- if types[source.type] then
- callback(source)
+ local f = types[source.type]
+ if f then
+ f(source)
end
end
end
diff --git a/server-beta/test/definition/set.lua b/server-beta/test/definition/set.lua
index ec582d38..b7e7e5d4 100644
--- a/server-beta/test/definition/set.lua
+++ b/server-beta/test/definition/set.lua
@@ -45,6 +45,6 @@ print(<?x?>)
]]
TEST [[
-rawset<!(_G, 'x', 1)!>
+<!rawset(_G, 'x', 1)!>
print(<?x?>)
]]