diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/src/core/rename.lua | 69 | ||||
-rw-r--r-- | server/src/vm/value.lua | 3 | ||||
-rw-r--r-- | server/src/vm/vm.lua | 18 | ||||
-rw-r--r-- | server/test/main.lua | 1 | ||||
-rw-r--r-- | server/test/rename/init.lua | 85 |
5 files changed, 151 insertions, 25 deletions
diff --git a/server/src/core/rename.lua b/server/src/core/rename.lua index f3132f77..349ec96b 100644 --- a/server/src/core/rename.lua +++ b/server/src/core/rename.lua @@ -1,15 +1,23 @@ local findSource = require 'core.find_source' local parser = require 'parser' -local function parseResult(result, source, newName) +local function parseResult(source, newName) local positions = {} - local tp = result.type - if tp == 'local' or tp == 'field' then - local key = source[1] - if result.hide then - return positions + if source:bindLabel() then + if not parser.grammar(newName, 'Name') then + return nil + end + source:bindLabel():eachInfo(function (info) + positions[#positions+1] = {info.source.start, info.source.finish} + end) + return positions + end + if source:bindLocal() then + local loc = source:bindLocal() + if loc:get 'hide' then + return nil end - if source.index then + if source:get 'in index' then if not parser.grammar(newName, 'Exp') then return positions end @@ -19,31 +27,46 @@ local function parseResult(result, source, newName) end end local mark = {} - for _, info in ipairs(result) do + loc:eachInfo(function (info) if not mark[info.source] then mark[info.source] = info - if info.source[1] == key then - positions[#positions+1] = {info.source.start, info.source.finish} - end + positions[#positions+1] = {info.source.start, info.source.finish} + end + end) + return positions + end + if source:bindValue() then + if source:get 'in index' then + if not parser.grammar(newName, 'Exp') then + return positions + end + else + if not parser.grammar(newName, 'Name') then + return positions end end - elseif tp == 'label' then - if not parser.grammar(newName, 'Name') then - return positions - end - local label = result.label - for _, info in ipairs(label) do - positions[#positions+1] = {info.source.start, info.source.finish} - end + local parent = source:get 'parent' + local mark = {} + parent:eachInfo(function (info) + if not mark[info.source] then + mark[info.source] = info + if info.type == 'get child' or info.type == 'set child' then + if info[1] == source[1] then + positions[#positions+1] = {info.source.start, info.source.finish} + end + end + end + end) + return positions end - return positions + return nil end return function (vm, pos, newName) - local result, source = findSource(vm, pos) - if not result then + local source = findSource(vm, pos) + if not source then return nil end - local positions = parseResult(result, source, newName) + local positions = parseResult(source, newName) return positions end diff --git a/server/src/vm/value.lua b/server/src/vm/value.lua index 695734e7..8124b484 100644 --- a/server/src/vm/value.lua +++ b/server/src/vm/value.lua @@ -237,13 +237,14 @@ function mt:mergeValue(value) end end -function mt:addInfo(tp, source) +function mt:addInfo(tp, source, ...) if source and not source.start then error('Miss start: ' .. table.dump(source)) end self[#self+1] = { type = tp, source = source or getDefaultSource(), + ... } end diff --git a/server/src/vm/vm.lua b/server/src/vm/vm.lua index 07e6917f..f965b429 100644 --- a/server/src/vm/vm.lua +++ b/server/src/vm/vm.lua @@ -54,10 +54,12 @@ function mt:buildTable(source) key:bindValue(value, 'set') if key.index then local index = self:getIndex(key) + tbl:addInfo('set child', key, index) tbl:setChild(index, value) else if key.type == 'name' then key:set('table index', true) + tbl:addInfo('set child', key, key[1]) tbl:setChild(key[1], value) end end @@ -67,14 +69,17 @@ function mt:buildTable(source) if index == #source then value:eachValue(function (_, v) n = n + 1 + tbl:addInfo('set child', obj, n) tbl:setChild(n, v) end) else n = n + 1 + tbl:addInfo('set child', obj, n) tbl:setChild(n, self:getFirstInMulti(value)) end else n = n + 1 + tbl:addInfo('set child', obj, n) tbl:setChild(n, value) end -- 处理写了一半的 key = value,把name类的数组元素视为哈希键 @@ -349,6 +354,7 @@ function mt:getName(name, source) end local ENV = self:loadLocal('_ENV') local ENVValue = ENV:getValue() + ENVValue:addInfo('get child', source, name) global = ENVValue:getChild(name, source) source:bindValue(global, 'get') source:set('global', true) @@ -371,6 +377,7 @@ function mt:setName(name, source, value) local ENV = self:loadLocal('_ENV') local ENVValue = ENV:getValue() source:bindValue(value, 'set') + ENVValue:addInfo('set child', source, name) ENVValue:setChild(name, value) source:set('global', true) source:set('parentValue', ENVValue) @@ -379,6 +386,7 @@ end function mt:getIndex(source) if source.type == 'name' then local value = self:getName(source[1], source) + source:set('in index', true) return value elseif source.type == 'string' or source.type == 'number' or source.type == 'boolean' then return source[1] @@ -477,11 +485,13 @@ function mt:getSimple(simple, max) source:set('parent', value) local child = source[1] local index = self:getIndex(child) + value:addInfo('get child', source, index) value = value:getChild(index, source) source:bindValue(value, 'get') elseif source.type == 'name' then source:set('parent', value) source:set('object', object) + value:addInfo('get child', source, source[1]) value = value:getChild(source[1], source) source:bindValue(value, 'get') elseif source.type == ':' then @@ -746,12 +756,14 @@ function mt:setOne(var, value) local key = var[#var] self:instantSource(key) key:set('simple', var) - key:set('parent', value) + key:set('parent', parent) if key.type == 'index' then local index = self:getIndex(key[1]) + parent:addInfo('set child', key[1], index) parent:setChild(index, value) elseif key.type == 'name' then local index = key[1] + parent:addInfo('set child', key, index) parent:setChild(index, value) end key:bindValue(value, 'set') @@ -878,9 +890,11 @@ function mt:doFunction(action) source:set('object', parent) if source.type == 'index' then local index = self:getIndex(source[1]) + parent:addInfo('set child', source[1], index) parent:setChild(index, value) elseif source.type == 'name' then local index = source[1] + parent:addInfo('set child', source, index) parent:setChild(index, value) end source:bindValue(value, 'set') @@ -903,9 +917,11 @@ function mt:doFunction(action) self:instantSource(source) if source.type == 'index' then local index = self:getIndex(source[1]) + parent:addInfo('set child', source[1], index) parent:setChild(index, value) elseif source.type == 'name' then local index = source[1] + parent:addInfo('set child', source, index) parent:setChild(index, value) end source:bindValue(value, 'set') diff --git a/server/test/main.lua b/server/test/main.lua index 7a498929..e6fdcbe2 100644 --- a/server/test/main.lua +++ b/server/test/main.lua @@ -27,6 +27,7 @@ local function main() test 'core' test 'full' test 'definition' + test 'rename' test 'diagnostics' test 'type_inference' test 'find_lib' diff --git a/server/test/rename/init.lua b/server/test/rename/init.lua new file mode 100644 index 00000000..09900e3d --- /dev/null +++ b/server/test/rename/init.lua @@ -0,0 +1,85 @@ +local core = require 'core' +local parser = require 'parser' +local buildVM = require 'vm' + +local function catch_target(script) + local list = {} + local cur = 1 + while true do + local start, finish = script:find('<[!?].-[!?]>', cur) + if not start then + break + end + list[#list+1] = { start + 2, finish - 2 } + cur = finish + 1 + end + return list +end + +local function founded(targets, results) + if #targets ~= #results then + return false + end + for _, target in ipairs(targets) do + for _, result in ipairs(results) do + if target[1] == result[1] and target[2] == result[2] then + goto NEXT + end + end + do return false end + ::NEXT:: + end + return true +end + +function TEST(newName) + return function (script) + local target = catch_target(script) + local start = script:find('<?', 1, true) + local finish = script:find('?>', 1, true) + local pos = (start + finish) // 2 + 1 + local new_script = script:gsub('<[!?]', ' '):gsub('[!?]>', ' ') + local ast = parser:ast(new_script) + assert(ast) + local vm = buildVM(ast) + assert(vm) + + local positions = core.rename(vm, pos, newName) + if positions then + assert(founded(target, positions)) + else + assert(#target == 0) + end + end +end + +TEST 'b' [[ +local <?a?> = 1 +]] + +TEST 'b' [[ +local <?a?> = 1 +<!a!> = 2 +<!a!> = <!a!> +]] + +TEST 'b' [[ +t.<?a?> = 1 +a = t.<!a!> +]] + +TEST 'b' [[ +t[<!'a'!>] = 1 +a = t.<?a?> +]] + +TEST 'b' [[ +:: <?a?> :: +goto <!a!> +]] + +TEST 'b' [[ +local function f(<!a!>) + return <?a?> +end +]] |