summaryrefslogtreecommitdiff
path: root/server-beta
diff options
context:
space:
mode:
Diffstat (limited to 'server-beta')
-rw-r--r--server-beta/src/core/rename.lua202
-rw-r--r--server-beta/src/parser/ast.lua8
-rw-r--r--server-beta/test/rename/init.lua131
3 files changed, 254 insertions, 87 deletions
diff --git a/server-beta/src/core/rename.lua b/server-beta/src/core/rename.lua
index a2b1b3fc..65010e0d 100644
--- a/server-beta/src/core/rename.lua
+++ b/server-beta/src/core/rename.lua
@@ -2,52 +2,151 @@ local files = require 'files'
local searcher = require 'searcher'
local guide = require 'parser.guide'
-local function checkSource(source)
- if source.type == 'field'
- or source.type == 'method'
- or source.type == 'tablefield'
- or source.type == 'string'
- or source.type == 'local'
- or source.type == 'setlocal'
- or source.type == 'getlocal'
- or source.type == 'setglobal'
- or source.type == 'getglobal'
- or source.type == 'label'
- or source.type == 'goto' then
- return true
+local function isValidName(str)
+ return str:match '^[%a_][%w_]*$'
+end
+
+local function forceReplace(name)
+ return true
+end
+
+local function ofLocal(source, newname, callback)
+ if not isValidName(newname) and not forceReplace(newname) then
+ return false
+ end
+ callback(source, source.start, source.finish, newname)
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback(ref, ref.start, ref.finish, newname)
+ end
end
- return false
end
-local function rename(source)
- if source.type == 'field'
- or source.type == 'method'
- or source.type == 'tablefield'
- or source.type == 'string'
- or source.type == 'local'
- or source.type == 'setlocal'
- or source.type == 'getlocal'
- or source.type == 'setglobal'
- or source.type == 'getglobal'
- or source.type == 'label'
- or source.type == 'goto' then
- return source
+local esc = {
+ ["'"] = [[\']],
+ ['"'] = [[\"]],
+ ['\r'] = [[\r]],
+ ['\n'] = [[\n]],
+}
+
+local function toString(quo, newstr)
+ if quo == "'" then
+ return quo .. newstr:gsub([=[['\r\n]]=], esc) .. quo
+ elseif quo == '"' then
+ return quo .. newstr:gsub([=[["\r\n]]=], esc) .. quo
+ else
+ if newstr:find([[\r]], 1, true) then
+ return toString('"', newstr)
+ end
+ local eqnum = #quo - 2
+ local fsymb = ']' .. ('='):rep(eqnum) .. ']'
+ if not newstr:find(fsymb, 1, true) then
+ return quo .. newstr .. fsymb
+ end
+ for i = 0, 100 do
+ local fsymb = ']' .. ('='):rep(i) .. ']'
+ if not newstr:find(fsymb, 1, true) then
+ local ssymb = '[' .. ('='):rep(i) .. '['
+ return ssymb .. newstr .. fsymb
+ end
+ end
+ return toString('"', newstr)
end
- if source.type == 'setfield'
- or source.type == 'getfield'
- or source.type == 'tablefield' then
- return source.field
+end
+
+local function renameField(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return true
+ end
+ local parent = source.parent
+ if parent.type == 'setfield'
+ or parent.type == 'getfield' then
+ local dot = parent.dot
+ local newstr = '[' .. toString('"', newname) .. ']'
+ callback(source, dot.start, source.finish, newstr)
+ elseif parent.type == 'tablefield' then
+ local newstr = '[' .. toString('"', newname) .. ']'
+ callback(source, source.start, source.finish, newstr)
+ else
+ if not forceReplace(newname) then
+ return false
+ end
+ callback(source, source.start, source.finish, newname)
+ end
+ return true
+end
+
+local function renameGlobal(source, newname, callback)
+ if isValidName(newname) then
+ callback(source, source.start, source.finish, newname)
+ return false
end
- if source.type == 'setindex'
- or source.type == 'getindex'
- or source.type == 'tableindex' then
- return source.index
+ local newstr = '_ENV[' .. toString('"', newname) .. ']'
+ callback(source, source.start, source.finish, newstr)
+ return true
+end
+
+local function ofField(source, newname, callback)
+ return searcher.eachRef(source, function (info)
+ local src = info.source
+ if src.type == 'tablefield'
+ or src.type == 'getfield'
+ or src.type == 'setfield' then
+ src = src.field
+ elseif src.type == 'tableindex'
+ or src.type == 'getindex'
+ or src.type == 'setindex' then
+ src = src.index
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ src = src.method
+ end
+ if src.type == 'string' then
+ local quo = src[2]
+ local text = toString(quo, newname)
+ callback(src, src.start, src.finish, text)
+ return
+ elseif src.type == 'field'
+ or src.type == 'method' then
+ local suc = renameField(src, newname, callback)
+ if not suc then
+ return false
+ end
+ elseif src.type == 'setglobal'
+ or src.type == 'getglobal' then
+ local suc = renameGlobal(src, newname, callback)
+ if not suc then
+ return false
+ end
+ end
+ end)
+end
+
+local function rename(source, newname, callback)
+ if source.type == 'label'
+ or source.type == 'goto' then
+ if not isValidName(newname) and not forceReplace(newname)then
+ return false
+ end
+ searcher.eachRef(source, function (info)
+ callback(info.source, info.source.start, info.source.finish, newname)
+ end)
end
- if source.type == 'setmethod'
- or source.type == 'getmethod' then
- return source.method
+ if source.type == 'local' then
+ return ofLocal(source, newname, callback)
+ elseif source.type == 'setlocal'
+ or source.type == 'getlocal' then
+ return ofLocal(source.node, newname, callback)
+ elseif source.type == 'field'
+ or source.type == 'method'
+ or source.type == 'tablefield'
+ or source.type == 'string'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal' then
+ return ofField(source, newname, callback)
end
- return nil
+ return true
end
return function (uri, pos, newname)
@@ -56,23 +155,24 @@ return function (uri, pos, newname)
return nil
end
local results = {}
+
+ local ok = true
guide.eachSourceContain(ast.ast, pos, function(source)
- if not checkSource(source) then
- return
- end
- searcher.eachRef(source, function (info)
- local src = rename(info.source)
- if not src then
- return
- end
+ local suc = rename(source, newname, function (target, start, finish, text)
results[#results+1] = {
- start = src.start,
- finish = src.finish,
- text = newname,
- uri = guide.getRoot(src).uri,
+ start = start,
+ finish = finish,
+ text = text,
+ uri = guide.getRoot(target).uri,
}
end)
+ if suc == false then
+ ok = false
+ end
end)
+ if not ok then
+ return nil
+ end
if #results == 0 then
return nil
end
diff --git a/server-beta/src/parser/ast.lua b/server-beta/src/parser/ast.lua
index c5303796..b87be927 100644
--- a/server-beta/src/parser/ast.lua
+++ b/server-beta/src/parser/ast.lua
@@ -1732,15 +1732,15 @@ local Defs = {
local function init(state)
State = state
- PushError = state.pushError
- PushDiag = state.pushDiag
+ PushError = state.pushError or function () end
+ PushDiag = state.pushDiag or function () end
emmy.init(State)
end
local function close()
State = nil
- PushError = nil
- PushDiag = nil
+ PushError = function () end
+ PushDiag = function () end
end
return {
diff --git a/server-beta/test/rename/init.lua b/server-beta/test/rename/init.lua
index fdb47ebe..92612f0f 100644
--- a/server-beta/test/rename/init.lua
+++ b/server-beta/test/rename/init.lua
@@ -34,52 +34,119 @@ local function founded(targets, results)
return true
end
-function TEST(newName)
- return function (script)
- files.removeAll()
- local target = catch_target(script)
- local start = script:find('<?', 1, true)
- local finish = script:find('?>', 1, true)
- local pos = (start + finish) // 2 + 1
- local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ')
- files.setText('', new_script)
+local function replace(text, positions)
+ local buf = {}
+ table.sort(positions, function (a, b)
+ return a.start < b.start
+ end)
+ local lastPos = 1
+ for _, info in ipairs(positions) do
+ buf[#buf+1] = text:sub(lastPos, info.start - 1)
+ buf[#buf+1] = info.text
+ lastPos = info.finish + 1
+ end
+ buf[#buf+1] = text:sub(lastPos)
+ return table.concat(buf)
+end
- local positions = core('', pos, newName)
- if positions then
- assert(founded(target, positions))
- else
- assert(#target == 0)
+function TEST(oldName, newName)
+ return function (oldScript)
+ return function (newScript)
+ files.removeAll()
+ files.setText('', oldScript)
+ local pos = oldScript:find('[^%w_]'..oldName..'[^%w_]')
+ assert(pos)
+
+ local positions = core('', pos+1, newName)
+ local script = oldScript
+ if positions then
+ script = replace(script, positions)
+ end
+ assert(script == newScript)
end
end
end
-TEST 'b' [[
-local <?a?> = 1
+TEST ('a', 'b') [[
+local a = 1
+]] [[
+local b = 1
]]
-TEST 'b' [[
-local <?a?> = 1
-<!a!> = 2
-<!a!> = <!a!>
+TEST ('a', 'b') [[
+local a = 1
+a = 2
+a = a
+]] [[
+local b = 1
+b = 2
+b = b
]]
-TEST 'b' [[
-t.<?a?> = 1
-a = t.<!a!>
+TEST ('a', 'b') [[
+t.a = 1
+a = t.a
+a = t['a']
+a = t["a"]
+a = t[ [=[a]=] ]
+]] [[
+t.b = 1
+a = t.b
+a = t['b']
+a = t["b"]
+a = t[ [=[b]=] ]
]]
-TEST 'b' [[
-t[<!'a'!>] = 1
-a = t.<?a?>
+TEST ('a', 'b') [[
+:: a ::
+goto a
+]] [[
+:: b ::
+goto b
]]
-TEST 'b' [[
-:: <?a?> ::
-goto <!a!>
+TEST ('a', 'b') [[
+local function f(a)
+ return a
+end
+]] [[
+local function f(b)
+ return b
+end
]]
-TEST 'b' [[
-local function f(<!a!>)
- return <?a?>
+TEST ('a', '!!!') [[
+t = {
+ a = 0
+}
+t.a = 1
+a = t.a
+]] [[
+t = {
+ ["!!!"] = 0
+}
+t["!!!"] = 1
+a = t["!!!"]
+]]
+
+TEST ('a', '"') [[
+print(t[ "a" ])
+]] [[
+print(t[ "\"" ])
+]]
+
+TEST ('a', '!!!') [[
+function mt:a()
end
+mt:a()
+]] [[
+function mt:!!!()
+end
+mt:!!!()
+]]
+
+TEST ('a', '!!!') [[
+a = a
+]] [[
+_ENV["!!!"] = _ENV["!!!"]
]]