summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/src/core/rename.lua69
-rw-r--r--server/src/vm/value.lua3
-rw-r--r--server/src/vm/vm.lua18
-rw-r--r--server/test/main.lua1
-rw-r--r--server/test/rename/init.lua85
5 files changed, 151 insertions, 25 deletions
diff --git a/server/src/core/rename.lua b/server/src/core/rename.lua
index f3132f77..349ec96b 100644
--- a/server/src/core/rename.lua
+++ b/server/src/core/rename.lua
@@ -1,15 +1,23 @@
local findSource = require 'core.find_source'
local parser = require 'parser'
-local function parseResult(result, source, newName)
+local function parseResult(source, newName)
local positions = {}
- local tp = result.type
- if tp == 'local' or tp == 'field' then
- local key = source[1]
- if result.hide then
- return positions
+ if source:bindLabel() then
+ if not parser.grammar(newName, 'Name') then
+ return nil
+ end
+ source:bindLabel():eachInfo(function (info)
+ positions[#positions+1] = {info.source.start, info.source.finish}
+ end)
+ return positions
+ end
+ if source:bindLocal() then
+ local loc = source:bindLocal()
+ if loc:get 'hide' then
+ return nil
end
- if source.index then
+ if source:get 'in index' then
if not parser.grammar(newName, 'Exp') then
return positions
end
@@ -19,31 +27,46 @@ local function parseResult(result, source, newName)
end
end
local mark = {}
- for _, info in ipairs(result) do
+ loc:eachInfo(function (info)
if not mark[info.source] then
mark[info.source] = info
- if info.source[1] == key then
- positions[#positions+1] = {info.source.start, info.source.finish}
- end
+ positions[#positions+1] = {info.source.start, info.source.finish}
+ end
+ end)
+ return positions
+ end
+ if source:bindValue() then
+ if source:get 'in index' then
+ if not parser.grammar(newName, 'Exp') then
+ return positions
+ end
+ else
+ if not parser.grammar(newName, 'Name') then
+ return positions
end
end
- elseif tp == 'label' then
- if not parser.grammar(newName, 'Name') then
- return positions
- end
- local label = result.label
- for _, info in ipairs(label) do
- positions[#positions+1] = {info.source.start, info.source.finish}
- end
+ local parent = source:get 'parent'
+ local mark = {}
+ parent:eachInfo(function (info)
+ if not mark[info.source] then
+ mark[info.source] = info
+ if info.type == 'get child' or info.type == 'set child' then
+ if info[1] == source[1] then
+ positions[#positions+1] = {info.source.start, info.source.finish}
+ end
+ end
+ end
+ end)
+ return positions
end
- return positions
+ return nil
end
return function (vm, pos, newName)
- local result, source = findSource(vm, pos)
- if not result then
+ local source = findSource(vm, pos)
+ if not source then
return nil
end
- local positions = parseResult(result, source, newName)
+ local positions = parseResult(source, newName)
return positions
end
diff --git a/server/src/vm/value.lua b/server/src/vm/value.lua
index 695734e7..8124b484 100644
--- a/server/src/vm/value.lua
+++ b/server/src/vm/value.lua
@@ -237,13 +237,14 @@ function mt:mergeValue(value)
end
end
-function mt:addInfo(tp, source)
+function mt:addInfo(tp, source, ...)
if source and not source.start then
error('Miss start: ' .. table.dump(source))
end
self[#self+1] = {
type = tp,
source = source or getDefaultSource(),
+ ...
}
end
diff --git a/server/src/vm/vm.lua b/server/src/vm/vm.lua
index 07e6917f..f965b429 100644
--- a/server/src/vm/vm.lua
+++ b/server/src/vm/vm.lua
@@ -54,10 +54,12 @@ function mt:buildTable(source)
key:bindValue(value, 'set')
if key.index then
local index = self:getIndex(key)
+ tbl:addInfo('set child', key, index)
tbl:setChild(index, value)
else
if key.type == 'name' then
key:set('table index', true)
+ tbl:addInfo('set child', key, key[1])
tbl:setChild(key[1], value)
end
end
@@ -67,14 +69,17 @@ function mt:buildTable(source)
if index == #source then
value:eachValue(function (_, v)
n = n + 1
+ tbl:addInfo('set child', obj, n)
tbl:setChild(n, v)
end)
else
n = n + 1
+ tbl:addInfo('set child', obj, n)
tbl:setChild(n, self:getFirstInMulti(value))
end
else
n = n + 1
+ tbl:addInfo('set child', obj, n)
tbl:setChild(n, value)
end
-- 处理写了一半的 key = value,把name类的数组元素视为哈希键
@@ -349,6 +354,7 @@ function mt:getName(name, source)
end
local ENV = self:loadLocal('_ENV')
local ENVValue = ENV:getValue()
+ ENVValue:addInfo('get child', source, name)
global = ENVValue:getChild(name, source)
source:bindValue(global, 'get')
source:set('global', true)
@@ -371,6 +377,7 @@ function mt:setName(name, source, value)
local ENV = self:loadLocal('_ENV')
local ENVValue = ENV:getValue()
source:bindValue(value, 'set')
+ ENVValue:addInfo('set child', source, name)
ENVValue:setChild(name, value)
source:set('global', true)
source:set('parentValue', ENVValue)
@@ -379,6 +386,7 @@ end
function mt:getIndex(source)
if source.type == 'name' then
local value = self:getName(source[1], source)
+ source:set('in index', true)
return value
elseif source.type == 'string' or source.type == 'number' or source.type == 'boolean' then
return source[1]
@@ -477,11 +485,13 @@ function mt:getSimple(simple, max)
source:set('parent', value)
local child = source[1]
local index = self:getIndex(child)
+ value:addInfo('get child', source, index)
value = value:getChild(index, source)
source:bindValue(value, 'get')
elseif source.type == 'name' then
source:set('parent', value)
source:set('object', object)
+ value:addInfo('get child', source, source[1])
value = value:getChild(source[1], source)
source:bindValue(value, 'get')
elseif source.type == ':' then
@@ -746,12 +756,14 @@ function mt:setOne(var, value)
local key = var[#var]
self:instantSource(key)
key:set('simple', var)
- key:set('parent', value)
+ key:set('parent', parent)
if key.type == 'index' then
local index = self:getIndex(key[1])
+ parent:addInfo('set child', key[1], index)
parent:setChild(index, value)
elseif key.type == 'name' then
local index = key[1]
+ parent:addInfo('set child', key, index)
parent:setChild(index, value)
end
key:bindValue(value, 'set')
@@ -878,9 +890,11 @@ function mt:doFunction(action)
source:set('object', parent)
if source.type == 'index' then
local index = self:getIndex(source[1])
+ parent:addInfo('set child', source[1], index)
parent:setChild(index, value)
elseif source.type == 'name' then
local index = source[1]
+ parent:addInfo('set child', source, index)
parent:setChild(index, value)
end
source:bindValue(value, 'set')
@@ -903,9 +917,11 @@ function mt:doFunction(action)
self:instantSource(source)
if source.type == 'index' then
local index = self:getIndex(source[1])
+ parent:addInfo('set child', source[1], index)
parent:setChild(index, value)
elseif source.type == 'name' then
local index = source[1]
+ parent:addInfo('set child', source, index)
parent:setChild(index, value)
end
source:bindValue(value, 'set')
diff --git a/server/test/main.lua b/server/test/main.lua
index 7a498929..e6fdcbe2 100644
--- a/server/test/main.lua
+++ b/server/test/main.lua
@@ -27,6 +27,7 @@ local function main()
test 'core'
test 'full'
test 'definition'
+ test 'rename'
test 'diagnostics'
test 'type_inference'
test 'find_lib'
diff --git a/server/test/rename/init.lua b/server/test/rename/init.lua
new file mode 100644
index 00000000..09900e3d
--- /dev/null
+++ b/server/test/rename/init.lua
@@ -0,0 +1,85 @@
+local core = require 'core'
+local parser = require 'parser'
+local buildVM = require 'vm'
+
+local function catch_target(script)
+ local list = {}
+ local cur = 1
+ while true do
+ local start, finish = script:find('<[!?].-[!?]>', cur)
+ if not start then
+ break
+ end
+ list[#list+1] = { start + 2, finish - 2 }
+ cur = finish + 1
+ end
+ return list
+end
+
+local function founded(targets, results)
+ if #targets ~= #results then
+ return false
+ end
+ for _, target in ipairs(targets) do
+ for _, result in ipairs(results) do
+ if target[1] == result[1] and target[2] == result[2] then
+ goto NEXT
+ end
+ end
+ do return false end
+ ::NEXT::
+ end
+ return true
+end
+
+function TEST(newName)
+ return function (script)
+ local target = catch_target(script)
+ local start = script:find('<?', 1, true)
+ local finish = script:find('?>', 1, true)
+ local pos = (start + finish) // 2 + 1
+ local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ')
+ local ast = parser:ast(new_script)
+ assert(ast)
+ local vm = buildVM(ast)
+ assert(vm)
+
+ local positions = core.rename(vm, pos, newName)
+ if positions then
+ assert(founded(target, positions))
+ else
+ assert(#target == 0)
+ end
+ end
+end
+
+TEST 'b' [[
+local <?a?> = 1
+]]
+
+TEST 'b' [[
+local <?a?> = 1
+<!a!> = 2
+<!a!> = <!a!>
+]]
+
+TEST 'b' [[
+t.<?a?> = 1
+a = t.<!a!>
+]]
+
+TEST 'b' [[
+t[<!'a'!>] = 1
+a = t.<?a?>
+]]
+
+TEST 'b' [[
+:: <?a?> ::
+goto <!a!>
+]]
+
+TEST 'b' [[
+local function f(<!a!>)
+ return <?a?>
+end
+]]