diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2019-11-22 23:26:32 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2019-11-22 23:26:32 +0800 |
commit | d0ff66c9abe9d6abbca12fd811e0c3cb69c1033a (patch) | |
tree | bb34518d70b85de7656dbdbe958dfa221a3ff3b3 /script-beta/src/core | |
parent | 0a2c2ad15e1ec359171fb0dd4c72e57c5b66e9ba (diff) | |
download | lua-language-server-d0ff66c9abe9d6abbca12fd811e0c3cb69c1033a.zip |
整理一下目录结构
Diffstat (limited to 'script-beta/src/core')
29 files changed, 1961 insertions, 0 deletions
diff --git a/script-beta/src/core/definition.lua b/script-beta/src/core/definition.lua new file mode 100644 index 00000000..865fc7cb --- /dev/null +++ b/script-beta/src/core/definition.lua @@ -0,0 +1,105 @@ +local guide = require 'parser.guide' +local workspace = require 'workspace' +local files = require 'files' +local vm = require 'vm' + +local function findDef(source, callback) + if source.type ~= 'local' + and source.type ~= 'getlocal' + and source.type ~= 'setlocal' + and source.type ~= 'setglobal' + and source.type ~= 'getglobal' + and source.type ~= 'field' + and source.type ~= 'method' + and source.type ~= 'string' + and source.type ~= 'number' + and source.type ~= 'boolean' + and source.type ~= 'goto' then + return + end + vm.eachDef(source, function (info) + if info.mode == 'declare' + or info.mode == 'set' + or info.mode == 'return' then + local src = info.source + local root = guide.getRoot(src) + local uri = root.uri + if src.type == 'setfield' + or src.type == 'getfield' + or src.type == 'tablefield' then + callback(src.field, uri) + elseif src.type == 'setindex' + or src.type == 'getindex' + or src.type == 'tableindex' then + callback(src.index, uri) + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + callback(src.method, uri) + else + callback(src, uri) + end + end + end) +end + +local function checkRequire(source, offset, callback) + if source.type ~= 'call' then + return + end + local func = source.node + local pathSource = source.args and source.args[1] + if not pathSource then + return + end + if not guide.isContain(pathSource, offset) then + return + end + local literal = guide.getLiteral(pathSource) + if type(literal) ~= 'string' then + return + end + local name = func.special + if name == 'require' then + local result = workspace.findUrisByRequirePath(literal, true) + for _, uri in ipairs(result) do + callback(uri) + end + elseif name == 'dofile' + or name == 'loadfile' then + local result = workspace.findUrisByFilePath(literal, true) + for _, uri in ipairs(result) do + callback(uri) + end + end +end + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local results = {} + guide.eachSourceContain(ast.ast, offset, function (source) + checkRequire(source, offset, function (uri) + results[#results+1] = { + uri = files.getOriginUri(uri), + source = source, + target = { + start = 0, + finish = 0, + } + } + end) + findDef(source, function (target, uri) + results[#results+1] = { + target = target, + uri = files.getOriginUri(uri), + source = source, + } + end) + end) + if #results == 0 then + return nil + end + return results +end diff --git a/script-beta/src/core/diagnostics/ambiguity-1.lua b/script-beta/src/core/diagnostics/ambiguity-1.lua new file mode 100644 index 00000000..37815fb5 --- /dev/null +++ b/script-beta/src/core/diagnostics/ambiguity-1.lua @@ -0,0 +1,69 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +local opMap = { + ['+'] = true, + ['-'] = true, + ['*'] = true, + ['/'] = true, + ['//'] = true, + ['^'] = true, + ['<<'] = true, + ['>>'] = true, + ['&'] = true, + ['|'] = true, + ['~'] = true, + ['..'] = true, +} + +local literalMap = { + ['number'] = true, + ['boolean'] = true, + ['string'] = true, + ['table'] = true, +} + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local text = files.getText(uri) + guide.eachSourceType(ast.ast, 'binary', function (source) + if source.op.type ~= 'or' then + return + end + local first = source[1] + local second = source[2] + -- a + (b or 0) --> (a + b) or 0 + do + if opMap[first.op and first.op.type] + and first.type ~= 'unary' + and not second.op + and literalMap[second.type] + and not literalMap[first[2].type] + then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_AMBIGUITY_1', text:sub(first.start, first.finish)) + } + end + end + -- (a or 0) + c --> a or (0 + c) + do + if opMap[second.op and second.op.type] + and second.type ~= 'unary' + and not first.op + and literalMap[second[1].type] + then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_AMBIGUITY_1', text:sub(second.start, second.finish)) + } + end + end + end) +end diff --git a/script-beta/src/core/diagnostics/duplicate-index.lua b/script-beta/src/core/diagnostics/duplicate-index.lua new file mode 100644 index 00000000..76b1c958 --- /dev/null +++ b/script-beta/src/core/diagnostics/duplicate-index.lua @@ -0,0 +1,62 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'table', function (source) + local mark = {} + for _, obj in ipairs(source) do + if obj.type == 'tablefield' + or obj.type == 'tableindex' then + local name = guide.getKeyName(obj) + if name then + if not mark[name] then + mark[name] = {} + end + mark[name][#mark[name]+1] = obj.field or obj.index + end + end + end + + for name, defs in pairs(mark) do + local sname = name:match '^.|(.+)$' + if #defs > 1 and sname then + local related = {} + for i = 1, #defs do + local def = defs[i] + related[i] = { + start = def.start, + finish = def.finish, + uri = uri, + } + end + for i = 1, #defs - 1 do + local def = defs[i] + callback { + start = def.start, + finish = def.finish, + related = related, + message = lang.script('DIAG_DUPLICATE_INDEX', sname), + level = define.DiagnosticSeverity.Hint, + tags = { define.DiagnosticTag.Unnecessary }, + } + end + for i = #defs, #defs do + local def = defs[i] + callback { + start = def.start, + finish = def.finish, + related = related, + message = lang.script('DIAG_DUPLICATE_INDEX', sname), + } + end + end + end + end) +end diff --git a/script-beta/src/core/diagnostics/emmy-lua.lua b/script-beta/src/core/diagnostics/emmy-lua.lua new file mode 100644 index 00000000..b3d19c21 --- /dev/null +++ b/script-beta/src/core/diagnostics/emmy-lua.lua @@ -0,0 +1,3 @@ +return function () + +end diff --git a/script-beta/src/core/diagnostics/empty-block.lua b/script-beta/src/core/diagnostics/empty-block.lua new file mode 100644 index 00000000..2024f4e3 --- /dev/null +++ b/script-beta/src/core/diagnostics/empty-block.lua @@ -0,0 +1,49 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' + +-- 检查空代码块 +-- 但是排除忙等待(repeat/while) +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'if', function (source) + for _, block in ipairs(source) do + if #block > 0 then + return + end + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) + guide.eachSourceType(ast.ast, 'loop', function (source) + if #source > 0 then + return + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) + guide.eachSourceType(ast.ast, 'in', function (source) + if #source > 0 then + return + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) +end diff --git a/script-beta/src/core/diagnostics/global-in-nil-env.lua b/script-beta/src/core/diagnostics/global-in-nil-env.lua new file mode 100644 index 00000000..9a0d4f35 --- /dev/null +++ b/script-beta/src/core/diagnostics/global-in-nil-env.lua @@ -0,0 +1,66 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +-- TODO: 检查路径是否可达 +local function mayRun(path) + return true +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local root = guide.getRoot(ast.ast) + local env = guide.getENV(root) + + local nilDefs = {} + if not env.ref then + return + end + for _, ref in ipairs(env.ref) do + if ref.type == 'setlocal' then + if ref.value and ref.value.type == 'nil' then + nilDefs[#nilDefs+1] = ref + end + end + end + + if #nilDefs == 0 then + return + end + + local function check(source) + local node = source.node + if node.tag == '_ENV' then + local ok + for _, nilDef in ipairs(nilDefs) do + local mode, pathA = guide.getPath(nilDef, source) + if mode == 'before' + and mayRun(pathA) then + ok = nilDef + break + end + end + if ok then + callback { + start = source.start, + finish = source.finish, + uri = uri, + message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, + related = { + { + start = ok.start, + finish = ok.finish, + uri = uri, + } + } + } + end + end + end + + guide.eachSourceType(ast.ast, 'getglobal', check) + guide.eachSourceType(ast.ast, 'setglobal', check) +end diff --git a/script-beta/src/core/diagnostics/init.lua b/script-beta/src/core/diagnostics/init.lua new file mode 100644 index 00000000..0d523f26 --- /dev/null +++ b/script-beta/src/core/diagnostics/init.lua @@ -0,0 +1,41 @@ +local files = require 'files' +local define = require 'proto.define' +local config = require 'config' +local await = require 'await' + +local function check(uri, name, level, results) + if config.config.diagnostics.disable[name] then + return + end + level = config.config.diagnostics.severity[name] or level + local severity = define.DiagnosticSeverity[level] + local clock = os.clock() + require('core.diagnostics.' .. name)(uri, function (result) + result.level = severity or result.level + result.code = name + results[#results+1] = result + end, name) + local passed = os.clock() - clock + if passed >= 0.5 then + log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed)) + await.delay() + end +end + +return function (uri) + local ast = files.getAst(uri) + if not ast then + return nil + end + local results = {} + + for name, level in pairs(define.DiagnosticDefaultSeverity) do + check(uri, name, level, results) + end + + if #results == 0 then + return nil + end + + return results +end diff --git a/script-beta/src/core/diagnostics/lowercase-global.lua b/script-beta/src/core/diagnostics/lowercase-global.lua new file mode 100644 index 00000000..bc48e1e6 --- /dev/null +++ b/script-beta/src/core/diagnostics/lowercase-global.lua @@ -0,0 +1,39 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local config = require 'config' +local library = require 'library' + +-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删) +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + local definedGlobal = {} + for name in pairs(config.config.diagnostics.globals) do + definedGlobal[name] = true + end + for name in pairs(library.global) do + definedGlobal[name] = true + end + + guide.eachSourceType(ast.ast, 'setglobal', function (source) + local name = guide.getName(source) + if definedGlobal[name] then + return + end + local first = name:match '%w' + if not first then + return + end + if first:match '%l' then + callback { + start = source.start, + finish = source.finish, + message = lang.script.DIAG_LOWERCASE_GLOBAL, + } + end + end) +end diff --git a/script-beta/src/core/diagnostics/newfield-call.lua b/script-beta/src/core/diagnostics/newfield-call.lua new file mode 100644 index 00000000..75681cbc --- /dev/null +++ b/script-beta/src/core/diagnostics/newfield-call.lua @@ -0,0 +1,37 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + local lines = files.getLines(uri) + local text = files.getText(uri) + + guide.eachSourceType(ast.ast, 'table', function (source) + for i = 1, #source do + local field = source[i] + if field.type == 'call' then + local func = field.node + local args = field.args + if args then + local funcLine = guide.positionOf(lines, func.finish) + local argsLine = guide.positionOf(lines, args.start) + if argsLine > funcLine then + callback { + start = field.start, + finish = field.finish, + message = lang.script('DIAG_PREFIELD_CALL' + , text:sub(func.start, func.finish) + , text:sub(args.start, args.finish) + ) + } + end + end + end + end + end) +end diff --git a/script-beta/src/core/diagnostics/newline-call.lua b/script-beta/src/core/diagnostics/newline-call.lua new file mode 100644 index 00000000..cb318380 --- /dev/null +++ b/script-beta/src/core/diagnostics/newline-call.lua @@ -0,0 +1,38 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local lines = files.getLines(uri) + + guide.eachSourceType(ast.ast, 'call', function (source) + local node = source.node + local args = source.args + if not args then + return + end + + -- 必须有其他人在继续使用当前对象 + if not source.next then + return + end + + local nodeRow = guide.positionOf(lines, node.finish) + local argRow = guide.positionOf(lines, args.start) + if nodeRow == argRow then + return + end + + if #args == 1 then + callback { + start = args.start, + finish = args.finish, + message = lang.script.DIAG_PREVIOUS_CALL, + } + end + end) +end diff --git a/script-beta/src/core/diagnostics/redefined-local.lua b/script-beta/src/core/diagnostics/redefined-local.lua new file mode 100644 index 00000000..f6176794 --- /dev/null +++ b/script-beta/src/core/diagnostics/redefined-local.lua @@ -0,0 +1,32 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'local', function (source) + local name = source[1] + if name == '_' + or name == '_ENV' then + return + end + local exist = guide.getLocal(source, name, source.start-1) + if exist then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_REDEFINED_LOCAL', name), + related = { + { + start = exist.start, + finish = exist.finish, + uri = uri, + } + }, + } + end + end) +end diff --git a/script-beta/src/core/diagnostics/redundant-parameter.lua b/script-beta/src/core/diagnostics/redundant-parameter.lua new file mode 100644 index 00000000..ec14188e --- /dev/null +++ b/script-beta/src/core/diagnostics/redundant-parameter.lua @@ -0,0 +1,102 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local define = require 'proto.define' +local await = require 'await' + +local function countLibraryArgs(source) + local func = vm.getLibrary(source) + if not func then + return nil + end + local result = 0 + if not func.args then + return result + end + if func.args[#func.args].type == '...' then + return math.maxinteger + end + result = result + #func.args + return result +end + +local function countCallArgs(source) + local result = 0 + if not source.args then + return 0 + end + if source.node and source.node.type == 'getmethod' then + result = result + 1 + end + result = result + #source.args + return result +end + +local function countFuncArgs(source) + local result = 0 + if not source.args then + return result + end + if source.args[#source.args].type == '...' then + return math.maxinteger + end + if source.parent and source.parent.type == 'setmethod' then + result = result + 1 + end + result = result + #source.args + return result +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'call', function (source) + local callArgs = countCallArgs(source) + if callArgs == 0 then + return + end + + await.delay(function () + return files.globalVersion + end) + + local func = source.node + local funcArgs + vm.eachDef(func, function (info) + if info.mode == 'value' then + local src = info.source + if src.type == 'function' then + local args = countFuncArgs(src) + if not funcArgs or args > funcArgs then + funcArgs = args + end + end + end + end) + + funcArgs = funcArgs or countLibraryArgs(func) + if not funcArgs then + return + end + + local delta = callArgs - funcArgs + if delta <= 0 then + return + end + for i = #source.args - delta + 1, #source.args do + local arg = source.args[i] + if arg then + callback { + start = arg.start, + finish = arg.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs) + } + end + end + end) +end diff --git a/script-beta/src/core/diagnostics/redundant-value.lua b/script-beta/src/core/diagnostics/redundant-value.lua new file mode 100644 index 00000000..be483448 --- /dev/null +++ b/script-beta/src/core/diagnostics/redundant-value.lua @@ -0,0 +1,24 @@ +local files = require 'files' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback, code) + local ast = files.getAst(uri) + if not ast then + return + end + + local diags = ast.diags[code] + if not diags then + return + end + + for _, info in ipairs(diags) do + callback { + start = info.start, + finish = info.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_OVER_MAX_VALUES', info.max, info.passed) + } + end +end diff --git a/script-beta/src/core/diagnostics/trailing-space.lua b/script-beta/src/core/diagnostics/trailing-space.lua new file mode 100644 index 00000000..e54a6e60 --- /dev/null +++ b/script-beta/src/core/diagnostics/trailing-space.lua @@ -0,0 +1,55 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' + +local function isInString(ast, offset) + local result = false + guide.eachSourceType(ast, 'string', function (source) + if offset >= source.start and offset <= source.finish then + result = true + end + end) + return result +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local text = files.getText(uri) + local lines = files.getLines(uri) + for i = 1, #lines do + local start = lines[i].start + local range = lines[i].range + local lastChar = text:sub(range, range) + if lastChar ~= ' ' and lastChar ~= '\t' then + goto NEXT_LINE + end + if isInString(ast.ast, range) then + goto NEXT_LINE + end + local first = start + for n = range - 1, start, -1 do + local char = text:sub(n, n) + if char ~= ' ' and char ~= '\t' then + first = n + 1 + break + end + end + if first == start then + callback { + start = first, + finish = range, + message = lang.script.DIAG_LINE_ONLY_SPACE, + } + else + callback { + start = first, + finish = range, + message = lang.script.DIAG_LINE_POST_SPACE, + } + end + ::NEXT_LINE:: + end +end diff --git a/script-beta/src/core/diagnostics/undefined-env-child.lua b/script-beta/src/core/diagnostics/undefined-env-child.lua new file mode 100644 index 00000000..df096cb8 --- /dev/null +++ b/script-beta/src/core/diagnostics/undefined-env-child.lua @@ -0,0 +1,32 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + -- 再遍历一次 getglobal ,找出 _ENV 被重载的情况 + guide.eachSourceType(ast.ast, 'getglobal', function (source) + -- 单独验证自己是否在重载过的 _ENV 中有定义 + if source.node.tag == '_ENV' then + return + end + local setInENV = vm.eachRef(source, function (info) + if info.mode == 'set' then + return true + end + end) + if setInENV then + return + end + local key = source[1] + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_UNDEF_ENV_CHILD', key), + } + end) +end diff --git a/script-beta/src/core/diagnostics/undefined-global.lua b/script-beta/src/core/diagnostics/undefined-global.lua new file mode 100644 index 00000000..ed81ced3 --- /dev/null +++ b/script-beta/src/core/diagnostics/undefined-global.lua @@ -0,0 +1,63 @@ +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local library = require 'library' +local config = require 'config' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + local globalCache = {} + + -- 遍历全局变量,检查所有没有 mode['set'] 的全局变量 + local globals = vm.getGlobals(ast.ast) + for key, infos in pairs(globals) do + if infos.mode['set'] == true then + goto CONTINUE + end + if globalCache[key] then + goto CONTINUE + end + local skey = key and key:match '^s|(.+)$' + if not skey then + goto CONTINUE + end + if library.global[skey] then + goto CONTINUE + end + if config.config.diagnostics.globals[skey] then + goto CONTINUE + end + if globalCache[key] == nil then + local uris = files.findGlobals(key) + for i = 1, #uris do + local destAst = files.getAst(uris[i]) + local destGlobals = vm.getGlobals(destAst.ast) + if destGlobals[key] and destGlobals[key].mode['set'] then + globalCache[key] = true + goto CONTINUE + end + end + end + globalCache[key] = false + local message = lang.script('DIAG_UNDEF_GLOBAL', skey) + local otherVersion = library.other[skey] + local customVersion = library.custom[skey] + if otherVersion then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version)) + elseif customVersion then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customVersion, '/'))) + end + for _, info in ipairs(infos) do + callback { + start = info.source.start, + finish = info.source.finish, + message = message, + } + end + ::CONTINUE:: + end +end diff --git a/script-beta/src/core/diagnostics/unused-function.lua b/script-beta/src/core/diagnostics/unused-function.lua new file mode 100644 index 00000000..6c53cdf7 --- /dev/null +++ b/script-beta/src/core/diagnostics/unused-function.lua @@ -0,0 +1,45 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local define = require 'proto.define' +local lang = require 'language' +local await = require 'await' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + -- 只检查局部函数与全局函数 + guide.eachSourceType(ast.ast, 'function', function (source) + local parent = source.parent + if not parent then + return + end + if parent.type ~= 'local' + and parent.type ~= 'setlocal' + and parent.type ~= 'setglobal' then + return + end + local hasSet + local hasGet = vm.eachRef(source, function (info) + if info.mode == 'get' then + return true + elseif info.mode == 'set' + or info.mode == 'declare' then + hasSet = true + end + end) + if not hasGet and hasSet then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_FUNCTION, + } + end + await.delay(function () + return files.globalVersion + end) + end) +end diff --git a/script-beta/src/core/diagnostics/unused-label.lua b/script-beta/src/core/diagnostics/unused-label.lua new file mode 100644 index 00000000..e6d998ba --- /dev/null +++ b/script-beta/src/core/diagnostics/unused-label.lua @@ -0,0 +1,22 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'label', function (source) + if not source.ref then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LABEL', source[1]), + } + end + end) +end diff --git a/script-beta/src/core/diagnostics/unused-local.lua b/script-beta/src/core/diagnostics/unused-local.lua new file mode 100644 index 00000000..22b2e16b --- /dev/null +++ b/script-beta/src/core/diagnostics/unused-local.lua @@ -0,0 +1,46 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +local function hasGet(loc) + if not loc.ref then + return false + end + for _, ref in ipairs(loc.ref) do + if ref.type == 'getlocal' then + if not ref.next then + return true + end + local nextType = ref.next.type + if nextType ~= 'setmethod' + and nextType ~= 'setfield' + and nextType ~= 'setindex' then + return true + end + end + end + return false +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'local', function (source) + local name = source[1] + if name == '_' + or name == '_ENV' then + return + end + if not hasGet(source) then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LOCAL', name), + } + end + end) +end diff --git a/script-beta/src/core/diagnostics/unused-vararg.lua b/script-beta/src/core/diagnostics/unused-vararg.lua new file mode 100644 index 00000000..74cc08e7 --- /dev/null +++ b/script-beta/src/core/diagnostics/unused-vararg.lua @@ -0,0 +1,31 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'function', function (source) + local args = source.args + if not args then + return + end + + for _, arg in ipairs(args) do + if arg.type == '...' then + if not arg.ref then + callback { + start = arg.start, + finish = arg.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_VARARG, + } + end + end + end + end) +end diff --git a/script-beta/src/core/highlight.lua b/script-beta/src/core/highlight.lua new file mode 100644 index 00000000..61e3f91a --- /dev/null +++ b/script-beta/src/core/highlight.lua @@ -0,0 +1,230 @@ +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm' +local define = require 'proto.define' + +local function ofLocal(source, callback) + callback(source) + if source.ref then + for _, ref in ipairs(source.ref) do + callback(ref) + end + end +end + +local function ofField(source, uri, callback) + local parent = source.parent + if not parent then + return + end + local myKey = guide.getKeyName(source) + if parent.type == 'tableindex' + or parent.type == 'tablefield' then + local tbl = parent.parent + vm.eachField(tbl, function (info) + if info.key ~= myKey then + return + end + local destUri = guide.getRoot(info.source).uri + if destUri ~= uri then + return + end + callback(info.source) + end) + else + vm.eachField(parent.node, function (info) + if info.key ~= myKey then + return + end + local destUri = guide.getRoot(info.source).uri + if destUri ~= uri then + return + end + callback(info.source) + end) + end +end + +local function ofIndex(source, uri, callback) + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + ofField(source, uri, callback) + end +end + +local function ofLabel(source, callback) + vm.eachRef(source, function (info) + callback(info.source) + end) +end + +local function find(source, uri, callback) + if source.type == 'local' then + ofLocal(source, callback) + elseif source.type == 'getlocal' + or source.type == 'setlocal' then + ofLocal(source.node, callback) + elseif source.type == 'field' + or source.type == 'method' then + ofField(source, uri, callback) + elseif source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' then + ofIndex(source, uri, callback) + callback(source) + elseif source.type == 'nil' then + callback(source) + elseif source.type == 'goto' + or source.type == 'label' then + ofLabel(source, callback) + end +end + +local function checkInIf(source, text, offset) + -- 检查 end + local endA = source.finish - #'end' + 1 + local endB = source.finish + if offset >= endA + and offset <= endB + and text:sub(endA, endB) == 'end' then + return true + end + -- 检查每个子模块 + for _, block in ipairs(source) do + for i = 1, #block.keyword, 2 do + local start = block.keyword[i] + local finish = block.keyword[i+1] + if offset >= start and offset <= finish then + return true + end + end + end + return false +end + +local function makeIf(source, text, callback) + -- end + local endA = source.finish - #'end' + 1 + local endB = source.finish + if text:sub(endA, endB) == 'end' then + callback(endA, endB) + end + -- 每个子模块 + for _, block in ipairs(source) do + for i = 1, #block.keyword, 2 do + local start = block.keyword[i] + local finish = block.keyword[i+1] + callback(start, finish) + end + end + return false +end + +local function findKeyword(source, text, offset, callback) + if source.type == 'do' + or source.type == 'function' + or source.type == 'loop' + or source.type == 'in' + or source.type == 'while' + or source.type == 'repeat' then + local ok + for i = 1, #source.keyword, 2 do + local start = source.keyword[i] + local finish = source.keyword[i+1] + if offset >= start and offset <= finish then + ok = true + break + end + end + if ok then + for i = 1, #source.keyword, 2 do + local start = source.keyword[i] + local finish = source.keyword[i+1] + callback(start, finish) + end + end + elseif source.type == 'if' then + local ok = checkInIf(source, text, offset) + if ok then + makeIf(source, text, callback) + end + end +end + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local text = files.getText(uri) + local results = {} + local mark = {} + guide.eachSourceContain(ast.ast, offset, function (source) + find(source, uri, function (target) + local kind + if target.type == 'getfield' then + target = target.field + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setfield' + or target.type == 'tablefield' then + target = target.field + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getmethod' then + target = target.method + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setmethod' then + target = target.method + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getindex' then + target = target.index + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setindex' + or target.type == 'tableindex' then + target = target.index + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getlocal' + or target.type == 'getglobal' + or target.type == 'goto' then + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setlocal' + or target.type == 'local' + or target.type == 'setglobal' + or target.type == 'label' then + kind = define.DocumentHighlightKind.Write + elseif target.type == 'string' + or target.type == 'boolean' + or target.type == 'number' + or target.type == 'nil' then + kind = define.DocumentHighlightKind.Text + else + log.warn('Unknow target.type:', target.type) + return + end + if mark[target] then + return + end + mark[target] = true + results[#results+1] = { + start = target.start, + finish = target.finish, + kind = kind, + } + end) + findKeyword(source, text, offset, function (start, finish) + results[#results+1] = { + start = start, + finish = finish, + kind = define.DocumentHighlightKind.Write + } + end) + end) + if #results == 0 then + return nil + end + return results +end diff --git a/script-beta/src/core/hover/arg.lua b/script-beta/src/core/hover/arg.lua new file mode 100644 index 00000000..be344488 --- /dev/null +++ b/script-beta/src/core/hover/arg.lua @@ -0,0 +1,20 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local function asFunction(source) + if not source.args then + return '' + end + local args = {} + for i = 1, #source.args do + local arg = source.args[i] + args[i] = ('%s: %s'):format(guide.getName(arg), vm.getType(arg)) + end + return table.concat(args, ', ') +end + +return function (source) + if source.type == 'function' then + return asFunction(source) + end +end diff --git a/script-beta/src/core/hover/init.lua b/script-beta/src/core/hover/init.lua new file mode 100644 index 00000000..b99c14b2 --- /dev/null +++ b/script-beta/src/core/hover/init.lua @@ -0,0 +1,56 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local getLabel = require 'core.hover.label' + +local function getHoverAsFunction(source) + local values = vm.getValue(source) + local labels = {} + for _, value in ipairs(values) do + if value.type == 'function' then + labels[#labels+1] = getLabel(value.source) + end + end + + local label = table.concat(labels, '\n') + return { + label = label, + source = source, + } +end + +local function getHoverAsValue(source) + local label = getLabel(source) + return { + label = label, + source = source, + } +end + +local function getHover(source) + local isFunction = vm.hasType(source, 'function') + if isFunction then + return getHoverAsFunction(source) + else + return getHoverAsValue(source) + end +end + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local hover = guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'field' + or source.type == 'method' then + return getHover(source) + end + end) + return hover +end diff --git a/script-beta/src/core/hover/label.lua b/script-beta/src/core/hover/label.lua new file mode 100644 index 00000000..72ce60f4 --- /dev/null +++ b/script-beta/src/core/hover/label.lua @@ -0,0 +1,103 @@ +local buildName = require 'core.hover.name' +local buildArg = require 'core.hover.arg' +local buildReturn = require 'core.hover.return' +local buildTable = require 'core.hover.table' +local vm = require 'vm' +local util = require 'utility' + +local function asFunction(source) + local name = buildName(source) + local arg = buildArg(source) + local rtn = buildReturn(source) + local lines = {} + lines[1] = ('function %s(%s)'):format(name, arg) + lines[2] = rtn + return table.concat(lines, '\n') +end + +local function asLocal(source) + local name = buildName(source) + local type = vm.getType(source) + local literal = vm.getLiteral(source) + if type == 'table' then + type = buildTable(source) + end + if literal == nil then + return ('local %s: %s'):format(name, type) + else + return ('local %s: %s = %s'):format(name, type, util.viewLiteral(literal)) + end +end + +local function asGlobal(source) + local name = buildName(source) + local type = vm.getType(source) + local literal = vm.getLiteral(source) + if type == 'table' then + type = buildTable(source) + end + if literal == nil then + return ('global %s: %s'):format(name, type) + else + return ('global %s: %s = %s'):format(name, type, util.viewLiteral(literal)) + end +end + +local function isGlobalField(source) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + if source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' + or source.type == 'tablefield' then + local node = source.node + if node.type == 'setglobal' + or node.type == 'getglobal' then + return true + end + return isGlobalField(node) + else + return false + end +end + +local function asField(source) + if isGlobalField(source) then + return asGlobal(source) + end + local name = buildName(source) + local type = vm.getType(source) + local literal = vm.getLiteral(source) + if type == 'table' then + type = buildTable(source) + end + if literal == nil then + return ('field %s: %s'):format(name, type) + else + return ('field %s: %s = %s'):format(name, type, util.viewLiteral(literal)) + end +end + +return function (source) + if source.type == 'function' then + return asFunction(source) + elseif source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' then + return asLocal(source) + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + return asGlobal(source) + elseif source.type == 'getfield' + or source.type == 'setfield' + or source.type == 'getmethod' + or source.type == 'setmethod' + or source.type == 'tablefield' + or source.type == 'field' + or source.type == 'method' then + return asField(source) + end +end diff --git a/script-beta/src/core/hover/name.lua b/script-beta/src/core/hover/name.lua new file mode 100644 index 00000000..a22a8b5a --- /dev/null +++ b/script-beta/src/core/hover/name.lua @@ -0,0 +1,64 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local function asLocal(source) + return guide.getName(source) +end + +local function asMethod(source) + local class = vm.eachField(source.node, function (info) + if info.key == 's|type' or info.key == 's|__name' or info.key == 's|name' then + if info.value and info.value.type == 'string' then + return info.value[1] + end + end + end) + local node = class or guide.getName(source.node) or '?' + local method = guide.getName(source) + return ('%s:%s'):format(node, method) +end + +local function asField(source) + local class = vm.eachField(source.node, function (info) + if info.key == 's|type' or info.key == 's|__name' or info.key == 's|name' then + if info.value and info.value.type == 'string' then + return info.value[1] + end + end + end) + local node = class or guide.getName(source.node) or '?' + local method = guide.getName(source) + return ('%s.%s'):format(node, method) +end + +local function asGlobal(source) + return guide.getName(source) +end + +local function buildName(source) + if source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' then + return asLocal(source) or '' + end + if source.type == 'setglobal' + or source.type == 'getglobal' then + return asGlobal(source) or '' + end + if source.type == 'setmethod' + or source.type == 'getmethod' then + return asMethod(source) or '' + end + if source.type == 'setfield' + or source.tyoe == 'getfield' + or source.type == 'tablefield' then + return asField(source) or '' + end + local parent = source.parent + if parent then + return buildName(parent) + end + return '' +end + +return buildName diff --git a/script-beta/src/core/hover/return.lua b/script-beta/src/core/hover/return.lua new file mode 100644 index 00000000..c22626a6 --- /dev/null +++ b/script-beta/src/core/hover/return.lua @@ -0,0 +1,34 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local function asFunction(source) + if not source.returns then + return nil + end + local returns = {} + for _, rtn in ipairs(source.returns) do + for i = 1, #rtn do + local values = vm.getValue(rtn[i]) + returns[#returns+1] = values + end + break + end + if #returns == 0 then + return nil + end + local lines = {} + for i = 1, #returns do + if i == 1 then + lines[i] = (' -> %s'):format(vm.viewType(returns[i])) + else + lines[i] = ('% 3d. %s'):format(i, returns[i]) + end + end + return table.concat(lines, '\n') +end + +return function (source) + if source.type == 'function' then + return asFunction(source) + end +end diff --git a/script-beta/src/core/hover/table.lua b/script-beta/src/core/hover/table.lua new file mode 100644 index 00000000..9ed86692 --- /dev/null +++ b/script-beta/src/core/hover/table.lua @@ -0,0 +1,35 @@ +local vm = require 'vm' + +local function checkClass(source) +end + +return function (source) + local fields = {} + local class + vm.eachField(source, function (info) + if info.key == 's|type' or info.key == 's|__name' or info.key == 's|name' then + if info.value and info.value.type == 'string' then + class = info.value[1] + end + end + local type = vm.getType(info.source) + fields[#fields+1] = ('%s'):format(type) + end) + local fieldsBuf + if #fields == 0 then + fieldsBuf = '{}' + else + local lines = {} + lines[#lines+1] = '{' + for _, field in ipairs(fields) do + lines[#lines+1] = ' ' .. field + end + lines[#lines+1] = '}' + fieldsBuf = table.concat(lines, '\n') + end + if class then + return ('%s %s'):format(class, fieldsBuf) + else + return fieldsBuf + end +end diff --git a/script-beta/src/core/reference.lua b/script-beta/src/core/reference.lua new file mode 100644 index 00000000..7e265e97 --- /dev/null +++ b/script-beta/src/core/reference.lua @@ -0,0 +1,84 @@ +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm' + +local function isFunction(source, offset) + if source.type ~= 'function' then + return false + end + -- 必须点在 `function` 这个单词上才能查找函数引用 + return offset >= source.start and offset < source.start + #'function' +end + +local function findRef(source, offset, callback) + if source.type ~= 'local' + and source.type ~= 'getlocal' + and source.type ~= 'setlocal' + and source.type ~= 'setglobal' + and source.type ~= 'getglobal' + and source.type ~= 'field' + and source.type ~= 'tablefield' + and source.type ~= 'method' + and source.type ~= 'string' + and source.type ~= 'number' + and source.type ~= 'boolean' + and source.type ~= 'goto' + and source.type ~= 'label' + and not isFunction(source, offset) then + return + end + vm.eachRef(source, function (info) + if info.mode == 'declare' + or info.mode == 'set' + or info.mode == 'get' + or info.mode == 'return' then + local src = info.source + local root = guide.getRoot(src) + local uri = root.uri + if src.type == 'setfield' + or src.type == 'getfield' + or src.type == 'tablefield' then + callback(src.field, uri) + elseif src.type == 'setindex' + or src.type == 'getindex' + or src.type == 'tableindex' then + callback(src.index, uri) + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + callback(src.method, uri) + else + callback(src, uri) + end + end + if info.mode == 'value' then + local src = info.source + local root = guide.getRoot(src) + local uri = root.uri + if src.type == 'function' then + if src.parent.type == 'return' then + callback(src, uri) + end + end + end + end) +end + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local results = {} + guide.eachSourceContain(ast.ast, offset, function (source) + findRef(source, offset, function (target, uri) + results[#results+1] = { + target = target, + uri = files.getOriginUri(uri), + } + end) + end) + if #results == 0 then + return nil + end + return results +end diff --git a/script-beta/src/core/rename.lua b/script-beta/src/core/rename.lua new file mode 100644 index 00000000..3e4512da --- /dev/null +++ b/script-beta/src/core/rename.lua @@ -0,0 +1,374 @@ +local files = require 'files' +local vm = require 'vm' +local guide = require 'parser.guide' +local proto = require 'proto' +local define = require 'proto.define' +local util = require 'utility' + +local Forcing + +local function askForcing(str) + if TEST then + return true + end + if Forcing == false then + return false + end + local version = files.globalVersion + -- TODO + local item = proto.awaitRequest('window/showMessageRequest', { + type = define.MessageType.Warning, + message = ('[%s]不是有效的标识符,是否强制替换?'):format(str), + actions = { + { + title = '强制替换', + }, + { + title = '取消', + }, + } + }) + if version ~= files.globalVersion then + Forcing = false + proto.notify('window/showMessage', { + type = define.MessageType.Warning, + message = '文件发生了变化,替换取消。' + }) + return false + end + if not item then + Forcing = false + return false + end + if item.title == '强制替换' then + Forcing = true + return true + else + Forcing = false + return false + end +end + +local function askForMultiChange(results, newname) + if TEST then + return true + end + local uris = {} + for _, result in ipairs(results) do + local uri = result.uri + if not uris[uri] then + uris[uri] = 0 + uris[#uris+1] = uri + end + uris[uri] = uris[uri] + 1 + end + if #uris <= 1 then + return true + end + + local version = files.globalVersion + -- TODO + local item = proto.awaitRequest('window/showMessageRequest', { + type = define.MessageType.Warning, + message = ('将修改 %d 个文件,共 %d 处。'):format( + #uris, + #results + ), + actions = { + { + title = '继续', + }, + { + title = '放弃', + }, + } + }) + if version ~= files.globalVersion then + proto.notify('window/showMessage', { + type = define.MessageType.Warning, + message = '文件发生了变化,替换取消。' + }) + return false + end + if item and item.title == '继续' then + local fileList = {} + for _, uri in ipairs(uris) do + fileList[#fileList+1] = ('%s (%d)'):format(uri, uris[uri]) + end + + log.debug(('Renamed [%s]\r\n%s'):format(newname, table.concat(fileList, '\r\n'))) + return true + end + return false +end + +local function trim(str) + return str:match '^%s*(%S+)%s*$' +end + +local function isValidName(str) + return str:match '^[%a_][%w_]*$' +end + +local function isValidGlobal(str) + for s in str:gmatch '[^%.]*' do + if not isValidName(trim(s)) then + return false + end + end + return true +end + +local function isValidFunctionName(str) + if isValidGlobal(str) then + return true + end + local pos = str:find(':', 1, true) + if not pos then + return false + end + return isValidGlobal(trim(str:sub(1, pos-1))) + and isValidName(trim(str:sub(pos+1))) +end + +local function isFunctionGlobalName(source) + local parent = source.parent + if parent.type ~= 'setglobal' then + return false + end + local value = parent.value + if not value.type ~= 'function' then + return false + end + return value.start <= parent.start +end + +local function renameLocal(source, newname, callback) + if isValidName(newname) then + callback(source, source.start, source.finish, newname) + return + end + if askForcing(newname) then + callback(source, source.start, source.finish, newname) + end +end + +local function renameField(source, newname, callback) + if isValidName(newname) then + callback(source, source.start, source.finish, newname) + return true + end + local parent = source.parent + if parent.type == 'setfield' + or parent.type == 'getfield' then + local dot = parent.dot + local newstr = '[' .. util.viewString(newname) .. ']' + callback(source, dot.start, source.finish, newstr) + elseif parent.type == 'tablefield' then + local newstr = '[' .. util.viewString(newname) .. ']' + callback(source, source.start, source.finish, newstr) + elseif parent.type == 'getmethod' then + if not askForcing(newname) then + return false + end + callback(source, source.start, source.finish, newname) + elseif parent.type == 'setmethod' then + local uri = guide.getRoot(source).uri + local text = files.getText(uri) + local func = parent.value + -- function mt:name () end --> mt['newname'] = function (self) end + local newstr = string.format('%s[%s] = function ' + , text:sub(parent.start, parent.node.finish) + , util.viewString(newname) + ) + callback(source, func.start, parent.finish, newstr) + local pl = text:find('(', parent.finish, true) + if pl then + if func.args then + callback(source, pl + 1, pl, 'self, ') + else + callback(source, pl + 1, pl, 'self') + end + end + end + return true +end + +local function renameGlobal(source, newname, callback) + if isValidGlobal(newname) then + callback(source, source.start, source.finish, newname) + return true + end + if isValidFunctionName(newname) then + if not isFunctionGlobalName(source) then + askForcing(newname) + end + callback(source, source.start, source.finish, newname) + return true + end + local newstr = '_ENV[' .. util.viewString(newname) .. ']' + -- function name () end --> _ENV['newname'] = function () end + if source.value and source.value.type == 'function' + and source.value.start < source.start then + callback(source, source.value.start, source.finish, newstr .. ' = function ') + return true + end + callback(source, source.start, source.finish, newstr) + return true +end + +local function ofLocal(source, newname, callback) + renameLocal(source, newname, callback) + if source.ref then + for _, ref in ipairs(source.ref) do + renameLocal(ref, newname, callback) + end + end +end + +local function ofField(source, newname, callback) + return vm.eachRef(source, function (info) + local src = info.source + if src.type == 'tablefield' + or src.type == 'getfield' + or src.type == 'setfield' then + src = src.field + elseif src.type == 'tableindex' + or src.type == 'getindex' + or src.type == 'setindex' then + src = src.index + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + src = src.method + end + if src.type == 'string' then + local quo = src[2] + local text = util.viewString(newname, quo) + callback(src, src.start, src.finish, text) + return + elseif src.type == 'field' + or src.type == 'method' then + local suc = renameField(src, newname, callback) + if not suc then + return false + end + elseif src.type == 'setglobal' + or src.type == 'getglobal' then + local suc = renameGlobal(src, newname, callback) + if not suc then + return false + end + end + end) +end + +local function rename(source, newname, callback) + if source.type == 'label' + or source.type == 'goto' then + if not isValidName(newname) and not askForcing(newname)then + return false + end + vm.eachRef(source, function (info) + callback(info.source, info.source.start, info.source.finish, newname) + end) + elseif source.type == 'local' then + return ofLocal(source, newname, callback) + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return ofLocal(source.node, newname, callback) + elseif source.type == 'field' + or source.type == 'method' + or source.type == 'tablefield' + or source.type == 'string' + or source.type == 'setglobal' + or source.type == 'getglobal' then + return ofField(source, newname, callback) + end + return true +end + +local function prepareRename(source) + if source.type == 'label' + or source.type == 'goto' + or source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'field' + or source.type == 'method' + or source.type == 'tablefield' + or source.type == 'setglobal' + or source.type == 'getglobal' then + return source, source[1] + elseif source.type == 'string' then + local parent = source.parent + if not parent then + return nil + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + return source, source[1] + end + return nil + end + return nil +end + +local m = {} + +function m.rename(uri, pos, newname) + local ast = files.getAst(uri) + if not ast then + return nil + end + local results = {} + + guide.eachSourceContain(ast.ast, pos, function(source) + rename(source, newname, function (target, start, finish, text) + results[#results+1] = { + start = start, + finish = finish, + text = text, + uri = guide.getRoot(target).uri, + } + end) + end) + + if Forcing == false then + Forcing = nil + return nil + end + + if #results == 0 then + return nil + end + + if not askForMultiChange(results, newname) then + return nil + end + + return results +end + +function m.prepareRename(uri, pos) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local result + guide.eachSourceContain(ast.ast, pos, function(source) + local res, text = prepareRename(source) + if res then + result = { + start = source.start, + finish = source.finish, + text = text, + } + end + end) + + return result +end + +return m |