diff options
Diffstat (limited to 'script/core')
49 files changed, 6048 insertions, 0 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua new file mode 100644 index 00000000..69304f98 --- /dev/null +++ b/script/core/code-action.lua @@ -0,0 +1,269 @@ +local files = require 'files' +local lang = require 'language' +local define = require 'proto.define' +local guide = require 'parser.guide' +local util = require 'utility' +local sp = require 'bee.subprocess' + +local function disableDiagnostic(uri, code, results) + results[#results+1] = { + title = lang.script('ACTION_DISABLE_DIAG', code), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_DISABLE_DIAG, + command = 'lua.config', + arguments = { + { + key = 'Lua.diagnostics.disable', + action = 'add', + value = code, + uri = uri, + } + } + } + } +end + +local function markGlobal(uri, name, results) + results[#results+1] = { + title = lang.script('ACTION_MARK_GLOBAL', name), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_MARK_GLOBAL, + command = 'lua.config', + arguments = { + { + key = 'Lua.diagnostics.globals', + action = 'add', + value = name, + uri = uri, + } + } + } + } +end + +local function changeVersion(uri, version, results) + results[#results+1] = { + title = lang.script('ACTION_RUNTIME_VERSION', version), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_RUNTIME_VERSION, + command = 'lua.config', + arguments = { + { + key = 'Lua.runtime.version', + action = 'set', + value = version, + uri = uri, + } + } + }, + } +end + +local function solveUndefinedGlobal(uri, diag, results) + local ast = files.getAst(uri) + local text = files.getText(uri) + local lines = files.getLines(uri) + local offset = define.offsetOfWord(lines, text, diag.range.start) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type ~= 'getglobal' then + return + end + + local name = guide.getName(source) + markGlobal(uri, name, results) + + -- TODO check other version + end) +end + +local function solveLowercaseGlobal(uri, diag, results) + local ast = files.getAst(uri) + local text = files.getText(uri) + local lines = files.getLines(uri) + local offset = define.offsetOfWord(lines, text, diag.range.start) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type ~= 'setglobal' then + return + end + + local name = guide.getName(source) + markGlobal(uri, name, results) + end) +end + +local function findSyntax(uri, diag) + local ast = files.getAst(uri) + local text = files.getText(uri) + local lines = files.getLines(uri) + for _, err in ipairs(ast.errs) do + if err.type:lower():gsub('_', '-') == diag.code then + local range = define.range(lines, text, err.start, err.finish) + if util.equal(range, diag.range) then + return err + end + end + end + return nil +end + +local function solveSyntaxByChangeVersion(uri, err, results) + if type(err.version) == 'table' then + for _, version in ipairs(err.version) do + changeVersion(uri, version, results) + end + else + changeVersion(uri, err.version, results) + end +end + +local function solveSyntaxByAddDoEnd(uri, err, results) + local text = files.getText(uri) + local lines = files.getLines(uri) + results[#results+1] = { + title = lang.script.ACTION_ADD_DO_END, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + range = define.range(lines, text, err.start, err.finish), + newText = ('do %s end'):format(text:sub(err.start, err.finish)), + }, + } + } + } + } +end + +local function solveSyntaxByFix(uri, err, results) + local text = files.getText(uri) + local lines = files.getLines(uri) + local changes = {} + for _, fix in ipairs(err.fix) do + changes[#changes+1] = { + range = define.range(lines, text, fix.start, fix.finish), + newText = fix.text, + } + end + results[#results+1] = { + title = lang.script['ACTION_' .. err.fix.title], + kind = 'quickfix', + edit = { + changes = { + [uri] = changes, + } + } + } +end + +local function solveSyntax(uri, diag, results) + local err = findSyntax(uri, diag) + if not err then + return + end + if err.version then + solveSyntaxByChangeVersion(uri, err, results) + end + if err.type == 'ACTION_AFTER_BREAK' or err.type == 'ACTION_AFTER_RETURN' then + solveSyntaxByAddDoEnd(uri, err, results) + end + if err.fix then + solveSyntaxByFix(uri, err, results) + end +end + +local function solveNewlineCall(uri, diag, results) + local text = files.getText(uri) + local lines = files.getLines(uri) + results[#results+1] = { + title = lang.script.ACTION_ADD_SEMICOLON, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + range = { + start = diag.range.start, + ['end'] = diag.range.start, + }, + newText = ';', + } + } + } + } + } +end + +local function solveAmbiguity1(uri, diag, results) + results[#results+1] = { + title = lang.script.ACTION_ADD_BRACKETS, + kind = 'quickfix', + command = { + title = lang.script.COMMAND_ADD_BRACKETS, + command = 'lua.solve:' .. sp:get_id(), + arguments = { + { + name = 'ambiguity-1', + uri = uri, + range = diag.range, + } + } + }, + } +end + +local function solveTrailingSpace(uri, diag, results) + results[#results+1] = { + title = lang.script.ACTION_REMOVE_SPACE, + kind = 'quickfix', + command = { + title = lang.script.COMMAND_REMOVE_SPACE, + command = 'lua.removeSpace:' .. sp:get_id(), + arguments = { + { + uri = uri, + } + } + }, + } +end + +local function solveDiagnostic(uri, diag, results) + if diag.source == lang.script.DIAG_SYNTAX_CHECK then + solveSyntax(uri, diag, results) + return + end + if not diag.code then + return + end + if diag.code == 'undefined-global' then + solveUndefinedGlobal(uri, diag, results) + elseif diag.code == 'lowercase-global' then + solveLowercaseGlobal(uri, diag, results) + elseif diag.code == 'newline-call' then + solveNewlineCall(uri, diag, results) + elseif diag.code == 'ambiguity-1' then + solveAmbiguity1(uri, diag, results) + elseif diag.code == 'trailing-space' then + solveTrailingSpace(uri, diag, results) + end + disableDiagnostic(uri, diag.code, results) +end + +return function (uri, range, diagnostics) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local results = {} + + for _, diag in ipairs(diagnostics) do + solveDiagnostic(uri, diag, results) + end + + return results +end diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua new file mode 100644 index 00000000..e8b09932 --- /dev/null +++ b/script/core/command/removeSpace.lua @@ -0,0 +1,56 @@ +local files = require 'files' +local define = require 'proto.define' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' + +local function isInString(ast, offset) + return guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'string' then + return true + end + end) or false +end + +return function (data) + local uri = data.uri + local lines = files.getLines(uri) + local text = files.getText(uri) + local ast = files.getAst(uri) + if not lines then + return + end + + local textEdit = {} + for i = 1, #lines do + local line = guide.lineContent(lines, text, i, true) + local pos = line:find '[ \t]+$' + if pos then + local start, finish = guide.lineRange(lines, i, true) + start = start + pos - 1 + if isInString(ast, start) then + goto NEXT_LINE + end + textEdit[#textEdit+1] = { + range = define.range(lines, text, start, finish), + newText = '', + } + goto NEXT_LINE + end + + ::NEXT_LINE:: + end + + if #textEdit == 0 then + return + end + + proto.awaitRequest('workspace/applyEdit', { + label = lang.script.COMMAND_REMOVE_SPACE, + edit = { + changes = { + [uri] = textEdit, + } + }, + }) +end diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua new file mode 100644 index 00000000..d3b8f94e --- /dev/null +++ b/script/core/command/solve.lua @@ -0,0 +1,96 @@ +local files = require 'files' +local define = require 'proto.define' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' + +local opMap = { + ['+'] = true, + ['-'] = true, + ['*'] = true, + ['/'] = true, + ['//'] = true, + ['^'] = true, + ['<<'] = true, + ['>>'] = true, + ['&'] = true, + ['|'] = true, + ['~'] = true, + ['..'] = true, +} + +local literalMap = { + ['number'] = true, + ['boolean'] = true, + ['string'] = true, + ['table'] = true, +} + +return function (data) + local uri = data.uri + local lines = files.getLines(uri) + local text = files.getText(uri) + local ast = files.getAst(uri) + if not ast then + return + end + + local start = define.offsetOfWord(lines, text, data.range.start) + local finish = define.offsetOfWord(lines, text, data.range['end']) + + local result = guide.eachSourceContain(ast.ast, start, function (source) + if source.start ~= start + or source.finish ~= finish then + return + end + if not source.op or source.op.type ~= 'or' then + return + end + local first = source[1] + local second = source[2] + -- a + b or 0 --> a + (b or 0) + do + if first.op + and opMap[first.op.type] + and first.type ~= 'unary' + and not second.op + and literalMap[second.type] then + return { + start = source[1][2].start, + finish = source[2].finish, + } + end + end + -- a or b + c --> (a or b) + c + do + if second.op + and opMap[second.op.type] + and second.type ~= 'unary' + and not first.op + and literalMap[second[1].type] then + return { + start = source[1].start, + finish = source[2][1].finish, + } + end + end + end) + + if not result then + return + end + + proto.awaitRequest('workspace/applyEdit', { + label = lang.script.COMMAND_REMOVE_SPACE, + edit = { + changes = { + [uri] = { + { + range = define.range(lines, text, result.start, result.finish), + newText = ('(%s)'):format(text:sub(result.start, result.finish)), + } + }, + } + }, + }) +end diff --git a/script/core/completion.lua b/script/core/completion.lua new file mode 100644 index 00000000..44874b39 --- /dev/null +++ b/script/core/completion.lua @@ -0,0 +1,1284 @@ +local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local matchKey = require 'core.matchkey' +local vm = require 'vm' +local getLabel = require 'core.hover.label' +local getName = require 'core.hover.name' +local getArg = require 'core.hover.arg' +local getDesc = require 'core.hover.description' +local getHover = require 'core.hover' +local config = require 'config' +local util = require 'utility' +local markdown = require 'provider.markdown' +local findSource = require 'core.find-source' +local await = require 'await' +local parser = require 'parser' +local keyWordMap = require 'core.keyword' +local workspace = require 'workspace' +local furi = require 'file-uri' +local rpath = require 'workspace.require-path' +local lang = require 'language' + +local stackID = 0 +local stacks = {} +local function stack(callback) + stackID = stackID + 1 + stacks[stackID] = callback + return stackID +end + +local function clearStack() + stacks = {} +end + +local function resolveStack(id) + local callback = stacks[id] + if not callback then + return nil + end + + -- 当进行新的 resolve 时,放弃当前的 resolve + await.close('completion.resove') + return await.await(callback, 'completion.resove') +end + +local function trim(str) + return str:match '^%s*(%S+)%s*$' +end + +local function isSpace(char) + if char == ' ' + or char == '\n' + or char == '\r' + or char == '\t' then + return true + end + return false +end + +local function skipSpace(text, offset) + for i = offset, 1, -1 do + local char = text:sub(i, i) + if not isSpace(char) then + return i + end + end + return 0 +end + +local function findWord(text, offset) + for i = offset, 1, -1 do + if not text:sub(i, i):match '[%w_]' then + if i == offset then + return nil + end + return text:sub(i+1, offset), i+1 + end + end + return text:sub(1, offset), 1 +end + +local function findSymbol(text, offset) + for i = offset, 1, -1 do + local char = text:sub(i, i) + if isSpace(char) then + goto CONTINUE + end + if char == '.' + or char == ':' + or char == '(' then + return char, i + else + return nil + end + ::CONTINUE:: + end + return nil +end + +local function findAnyPos(text, offset) + for i = offset, 1, -1 do + if not isSpace(text:sub(i, i)) then + return i + end + end + return nil +end + +local function findParent(ast, text, offset) + for i = offset, 1, -1 do + local char = text:sub(i, i) + if isSpace(char) then + goto CONTINUE + end + local oop + if char == '.' then + -- `..` 的情况 + if text:sub(i-1, i-1) == '.' then + return nil, nil + end + oop = false + elseif char == ':' then + oop = true + else + return nil, nil + end + local anyPos = findAnyPos(text, i-1) + if not anyPos then + return nil, nil + end + local parent = guide.eachSourceContain(ast.ast, anyPos, function (source) + if source.finish == anyPos then + return source + end + end) + if parent then + return parent, oop + end + ::CONTINUE:: + end + return nil, nil +end + +local function findParentInStringIndex(ast, text, offset) + local near, nearStart + guide.eachSourceContain(ast.ast, offset, function (source) + local start = guide.getStartFinish(source) + if not start then + return + end + if not nearStart or nearStart < start then + near = source + nearStart = start + end + end) + if not near or near.type ~= 'string' then + return + end + local parent = near.parent + if not parent or parent.index ~= near then + return + end + -- index不可能是oop模式 + return parent.node, false +end + +local function buildFunctionSnip(source, oop) + local name = getName(source):gsub('^.-[$.:]', '') + local defs = vm.getDefs(source, 'deep') + local args = '' + for _, def in ipairs(defs) do + local defArgs = getArg(def, oop) + if defArgs ~= '' then + args = defArgs + break + end + end + local id = 0 + args = args:gsub('[^,]+', function (arg) + id = id + 1 + return arg:gsub('^(%s*)(.+)', function (sp, word) + return ('%s${%d:%s}'):format(sp, id, word) + end) + end) + return ('%s(%s)'):format(name, args) +end + +local function buildDetail(source) + local types = vm.getInferType(source, 'deep') + local literals = vm.getInferLiteral(source, 'deep') + if literals then + return types .. ' = ' .. literals + else + return types + end +end + +local function getSnip(source) + local context = config.config.completion.displayContext + if context <= 0 then + return nil + end + local defs = vm.getRefs(source, 'deep') + for _, def in ipairs(defs) do + def = guide.getObjectValue(def) or def + if def ~= source and def.type == 'function' then + local uri = guide.getUri(def) + local text = files.getText(uri) + local lines = files.getLines(uri) + if not text then + goto CONTINUE + end + if vm.isMetaFile(uri) then + goto CONTINUE + end + local row = guide.positionOf(lines, def.start) + local firstRow = lines[row] + local lastRow = lines[math.min(row + context - 1, #lines)] + local snip = text:sub(firstRow.start, lastRow.finish) + return snip + end + ::CONTINUE:: + end +end + +local function buildDesc(source) + local hover = getHover.get(source) + local md = markdown() + md:add('lua', hover.label) + md:add('md', hover.description) + local snip = getSnip(source) + if snip then + md:add('md', '-------------') + md:add('lua', snip) + end + return md:string() +end + +local function buildFunction(results, source, oop, data) + local snipType = config.config.completion.callSnippet + if snipType == 'Disable' or snipType == 'Both' then + results[#results+1] = data + end + if snipType == 'Both' or snipType == 'Replace' then + local snipData = util.deepCopy(data) + snipData.kind = define.CompletionItemKind.Snippet + snipData.label = snipData.label .. '()' + snipData.insertText = buildFunctionSnip(source, oop) + snipData.insertTextFormat = 2 + snipData.id = stack(function () + return { + detail = buildDetail(source), + description = buildDesc(source), + } + end) + results[#results+1] = snipData + end +end + +local function isSameSource(ast, source, pos) + if not files.eq(guide.getUri(source), guide.getUri(ast.ast)) then + return false + end + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + return source.start <= pos and source.finish >= pos +end + +local function checkLocal(ast, word, offset, results) + local locals = guide.getVisibleLocals(ast.ast, offset) + for name, source in pairs(locals) do + if isSameSource(ast, source, offset) then + goto CONTINUE + end + if not matchKey(word, name) then + goto CONTINUE + end + if vm.hasType(source, 'function') then + buildFunction(results, source, false, { + label = name, + kind = define.CompletionItemKind.Function, + id = stack(function () + return { + detail = buildDetail(source), + description = buildDesc(source), + } + end), + }) + else + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Variable, + id = stack(function () + return { + detail = buildDetail(source), + description = buildDesc(source), + } + end), + } + end + ::CONTINUE:: + end +end + +local function checkFieldFromFieldToIndex(name, parent, word, start, offset) + if name:match '^[%a_][%w_]*$' then + return nil + end + local textEdit, additionalTextEdits + local uri = guide.getUri(parent) + local text = files.getText(uri) + local wordStart + if word == '' then + wordStart = text:match('()%S', start + 1) or (offset + 1) + else + wordStart = offset - #word + 1 + end + textEdit = { + start = wordStart, + finish = offset, + newText = ('[%q]'):format(name), + } + local nxt = parent.next + if nxt then + local dotStart + if nxt.type == 'setfield' + or nxt.type == 'getfield' + or nxt.type == 'tablefield' then + dotStart = nxt.dot.start + elseif nxt.type == 'setmethod' + or nxt.type == 'getmethod' then + dotStart = nxt.colon.start + end + if dotStart then + additionalTextEdits = { + { + start = dotStart, + finish = dotStart, + newText = '', + } + } + end + else + if config.config.runtime.version == 'Lua 5.1' + or config.config.runtime.version == 'LuaJIT' then + textEdit.newText = '_G' .. textEdit.newText + else + textEdit.newText = '_ENV' .. textEdit.newText + end + end + return textEdit, additionalTextEdits +end + +local function checkFieldThen(name, src, word, start, offset, parent, oop, results) + local value = guide.getObjectValue(src) or src + local kind = define.CompletionItemKind.Field + if value.type == 'function' then + if oop then + kind = define.CompletionItemKind.Method + else + kind = define.CompletionItemKind.Function + end + buildFunction(results, src, oop, { + label = name, + kind = kind, + deprecated = vm.isDeprecated(src) or nil, + id = stack(function () + return { + detail = buildDetail(src), + description = buildDesc(src), + } + end), + }) + return + end + if oop then + return + end + local literal = guide.getLiteral(value) + if literal ~= nil then + kind = define.CompletionItemKind.Enum + end + local textEdit, additionalTextEdits + if parent.next and parent.next.index then + local str = parent.next.index + textEdit = { + start = str.start + #str[2], + finish = offset, + newText = name, + } + else + textEdit, additionalTextEdits = checkFieldFromFieldToIndex(name, parent, word, start, offset) + end + results[#results+1] = { + label = name, + kind = kind, + textEdit = textEdit, + additionalTextEdits = additionalTextEdits, + id = stack(function () + return { + detail = buildDetail(src), + description = buildDesc(src), + } + end) + } +end + +local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, isGlobal) + local fields = {} + local count = 0 + for _, src in ipairs(refs) do + local key = vm.getKeyName(src) + if not key or key:sub(1, 1) ~= 's' then + goto CONTINUE + end + if isSameSource(ast, src, start) then + -- 由于fastGlobal的优化,全局变量只会找出一个值,有可能找出自己 + -- 所以遇到自己的时候重新找一下有没有其他定义 + if not isGlobal then + goto CONTINUE + end + if #vm.getGlobals(key) <= 1 then + goto CONTINUE + elseif not vm.isSet(src) then + src = vm.getGlobalSets(key)[1] or src + end + end + local name = key:sub(3) + if locals and locals[name] then + goto CONTINUE + end + if not matchKey(word, name, count >= 100) then + goto CONTINUE + end + local last = fields[name] + if not last then + fields[name] = src + count = count + 1 + goto CONTINUE + end + if src.type == 'tablefield' + or src.type == 'setfield' + or src.type == 'tableindex' + or src.type == 'setindex' + or src.type == 'setmethod' + or src.type == 'setglobal' then + fields[name] = src + goto CONTINUE + end + ::CONTINUE:: + end + for name, src in util.sortPairs(fields) do + checkFieldThen(name, src, word, start, offset, parent, oop, results) + end +end + +local function checkField(ast, word, start, offset, parent, oop, results) + local refs = vm.getFields(parent, 'deep') + checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) +end + +local function checkGlobal(ast, word, start, offset, parent, oop, results) + local locals = guide.getVisibleLocals(ast.ast, offset) + local refs = vm.getGlobalSets '*' + checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global') +end + +local function checkTableField(ast, word, start, results) + local source = guide.eachSourceContain(ast.ast, start, function (source) + if source.start == start + and source.parent + and source.parent.type == 'table' then + return source + end + end) + if not source then + return + end + local used = {} + guide.eachSourceType(ast.ast, 'tablefield', function (src) + if not src.field then + return + end + local key = src.field[1] + if not used[key] + and matchKey(word, key) + and src ~= source then + used[key] = true + results[#results+1] = { + label = key, + kind = define.CompletionItemKind.Property, + } + end + end) +end + +local function checkCommon(word, text, offset, results) + local used = {} + for _, result in ipairs(results) do + used[result.label] = true + end + for _, data in ipairs(keyWordMap) do + used[data[1]] = true + end + for str, pos in text:gmatch '([%a_][%w_]*)()' do + if not used[str] and pos - 1 ~= offset then + used[str] = true + if matchKey(word, str) then + results[#results+1] = { + label = str, + kind = define.CompletionItemKind.Text, + } + end + end + end +end + +local function isInString(ast, offset) + return guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'string' then + return true + end + end) +end + +local function checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results) + local snipType = config.config.completion.keywordSnippet + for _, data in ipairs(keyWordMap) do + local key = data[1] + local eq + if hasSpace then + eq = word == key + else + eq = matchKey(word, key) + end + if afterLocal and key ~= 'function' then + eq = false + end + if eq then + local replaced + local extra + if snipType == 'Both' or snipType == 'Replace' then + local func = data[2] + if func then + replaced = func(hasSpace, results) + extra = true + end + end + if snipType == 'Both' then + replaced = false + end + if not replaced then + if not hasSpace then + local item = { + label = key, + kind = define.CompletionItemKind.Keyword, + } + if extra then + table.insert(results, #results, item) + else + results[#results+1] = item + end + end + end + local checkStop = data[3] + if checkStop then + local stop = checkStop(ast, start) + if stop then + return true + end + end + end + end +end + +local function checkProvideLocal(ast, word, start, results) + local block + guide.eachSourceContain(ast.ast, start, function (source) + if source.type == 'function' + or source.type == 'main' then + block = source + end + end) + if not block then + return + end + local used = {} + guide.eachSourceType(block, 'getglobal', function (source) + if source.start > start + and not used[source[1]] + and matchKey(word, source[1]) then + used[source[1]] = true + results[#results+1] = { + label = source[1], + kind = define.CompletionItemKind.Variable, + } + end + end) + guide.eachSourceType(block, 'getlocal', function (source) + if source.start > start + and not used[source[1]] + and matchKey(word, source[1]) then + used[source[1]] = true + results[#results+1] = { + label = source[1], + kind = define.CompletionItemKind.Variable, + } + end + end) +end + +local function checkFunctionArgByDocParam(ast, word, start, results) + local func = guide.eachSourceContain(ast.ast, start, function (source) + if source.type == 'function' then + return source + end + end) + if not func then + return + end + local docs = func.bindDocs + if not docs then + return + end + local params = {} + for _, doc in ipairs(docs) do + if doc.type == 'doc.param' then + params[#params+1] = doc + end + end + local firstArg = func.args and func.args[1] + if not firstArg + or firstArg.start <= start and firstArg.finish >= start then + local firstParam = params[1] + if firstParam and matchKey(word, firstParam.param[1]) then + local label = {} + for _, param in ipairs(params) do + label[#label+1] = param.param[1] + end + results[#results+1] = { + label = table.concat(label, ', '), + kind = define.CompletionItemKind.Snippet, + } + end + end + for _, doc in ipairs(params) do + if matchKey(word, doc.param[1]) then + results[#results+1] = { + label = doc.param[1], + kind = define.CompletionItemKind.Interface, + } + end + end +end + +local function isAfterLocal(text, start) + local pos = skipSpace(text, start-1) + local word = findWord(text, pos) + return word == 'local' +end + +local function checkUri(ast, text, offset, results) + local collect = {} + local myUri = guide.getUri(ast.ast) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type ~= 'string' then + return + end + local callargs = source.parent + if not callargs or callargs.type ~= 'callargs' then + return + end + if callargs[1] ~= source then + return + end + local call = callargs.parent + local func = call.node + local literal = guide.getLiteral(source) + local libName = vm.getLibraryName(func) + if not libName then + return + end + if libName == 'require' then + for uri in files.eachFile() do + uri = files.getOriginUri(uri) + if files.eq(myUri, uri) then + goto CONTINUE + end + if vm.isMetaFile(uri) then + goto CONTINUE + end + local path = workspace.getRelativePath(uri) + local infos = rpath.getVisiblePath(path, config.config.runtime.path) + for _, info in ipairs(infos) do + if matchKey(literal, info.expect) then + if not collect[info.expect] then + collect[info.expect] = { + textEdit = { + start = source.start + #source[2], + finish = source.finish - #source[2], + } + } + end + collect[info.expect][#collect[info.expect]+1] = ([=[* [%s](%s) %s]=]):format( + path, + uri, + lang.script('HOVER_USE_LUA_PATH', info.searcher) + ) + end + end + ::CONTINUE:: + end + elseif libName == 'dofile' + or libName == 'loadfile' then + for uri in files.eachFile() do + uri = files.getOriginUri(uri) + if files.eq(myUri, uri) then + goto CONTINUE + end + if vm.isMetaFile(uri) then + goto CONTINUE + end + local path = workspace.getRelativePath(uri) + if matchKey(literal, path) then + if not collect[path] then + collect[path] = { + textEdit = { + start = source.start + #source[2], + finish = source.finish - #source[2], + } + } + end + collect[path][#collect[path]+1] = ([=[[%s](%s)]=]):format( + path, + uri + ) + end + ::CONTINUE:: + end + end + end) + for label, infos in util.sortPairs(collect) do + local mark = {} + local des = {} + for _, info in ipairs(infos) do + if not mark[info] then + mark[info] = true + des[#des+1] = info + end + end + results[#results+1] = { + label = label, + kind = define.CompletionItemKind.Reference, + description = table.concat(des, '\n'), + textEdit = infos.textEdit, + } + end +end + +local function checkLenPlusOne(ast, text, offset, results) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'getindex' + or source.type == 'setindex' then + local _, pos = text:find('%s*%[%s*%#', source.node.finish) + if not pos then + return + end + local nodeText = text:sub(source.node.start, source.node.finish) + local writingText = trim(text:sub(pos + 1, offset - 1)) or '' + if not matchKey(writingText, nodeText) then + return + end + if source.parent == guide.getParentBlock(source) then + -- state + local label = text:match('%#[ \t]*', pos) .. nodeText .. '+1' + local eq = text:find('^%s*%]?%s*%=', source.finish) + local newText = label .. ']' + if not eq then + newText = newText .. ' = ' + end + results[#results+1] = { + label = label, + kind = define.CompletionItemKind.Snippet, + textEdit = { + start = pos, + finish = source.finish, + newText = newText, + }, + } + else + -- exp + local label = text:match('%#[ \t]*', pos) .. nodeText + local newText = label .. ']' + results[#results+1] = { + label = label, + kind = define.CompletionItemKind.Snippet, + textEdit = { + start = pos, + finish = source.finish, + newText = newText, + }, + } + end + end + end) +end + +local function isFuncArg(ast, offset) + return guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'funcargs' then + return true + end + end) +end + +local function trySpecial(ast, text, offset, results) + if isInString(ast, offset) then + checkUri(ast, text, offset, results) + return + end + -- x[#x+1] + checkLenPlusOne(ast, text, offset, results) +end + +local function tryIndex(ast, text, offset, results) + local parent, oop = findParentInStringIndex(ast, text, offset) + if not parent then + return + end + local word = parent.next.index[1] + checkField(ast, word, offset, offset, parent, oop, results) +end + +local function tryWord(ast, text, offset, results) + local finish = skipSpace(text, offset) + local word, start = findWord(text, finish) + if not word then + return nil + end + local hasSpace = finish ~= offset + if isInString(ast, offset) then + else + local parent, oop = findParent(ast, text, start - 1) + if parent then + if not hasSpace then + checkField(ast, word, start, offset, parent, oop, results) + end + elseif isFuncArg(ast, offset) then + checkProvideLocal(ast, word, start, results) + checkFunctionArgByDocParam(ast, word, start, results) + else + local afterLocal = isAfterLocal(text, start) + local stop = checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results) + if stop then + return + end + if not hasSpace then + if afterLocal then + checkProvideLocal(ast, word, start, results) + else + checkLocal(ast, word, start, results) + checkTableField(ast, word, start, results) + local env = guide.getENV(ast.ast, start) + checkGlobal(ast, word, start, offset, env, false, results) + end + end + end + if not hasSpace then + checkCommon(word, text, offset, results) + end + end +end + +local function trySymbol(ast, text, offset, results) + local symbol, start = findSymbol(text, offset) + if not symbol then + return nil + end + if isInString(ast, offset) then + return nil + end + if symbol == '.' + or symbol == ':' then + local parent, oop = findParent(ast, text, start) + if parent then + checkField(ast, '', start, offset, parent, oop, results) + end + end + if symbol == '(' then + checkFunctionArgByDocParam(ast, '', start, results) + end +end + +local function getCallEnums(source, index) + if source.type == 'function' and source.bindDocs then + if not source.args then + return + end + local arg + if index <= #source.args then + arg = source.args[index] + else + local lastArg = source.args[#source.args] + if lastArg.type == '...' then + arg = lastArg + else + return + end + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.param' + and doc.param[1] == arg[1] then + local enums = {} + for _, enum in ipairs(vm.getDocEnums(doc.extends)) do + enums[#enums+1] = { + label = enum[1], + description = enum.comment, + kind = define.CompletionItemKind.EnumMember, + } + end + return enums + elseif doc.type == 'doc.vararg' + and arg.type == '...' then + local enums = {} + for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do + enums[#enums+1] = { + label = enum[1], + description = enum.comment, + kind = define.CompletionItemKind.EnumMember, + } + end + return enums + end + end + end +end + +local function tryLabelInString(label, arg) + if not arg or arg.type ~= 'string' then + return label + end + local str = parser:grammar(label, 'String') + if not str then + return label + end + if not matchKey(arg[1], str[1]) then + return nil + end + return util.viewString(str[1], arg[2]) +end + +local function mergeEnums(a, b, text, arg) + local mark = {} + for _, enum in ipairs(a) do + mark[enum.label] = true + end + for _, enum in ipairs(b) do + local label = tryLabelInString(enum.label, arg) + if label and not mark[label] then + mark[label] = true + local result = { + label = label, + kind = define.CompletionItemKind.EnumMember, + description = enum.description, + textEdit = arg and { + start = arg.start, + finish = arg.finish, + newText = label, + }, + } + a[#a+1] = result + end + end +end + +local function findCall(ast, text, offset) + local call + guide.eachSourceContain(ast.ast, offset, function (src) + if src.type == 'call' then + if not call or call.start < src.start then + call = src + end + end + end) + return call +end + +local function getCallArgInfo(call, text, offset) + if not call.args then + return 1, nil + end + for index, arg in ipairs(call.args) do + if arg.start <= offset and arg.finish >= offset then + return index, arg + end + end + return #call.args + 1, nil +end + +local function tryCallArg(ast, text, offset, results) + local call = findCall(ast, text, offset) + if not call then + return + end + local myResults = {} + local argIndex, arg = getCallArgInfo(call, text, offset) + if arg and arg.type == 'function' then + return + end + local defs = vm.getDefs(call.node, 'deep') + for _, def in ipairs(defs) do + def = guide.getObjectValue(def) or def + local enums = getCallEnums(def, argIndex) + if enums then + mergeEnums(myResults, enums, text, arg) + end + end + for _, enum in ipairs(myResults) do + results[#results+1] = enum + end +end + +local function getComment(ast, offset) + for _, comm in ipairs(ast.comms) do + if offset >= comm.start and offset <= comm.finish then + return comm + end + end + return nil +end + +local function tryLuaDocCate(line, results) + local word = line:sub(3) + for _, docType in ipairs { + 'class', + 'type', + 'alias', + 'param', + 'return', + 'field', + 'generic', + 'vararg', + 'overload', + 'deprecated', + 'meta', + 'version', + } do + if matchKey(word, docType) then + results[#results+1] = { + label = docType, + kind = define.CompletionItemKind.Event, + } + end + end +end + +local function getLuaDocByContain(ast, offset) + local result + local range = math.huge + guide.eachSourceContain(ast.ast.docs, offset, function (src) + if not src.start then + return + end + if range >= offset - src.start + and offset <= src.finish then + range = offset - src.start + result = src + end + end) + return result +end + +local function getLuaDocByErr(ast, text, start, offset) + local targetError + for _, err in ipairs(ast.errs) do + if err.finish <= offset + and err.start >= start then + if not text:sub(err.finish + 1, offset):find '%S' then + targetError = err + break + end + end + end + if not targetError then + return nil + end + local targetDoc + for i = #ast.ast.docs, 1, -1 do + local doc = ast.ast.docs[i] + if doc.finish <= targetError.start then + targetDoc = doc + break + end + end + return targetError, targetDoc +end + +local function tryLuaDocBySource(ast, offset, source, results) + if source.type == 'doc.extends.name' then + if source.parent.type == 'doc.class' then + for _, doc in ipairs(vm.getDocTypes '*') do + if doc.type == 'doc.class.name' + and doc.parent ~= source.parent + and matchKey(source[1], doc[1]) then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + end + elseif source.type == 'doc.type.name' then + for _, doc in ipairs(vm.getDocTypes '*') do + if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') + and doc.parent ~= source.parent + and matchKey(source[1], doc[1]) then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + elseif source.type == 'doc.param.name' then + local funcs = {} + guide.eachSourceBetween(ast.ast, offset, math.huge, function (src) + if src.type == 'function' and src.start > offset then + funcs[#funcs+1] = src + end + end) + table.sort(funcs, function (a, b) + return a.start < b.start + end) + local func = funcs[1] + if not func or not func.args then + return + end + for _, arg in ipairs(func.args) do + if arg[1] and matchKey(source[1], arg[1]) then + results[#results+1] = { + label = arg[1], + kind = define.CompletionItemKind.Interface, + } + end + end + end +end + +local function tryLuaDocByErr(ast, offset, err, docState, results) + if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then + for _, doc in ipairs(vm.getDocTypes '*') do + if doc.type == 'doc.class.name' + and doc.parent ~= docState then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + elseif err.type == 'LUADOC_MISS_TYPE_NAME' then + for _, doc in ipairs(vm.getDocTypes '*') do + if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + elseif err.type == 'LUADOC_MISS_PARAM_NAME' then + local funcs = {} + guide.eachSourceBetween(ast.ast, offset, math.huge, function (src) + if src.type == 'function' and src.start > offset then + funcs[#funcs+1] = src + end + end) + table.sort(funcs, function (a, b) + return a.start < b.start + end) + local func = funcs[1] + if not func or not func.args then + return + end + local label = {} + local insertText = {} + for i, arg in ipairs(func.args) do + if arg[1] then + label[#label+1] = arg[1] + if i == 1 then + insertText[i] = ('%s ${%d:any}'):format(arg[1], i) + else + insertText[i] = ('---@param %s ${%d:any}'):format(arg[1], i) + end + end + end + results[#results+1] = { + label = table.concat(label, ', '), + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = table.concat(insertText, '\n'), + } + for i, arg in ipairs(func.args) do + if arg[1] then + results[#results+1] = { + label = arg[1], + kind = define.CompletionItemKind.Interface, + } + end + end + end +end + +local function tryLuaDocFeatures(line, ast, comm, offset, results) +end + +local function tryLuaDoc(ast, text, offset, results) + local comm = getComment(ast, offset) + local line = text:sub(comm.start, offset) + if not line then + return + end + if line:sub(1, 2) ~= '-@' then + return + end + -- 尝试 ---@$ + local cate = line:match('%a*', 3) + if #cate + 2 >= #line then + tryLuaDocCate(line, results) + return + end + -- 尝试一些其他特征 + if tryLuaDocFeatures(line, ast, comm, offset, results) then + return + end + -- 根据输入中的source来补全 + local source = getLuaDocByContain(ast, offset) + if source then + tryLuaDocBySource(ast, offset, source, results) + return + end + -- 根据附近的错误消息来补全 + local err, doc = getLuaDocByErr(ast, text, comm.start, offset) + if err then + tryLuaDocByErr(ast, offset, err, doc, results) + return + end +end + +local function completion(uri, offset) + local ast = files.getAst(uri) + local text = files.getText(uri) + local results = {} + clearStack() + if ast then + if getComment(ast, offset) then + tryLuaDoc(ast, text, offset, results) + else + trySpecial(ast, text, offset, results) + tryWord(ast, text, offset, results) + tryIndex(ast, text, offset, results) + trySymbol(ast, text, offset, results) + tryCallArg(ast, text, offset, results) + end + else + local word = findWord(text, offset) + if word then + checkCommon(word, text, offset, results) + end + end + + if #results == 0 then + return nil + end + return results +end + +local function resolve(id) + return resolveStack(id) +end + +return { + completion = completion, + resolve = resolve, +} diff --git a/script/core/definition.lua b/script/core/definition.lua new file mode 100644 index 00000000..c143939d --- /dev/null +++ b/script/core/definition.lua @@ -0,0 +1,156 @@ +local guide = require 'parser.guide' +local workspace = require 'workspace' +local files = require 'files' +local vm = require 'vm' +local findSource = require 'core.find-source' + +local function sortResults(results) + -- 先按照顺序排序 + table.sort(results, function (a, b) + local u1 = guide.getUri(a.target) + local u2 = guide.getUri(b.target) + if u1 == u2 then + return a.target.start < b.target.start + else + return u1 < u2 + end + end) + -- 如果2个结果处于嵌套状态,则取范围小的那个 + local lf, lu + for i = #results, 1, -1 do + local res = results[i].target + local f = res.finish + local uri = guide.getUri(res) + if lf and f > lf and uri == lu then + table.remove(results, i) + else + lu = uri + lf = f + end + end +end + +local accept = { + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['label'] = true, + ['goto'] = true, + ['field'] = true, + ['method'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['string'] = true, + ['boolean'] = true, + ['number'] = true, + + ['doc.type.name'] = true, + ['doc.class.name'] = true, + ['doc.extends.name'] = true, + ['doc.alias.name'] = true, +} + +local function checkRequire(source, offset) + if source.type ~= 'string' then + return nil + end + local callargs = source.parent + if callargs.type ~= 'callargs' then + return + end + if callargs[1] ~= source then + return + end + local call = callargs.parent + local func = call.node + local literal = guide.getLiteral(source) + local libName = vm.getLibraryName(func) + if not libName then + return nil + end + if libName == 'require' then + return workspace.findUrisByRequirePath(literal) + elseif libName == 'dofile' + or libName == 'loadfile' then + return workspace.findUrisByFilePath(literal) + end + return nil +end + +local function convertIndex(source) + if not source then + return + end + if source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' then + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + return parent + end + end + return source +end + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local source = convertIndex(findSource(ast, offset, accept)) + if not source then + return nil + end + + local results = {} + local uris = checkRequire(source) + if uris then + for i, uri in ipairs(uris) do + results[#results+1] = { + uri = files.getOriginUri(uri), + source = source, + target = { + start = 0, + finish = 0, + uri = uri, + } + } + end + end + + for _, src in ipairs(vm.getDefs(source, 'deep')) do + local root = guide.getRoot(src) + if not root then + goto CONTINUE + end + src = src.field or src.method or src.index or src + if src.type == 'table' and src.parent.type ~= 'return' then + goto CONTINUE + end + if src.type == 'doc.class.name' + and source.type ~= 'doc.type.name' + and source.type ~= 'doc.extends.name' then + goto CONTINUE + end + results[#results+1] = { + target = src, + uri = files.getOriginUri(root.uri), + source = source, + } + ::CONTINUE:: + end + + if #results == 0 then + return nil + end + + sortResults(results) + + return results +end diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua new file mode 100644 index 00000000..37815fb5 --- /dev/null +++ b/script/core/diagnostics/ambiguity-1.lua @@ -0,0 +1,69 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +local opMap = { + ['+'] = true, + ['-'] = true, + ['*'] = true, + ['/'] = true, + ['//'] = true, + ['^'] = true, + ['<<'] = true, + ['>>'] = true, + ['&'] = true, + ['|'] = true, + ['~'] = true, + ['..'] = true, +} + +local literalMap = { + ['number'] = true, + ['boolean'] = true, + ['string'] = true, + ['table'] = true, +} + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local text = files.getText(uri) + guide.eachSourceType(ast.ast, 'binary', function (source) + if source.op.type ~= 'or' then + return + end + local first = source[1] + local second = source[2] + -- a + (b or 0) --> (a + b) or 0 + do + if opMap[first.op and first.op.type] + and first.type ~= 'unary' + and not second.op + and literalMap[second.type] + and not literalMap[first[2].type] + then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_AMBIGUITY_1', text:sub(first.start, first.finish)) + } + end + end + -- (a or 0) + c --> a or (0 + c) + do + if opMap[second.op and second.op.type] + and second.type ~= 'unary' + and not first.op + and literalMap[second[1].type] + then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_AMBIGUITY_1', text:sub(second.start, second.finish)) + } + end + end + end) +end diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua new file mode 100644 index 00000000..55179447 --- /dev/null +++ b/script/core/diagnostics/circle-doc-class.lua @@ -0,0 +1,54 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.class' then + if not doc.extends then + goto CONTINUE + end + local myName = guide.getName(doc) + local list = { doc } + local mark = {} + for i = 1, 999 do + local current = list[i] + if not current then + goto CONTINUE + end + if current.extends then + local newName = current.extends[1] + if newName == myName then + callback { + start = doc.start, + finish = doc.finish, + message = lang.script('DIAG_CIRCLE_DOC_CLASS', myName) + } + goto CONTINUE + end + if not mark[newName] then + mark[newName] = true + local docs = vm.getDocTypes(newName) + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' then + list[#list+1] = otherDoc.parent + end + end + end + end + end + ::CONTINUE:: + end + end +end diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua new file mode 100644 index 00000000..a2bac8a4 --- /dev/null +++ b/script/core/diagnostics/code-after-break.lua @@ -0,0 +1,34 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + local mark = {} + guide.eachSourceType(state.ast, 'break', function (source) + local list = source.parent + if mark[list] then + return + end + mark[list] = true + for i = #list, 1, -1 do + local src = list[i] + if src == source then + if i == #list then + return + end + callback { + start = list[i+1].start, + finish = list[#list].range or list[#list].finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_CODE_AFTER_BREAK, + } + end + end + end) +end diff --git a/script/core/diagnostics/doc-field-no-class.lua b/script/core/diagnostics/doc-field-no-class.lua new file mode 100644 index 00000000..f27bbb32 --- /dev/null +++ b/script/core/diagnostics/doc-field-no-class.lua @@ -0,0 +1,41 @@ +local files = require 'files' +local lang = require 'language' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type ~= 'doc.field' then + goto CONTINUE + end + local bindGroup = doc.bindGroup + if not bindGroup then + goto CONTINUE + end + local ok + for _, other in ipairs(bindGroup) do + if other.type == 'doc.class' then + ok = true + break + end + if other == doc then + break + end + end + if not ok then + callback { + start = doc.start, + finish = doc.finish, + message = lang.script('DIAG_DOC_FIELD_NO_CLASS'), + } + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua new file mode 100644 index 00000000..259c048b --- /dev/null +++ b/script/core/diagnostics/duplicate-doc-class.lua @@ -0,0 +1,46 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + local cache = {} + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.class' + or doc.type == 'doc.alias' then + local name = guide.getName(doc) + if not cache[name] then + local docs = vm.getDocTypes(name) + cache[name] = {} + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' + or otherDoc.type == 'doc.alias.name' then + cache[name][#cache[name]+1] = { + start = otherDoc.start, + finish = otherDoc.finish, + uri = guide.getUri(otherDoc), + } + end + end + end + if #cache[name] > 1 then + callback { + start = doc.start, + finish = doc.finish, + related = cache, + message = lang.script('DIAG_DUPLICATE_DOC_CLASS', name) + } + end + end + end +end diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua new file mode 100644 index 00000000..b621fd9e --- /dev/null +++ b/script/core/diagnostics/duplicate-doc-field.lua @@ -0,0 +1,34 @@ +local files = require 'files' +local lang = require 'language' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + local mark + for _, group in ipairs(state.ast.docs.groups) do + for _, doc in ipairs(group) do + if doc.type == 'doc.class' then + mark = {} + elseif doc.type == 'doc.field' then + if mark then + local name = doc.field[1] + if mark[name] then + callback { + start = doc.field.start, + finish = doc.field.finish, + message = lang.script('DIAG_DUPLICATE_DOC_FIELD', name), + } + end + mark[name] = true + end + end + end + end +end diff --git a/script/core/diagnostics/duplicate-doc-param.lua b/script/core/diagnostics/duplicate-doc-param.lua new file mode 100644 index 00000000..676a6fb4 --- /dev/null +++ b/script/core/diagnostics/duplicate-doc-param.lua @@ -0,0 +1,37 @@ +local files = require 'files' +local lang = require 'language' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type ~= 'doc.param' then + goto CONTINUE + end + local name = doc.param[1] + local bindGroup = doc.bindGroup + if not bindGroup then + goto CONTINUE + end + for _, other in ipairs(bindGroup) do + if other ~= doc + and other.type == 'doc.param' + and other.param[1] == name then + callback { + start = doc.param.start, + finish = doc.param.finish, + message = lang.script('DIAG_DUPLICATE_DOC_PARAM', name) + } + goto CONTINUE + end + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua new file mode 100644 index 00000000..dabe1b3c --- /dev/null +++ b/script/core/diagnostics/duplicate-index.lua @@ -0,0 +1,63 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'table', function (source) + local mark = {} + for _, obj in ipairs(source) do + if obj.type == 'tablefield' + or obj.type == 'tableindex' then + local name = vm.getKeyName(obj) + if name then + if not mark[name] then + mark[name] = {} + end + mark[name][#mark[name]+1] = obj.field or obj.index + end + end + end + + for name, defs in pairs(mark) do + local sname = name:match '^.|(.+)$' + if #defs > 1 and sname then + local related = {} + for i = 1, #defs do + local def = defs[i] + related[i] = { + start = def.start, + finish = def.finish, + uri = uri, + } + end + for i = 1, #defs - 1 do + local def = defs[i] + callback { + start = def.start, + finish = def.finish, + related = related, + message = lang.script('DIAG_DUPLICATE_INDEX', sname), + level = define.DiagnosticSeverity.Hint, + tags = { define.DiagnosticTag.Unnecessary }, + } + end + for i = #defs, #defs do + local def = defs[i] + callback { + start = def.start, + finish = def.finish, + related = related, + message = lang.script('DIAG_DUPLICATE_INDEX', sname), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua new file mode 100644 index 00000000..2024f4e3 --- /dev/null +++ b/script/core/diagnostics/empty-block.lua @@ -0,0 +1,49 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' + +-- 检查空代码块 +-- 但是排除忙等待(repeat/while) +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'if', function (source) + for _, block in ipairs(source) do + if #block > 0 then + return + end + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) + guide.eachSourceType(ast.ast, 'loop', function (source) + if #source > 0 then + return + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) + guide.eachSourceType(ast.ast, 'in', function (source) + if #source > 0 then + return + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) +end diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua new file mode 100644 index 00000000..9a0d4f35 --- /dev/null +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -0,0 +1,66 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +-- TODO: 检查路径是否可达 +local function mayRun(path) + return true +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local root = guide.getRoot(ast.ast) + local env = guide.getENV(root) + + local nilDefs = {} + if not env.ref then + return + end + for _, ref in ipairs(env.ref) do + if ref.type == 'setlocal' then + if ref.value and ref.value.type == 'nil' then + nilDefs[#nilDefs+1] = ref + end + end + end + + if #nilDefs == 0 then + return + end + + local function check(source) + local node = source.node + if node.tag == '_ENV' then + local ok + for _, nilDef in ipairs(nilDefs) do + local mode, pathA = guide.getPath(nilDef, source) + if mode == 'before' + and mayRun(pathA) then + ok = nilDef + break + end + end + if ok then + callback { + start = source.start, + finish = source.finish, + uri = uri, + message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, + related = { + { + start = ok.start, + finish = ok.finish, + uri = uri, + } + } + } + end + end + end + + guide.eachSourceType(ast.ast, 'getglobal', check) + guide.eachSourceType(ast.ast, 'setglobal', check) +end diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua new file mode 100644 index 00000000..a6b61e12 --- /dev/null +++ b/script/core/diagnostics/init.lua @@ -0,0 +1,56 @@ +local files = require 'files' +local define = require 'proto.define' +local config = require 'config' +local await = require 'await' + +-- 把耗时最长的诊断放到最后面 +local diagLevel = { + ['redundant-parameter'] = 100, +} + +local diagList = {} +for k in pairs(define.DiagnosticDefaultSeverity) do + diagList[#diagList+1] = k +end +table.sort(diagList, function (a, b) + return (diagLevel[a] or 0) < (diagLevel[b] or 0) +end) + +local function check(uri, name, results) + if config.config.diagnostics.disable[name] then + return + end + local level = config.config.diagnostics.severity[name] + or define.DiagnosticDefaultSeverity[name] + if level == 'Hint' and not files.isOpen(uri) then + return + end + local severity = define.DiagnosticSeverity[level] + local clock = os.clock() + require('core.diagnostics.' .. name)(uri, function (result) + result.level = severity or result.level + result.code = name + results[#results+1] = result + end, name) + local passed = os.clock() - clock + if passed >= 0.5 then + log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed)) + end +end + +return function (uri, response) + local vm = require 'vm' + local ast = files.getAst(uri) + if not ast then + return nil + end + + local isOpen = files.isOpen(uri) + + for _, name in ipairs(diagList) do + await.delay() + local results = {} + check(uri, name, results) + response(results) + end +end diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua new file mode 100644 index 00000000..fe5d1eca --- /dev/null +++ b/script/core/diagnostics/lowercase-global.lua @@ -0,0 +1,56 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local config = require 'config' +local vm = require 'vm' + +local function isDocClass(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' then + return true + end + end + return false +end + +-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删) +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + local definedGlobal = {} + for name in pairs(config.config.diagnostics.globals) do + definedGlobal[name] = true + end + + guide.eachSourceType(ast.ast, 'setglobal', function (source) + local name = guide.getName(source) + if definedGlobal[name] then + return + end + local first = name:match '%w' + if not first then + return + end + if not first:match '%l' then + return + end + -- 如果赋值被标记为 doc.class ,则认为是允许的 + if isDocClass(source) then + return + end + if vm.isGlobalLibraryName(name) then + return + end + callback { + start = source.start, + finish = source.finish, + message = lang.script.DIAG_LOWERCASE_GLOBAL, + } + end) +end diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua new file mode 100644 index 00000000..75681cbc --- /dev/null +++ b/script/core/diagnostics/newfield-call.lua @@ -0,0 +1,37 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + local lines = files.getLines(uri) + local text = files.getText(uri) + + guide.eachSourceType(ast.ast, 'table', function (source) + for i = 1, #source do + local field = source[i] + if field.type == 'call' then + local func = field.node + local args = field.args + if args then + local funcLine = guide.positionOf(lines, func.finish) + local argsLine = guide.positionOf(lines, args.start) + if argsLine > funcLine then + callback { + start = field.start, + finish = field.finish, + message = lang.script('DIAG_PREFIELD_CALL' + , text:sub(func.start, func.finish) + , text:sub(args.start, args.finish) + ) + } + end + end + end + end + end) +end diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua new file mode 100644 index 00000000..cb318380 --- /dev/null +++ b/script/core/diagnostics/newline-call.lua @@ -0,0 +1,38 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local lines = files.getLines(uri) + + guide.eachSourceType(ast.ast, 'call', function (source) + local node = source.node + local args = source.args + if not args then + return + end + + -- 必须有其他人在继续使用当前对象 + if not source.next then + return + end + + local nodeRow = guide.positionOf(lines, node.finish) + local argRow = guide.positionOf(lines, args.start) + if nodeRow == argRow then + return + end + + if #args == 1 then + callback { + start = args.start, + finish = args.finish, + message = lang.script.DIAG_PREVIOUS_CALL, + } + end + end) +end diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua new file mode 100644 index 00000000..5e53d837 --- /dev/null +++ b/script/core/diagnostics/redefined-local.lua @@ -0,0 +1,32 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'local', function (source) + local name = source[1] + if name == '_' + or name == ast.ENVMode then + return + end + local exist = guide.getLocal(source, name, source.start-1) + if exist then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_REDEFINED_LOCAL', name), + related = { + { + start = exist.start, + finish = exist.finish, + uri = uri, + } + }, + } + end + end) +end diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua new file mode 100644 index 00000000..2fae20e8 --- /dev/null +++ b/script/core/diagnostics/redundant-parameter.lua @@ -0,0 +1,82 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local define = require 'proto.define' +local await = require 'await' + +local function countCallArgs(source) + local result = 0 + if not source.args then + return 0 + end + if source.node and source.node.type == 'getmethod' then + result = result + 1 + end + result = result + #source.args + return result +end + +local function countFuncArgs(source) + local result = 0 + if source.parent and source.parent.type == 'setmethod' then + result = result + 1 + end + if not source.args then + return result + end + if source.args[#source.args].type == '...' then + return math.maxinteger + end + result = result + #source.args + return result +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'call', function (source) + local callArgs = countCallArgs(source) + if callArgs == 0 then + return + end + + local func = source.node + local funcArgs + local defs = vm.getDefs(func) + for _, def in ipairs(defs) do + if def.value then + def = def.value + end + if def.type == 'function' then + local args = countFuncArgs(def) + if not funcArgs or args > funcArgs then + funcArgs = args + end + end + end + + if not funcArgs then + return + end + + local delta = callArgs - funcArgs + if delta <= 0 then + return + end + for i = #source.args - delta + 1, #source.args do + local arg = source.args[i] + if arg then + callback { + start = arg.start, + finish = arg.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs) + } + end + end + end) +end diff --git a/script/core/diagnostics/redundant-value.lua b/script/core/diagnostics/redundant-value.lua new file mode 100644 index 00000000..be483448 --- /dev/null +++ b/script/core/diagnostics/redundant-value.lua @@ -0,0 +1,24 @@ +local files = require 'files' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback, code) + local ast = files.getAst(uri) + if not ast then + return + end + + local diags = ast.diags[code] + if not diags then + return + end + + for _, info in ipairs(diags) do + callback { + start = info.start, + finish = info.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_OVER_MAX_VALUES', info.max, info.passed) + } + end +end diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua new file mode 100644 index 00000000..e54a6e60 --- /dev/null +++ b/script/core/diagnostics/trailing-space.lua @@ -0,0 +1,55 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' + +local function isInString(ast, offset) + local result = false + guide.eachSourceType(ast, 'string', function (source) + if offset >= source.start and offset <= source.finish then + result = true + end + end) + return result +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local text = files.getText(uri) + local lines = files.getLines(uri) + for i = 1, #lines do + local start = lines[i].start + local range = lines[i].range + local lastChar = text:sub(range, range) + if lastChar ~= ' ' and lastChar ~= '\t' then + goto NEXT_LINE + end + if isInString(ast.ast, range) then + goto NEXT_LINE + end + local first = start + for n = range - 1, start, -1 do + local char = text:sub(n, n) + if char ~= ' ' and char ~= '\t' then + first = n + 1 + break + end + end + if first == start then + callback { + start = first, + finish = range, + message = lang.script.DIAG_LINE_ONLY_SPACE, + } + else + callback { + start = first, + finish = range, + message = lang.script.DIAG_LINE_POST_SPACE, + } + end + ::NEXT_LINE:: + end +end diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua new file mode 100644 index 00000000..bbfdceec --- /dev/null +++ b/script/core/diagnostics/undefined-doc-class.lua @@ -0,0 +1,46 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + local cache = {} + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.class' then + local ext = doc.extends + if not ext then + goto CONTINUE + end + local name = ext[1] + local docs = vm.getDocTypes(name) + if cache[name] == nil then + cache[name] = false + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' then + cache[name] = true + break + end + end + end + if not cache[name] then + callback { + start = ext.start, + finish = ext.finish, + related = cache, + message = lang.script('DIAG_UNDEFINED_DOC_CLASS', name) + } + end + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua new file mode 100644 index 00000000..5c1e8fbf --- /dev/null +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -0,0 +1,60 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +local function hasNameOfClassOrAlias(name) + local docs = vm.getDocTypes(name) + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' + or otherDoc.type == 'doc.alias.name' then + return true + end + end + return false +end + +local function hasNameOfGeneric(name, source) + if not source.typeGeneric then + return false + end + if not source.typeGeneric[name] then + return false + end + return true +end + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + guide.eachSource(state.ast.docs, function (source) + if source.type ~= 'doc.extends.name' + and source.type ~= 'doc.type.name' then + return + end + if source.parent.type == 'doc.class' then + return + end + local name = source[1] + if name == '...' then + return + end + if hasNameOfClassOrAlias(name) + or hasNameOfGeneric(name, source) then + return + end + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_UNDEFINED_DOC_NAME', name) + } + end) +end diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua new file mode 100644 index 00000000..af3e07bc --- /dev/null +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -0,0 +1,52 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +local function hasParamName(func, name) + if not func.args then + return false + end + for _, arg in ipairs(func.args) do + if arg[1] == name then + return true + end + end + return false +end + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type ~= 'doc.param' then + goto CONTINUE + end + local binds = doc.bindSources + if not binds then + goto CONTINUE + end + local param = doc.param + local name = param[1] + for _, source in ipairs(binds) do + if source.type == 'function' then + if not hasParamName(source, name) then + callback { + start = param.start, + finish = param.finish, + message = lang.script('DIAG_UNDEFINED_DOC_PARAM', name) + } + end + end + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua new file mode 100644 index 00000000..6b8c62f0 --- /dev/null +++ b/script/core/diagnostics/undefined-env-child.lua @@ -0,0 +1,27 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'getglobal', function (source) + -- 单独验证自己是否在重载过的 _ENV 中有定义 + if source.node.tag == '_ENV' then + return + end + local defs = guide.requestDefinition(source) + if #defs > 0 then + return + end + local key = source[1] + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_UNDEF_ENV_CHILD', key), + } + end) +end diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua new file mode 100644 index 00000000..778fc1f1 --- /dev/null +++ b/script/core/diagnostics/undefined-global.lua @@ -0,0 +1,40 @@ +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local config = require 'config' +local guide = require 'parser.guide' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + -- 遍历全局变量,检查所有没有 set 模式的全局变量 + guide.eachSourceType(ast.ast, 'getglobal', function (src) + local key = guide.getName(src) + if not key then + return + end + if config.config.diagnostics.globals[key] then + return + end + if #vm.getGlobalSets(guide.getKeyName(src)) > 0 then + return + end + local message = lang.script('DIAG_UNDEF_GLOBAL', key) + -- TODO check other version + local otherVersion + local customVersion + if otherVersion then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version)) + elseif customVersion then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customVersion, '/'))) + end + callback { + start = src.start, + finish = src.finish, + message = message, + } + end) +end diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua new file mode 100644 index 00000000..f0bca613 --- /dev/null +++ b/script/core/diagnostics/unused-function.lua @@ -0,0 +1,40 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local define = require 'proto.define' +local lang = require 'language' +local await = require 'await' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + -- 只检查局部函数 + guide.eachSourceType(ast.ast, 'function', function (source) + local parent = source.parent + if not parent then + return + end + if parent.type ~= 'local' + and parent.type ~= 'setlocal' then + return + end + local hasGet + local refs = vm.getRefs(source) + for _, src in ipairs(refs) do + if vm.isGet(src) then + hasGet = true + break + end + end + if not hasGet then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_FUNCTION, + } + end + end) +end diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua new file mode 100644 index 00000000..e6d998ba --- /dev/null +++ b/script/core/diagnostics/unused-label.lua @@ -0,0 +1,22 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'label', function (source) + if not source.ref then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LABEL', source[1]), + } + end + end) +end diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua new file mode 100644 index 00000000..873a70f2 --- /dev/null +++ b/script/core/diagnostics/unused-local.lua @@ -0,0 +1,93 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +local function hasGet(loc) + if not loc.ref then + return false + end + local weak + for _, ref in ipairs(loc.ref) do + if ref.type == 'getlocal' then + if not ref.next then + return 'strong' + end + local nextType = ref.next.type + if nextType ~= 'setmethod' + and nextType ~= 'setfield' + and nextType ~= 'setindex' then + return 'strong' + else + weak = true + end + end + end + if weak then + return 'weak' + else + return nil + end +end + +local function isMyTable(loc) + local value = loc.value + if value and value.type == 'table' then + return true + end + return false +end + +local function isClose(source) + if not source.attrs then + return false + end + for _, attr in ipairs(source.attrs) do + if attr[1] == 'close' then + return true + end + end + return false +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'local', function (source) + local name = source[1] + if name == '_' + or name == ast.ENVMode then + return + end + if isClose(source) then + return + end + local data = hasGet(source) + if data == 'strong' then + return + end + if data == 'weak' then + if not isMyTable(source) then + return + end + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LOCAL', name), + } + if source.ref then + for _, ref in ipairs(source.ref) do + callback { + start = ref.start, + finish = ref.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LOCAL', name), + } + end + end + end) +end diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua new file mode 100644 index 00000000..74cc08e7 --- /dev/null +++ b/script/core/diagnostics/unused-vararg.lua @@ -0,0 +1,31 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'function', function (source) + local args = source.args + if not args then + return + end + + for _, arg in ipairs(args) do + if arg.type == '...' then + if not arg.ref then + callback { + start = arg.start, + finish = arg.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_VARARG, + } + end + end + end + end) +end diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua new file mode 100644 index 00000000..7392b337 --- /dev/null +++ b/script/core/document-symbol.lua @@ -0,0 +1,307 @@ +local await = require 'await' +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local util = require 'utility' + +local function buildName(source, text) + if source.type == 'setmethod' + or source.type == 'getmethod' then + if source.method then + return text:sub(source.start, source.method.finish) + end + end + if source.type == 'setfield' + or source.type == 'tablefield' + or source.type == 'getfield' then + if source.field then + return text:sub(source.start, source.field.finish) + end + end + return text:sub(source.start, source.finish) +end + +local function buildFunctionParams(func) + if not func.args then + return '' + end + local params = {} + for i, arg in ipairs(func.args) do + if arg.type == '...' then + params[i] = '...' + else + params[i] = arg[1] or '' + end + end + return table.concat(params, ', ') +end + +local function buildFunction(source, text, symbols) + local name = buildName(source, text) + local func = source.value + if source.type == 'tablefield' + or source.type == 'setfield' then + source = source.field + if not source then + return + end + end + local range, kind + if func.start > source.finish then + -- a = function() + range = { source.start, func.finish } + else + -- function f() + range = { func.start, func.finish } + end + if source.type == 'setmethod' then + kind = define.SymbolKind.Method + else + kind = define.SymbolKind.Function + end + symbols[#symbols+1] = { + name = name, + detail = ('function (%s)'):format(buildFunctionParams(func)), + kind = kind, + range = range, + selectionRange = { source.start, source.finish }, + valueRange = { func.start, func.finish }, + } +end + +local function buildTable(tbl) + local buf = {} + for i = 1, 3 do + local field = tbl[i] + if not field then + break + end + if field.type == 'tablefield' then + buf[i] = ('%s'):format(field.field[1]) + end + end + return table.concat(buf, ', ') +end + +local function buildValue(source, text, symbols) + local name = buildName(source, text) + local range, sRange, valueRange, kind + local details = {} + if source.type == 'local' then + if source.parent.type == 'funcargs' then + details[1] = 'param' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Constant + else + details[1] = 'local' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Variable + end + elseif source.type == 'setlocal' then + details[1] = 'setlocal' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Variable + elseif source.type == 'setglobal' then + details[1] = 'global' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Class + elseif source.type == 'tablefield' then + if not source.field then + return + end + details[1] = 'field' + range = { source.field.start, source.field.finish } + sRange = { source.field.start, source.field.finish } + kind = define.SymbolKind.Property + elseif source.type == 'setfield' then + if not source.field then + return + end + details[1] = 'field' + range = { source.field.start, source.field.finish } + sRange = { source.field.start, source.field.finish } + kind = define.SymbolKind.Field + else + return + end + if source.value then + local literal = source.value[1] + if source.value.type == 'boolean' then + details[2] = ' boolean' + if literal ~= nil then + details[3] = ' = ' + details[4] = util.viewLiteral(source.value[1]) + end + elseif source.value.type == 'string' then + details[2] = ' string' + if literal ~= nil then + details[3] = ' = ' + details[4] = util.viewLiteral(source.value[1]) + end + elseif source.value.type == 'number' then + details[2] = ' number' + if literal ~= nil then + details[3] = ' = ' + details[4] = util.viewLiteral(source.value[1]) + end + elseif source.value.type == 'table' then + details[2] = ' {' + details[3] = buildTable(source.value) + details[4] = '}' + valueRange = { source.value.start, source.value.finish } + elseif source.value.type == 'select' then + if source.value.vararg and source.value.vararg.type == 'call' then + valueRange = { source.value.start, source.value.finish } + end + end + range = { range[1], source.value.finish } + end + symbols[#symbols+1] = { + name = name, + detail = table.concat(details), + kind = kind, + range = range, + selectionRange = sRange, + valueRange = valueRange, + } +end + +local function buildSet(source, text, used, symbols) + local value = source.value + if value and value.type == 'function' then + used[value] = true + buildFunction(source, text, symbols) + else + buildValue(source, text, symbols) + end +end + +local function buildAnonymousFunction(source, text, used, symbols) + if used[source] then + return + end + used[source] = true + local head = '' + local parent = source.parent + if parent.type == 'return' then + head = 'return ' + elseif parent.type == 'callargs' then + local call = parent.parent + local node = call.node + head = buildName(node, text) .. ' -> ' + end + symbols[#symbols+1] = { + name = '', + detail = ('%sfunction (%s)'):format(head, buildFunctionParams(source)), + kind = define.SymbolKind.Function, + range = { source.start, source.finish }, + selectionRange = { source.start, source.start }, + valueRange = { source.start, source.finish }, + } +end + +local function buildSource(source, text, used, symbols) + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'setfield' + or source.type == 'setmethod' + or source.type == 'tablefield' then + await.delay() + buildSet(source, text, used, symbols) + elseif source.type == 'function' then + await.delay() + buildAnonymousFunction(source, text, used, symbols) + end +end + +local function makeSymbol(uri) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local text = files.getText(uri) + local symbols = {} + local used = {} + guide.eachSource(ast.ast, function (source) + buildSource(source, text, used, symbols) + end) + + return symbols +end + +local function packChild(ranges, symbols) + await.delay() + table.sort(symbols, function (a, b) + return a.selectionRange[1] < b.selectionRange[1] + end) + await.delay() + local root = { + valueRange = { 0, math.maxinteger }, + children = {}, + } + local stacks = { root } + for _, symbol in ipairs(symbols) do + local parent = stacks[#stacks] + -- 移除已经超出生效范围的区间 + while symbol.selectionRange[1] > parent.valueRange[2] do + stacks[#stacks] = nil + parent = stacks[#stacks] + end + -- 向后看,找出当前可能生效的区间 + local nextRange + while #ranges > 0 + and symbol.selectionRange[1] >= ranges[#ranges].valueRange[1] do + if symbol.selectionRange[1] <= ranges[#ranges].valueRange[2] then + nextRange = ranges[#ranges] + end + ranges[#ranges] = nil + end + if nextRange then + stacks[#stacks+1] = nextRange + parent = nextRange + end + if parent == symbol then + -- function f() end 的情况,selectionRange 在 valueRange 内部, + -- 当前区间置为上一层 + parent = stacks[#stacks-1] + end + -- 把自己放到当前区间中 + if not parent.children then + parent.children = {} + end + parent.children[#parent.children+1] = symbol + end + return root.children +end + +local function packSymbols(symbols) + local ranges = {} + for _, symbol in ipairs(symbols) do + if symbol.valueRange then + ranges[#ranges+1] = symbol + end + end + await.delay() + table.sort(ranges, function (a, b) + return a.valueRange[1] > b.valueRange[1] + end) + -- 处理嵌套 + return packChild(ranges, symbols) +end + +return function (uri) + local symbols = makeSymbol(uri) + if not symbols then + return nil + end + + local packedSymbols = packSymbols(symbols) + + return packedSymbols +end diff --git a/script/core/find-source.lua b/script/core/find-source.lua new file mode 100644 index 00000000..32de102c --- /dev/null +++ b/script/core/find-source.lua @@ -0,0 +1,14 @@ +local guide = require 'parser.guide' + +return function (ast, offset, accept) + local len = math.huge + local result + guide.eachSourceContain(ast.ast, offset, function (source) + local start, finish = guide.getStartFinish(source) + if finish - start < len and accept[source.type] then + result = source + len = finish - start + end + end) + return result +end diff --git a/script/core/highlight.lua b/script/core/highlight.lua new file mode 100644 index 00000000..d7671df2 --- /dev/null +++ b/script/core/highlight.lua @@ -0,0 +1,252 @@ +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm' +local define = require 'proto.define' +local findSource = require 'core.find-source' + +local function eachRef(source, callback) + local results = guide.requestReference(source) + for i = 1, #results do + callback(results[i]) + end +end + +local function eachField(source, callback) + local isGlobal = guide.isGlobal(source) + local results = guide.requestReference(source) + for i = 1, #results do + local res = results[i] + if isGlobal == guide.isGlobal(res) then + callback(res) + end + end +end + +local function eachLocal(source, callback) + callback(source) + if source.ref then + for _, ref in ipairs(source.ref) do + callback(ref) + end + end +end + +local function find(source, uri, callback) + if source.type == 'local' then + eachLocal(source, callback) + elseif source.type == 'getlocal' + or source.type == 'setlocal' then + eachLocal(source.node, callback) + elseif source.type == 'field' + or source.type == 'method' then + eachField(source.parent, callback) + elseif source.type == 'getindex' + or source.type == 'setindex' + or source.type == 'tableindex' then + eachField(source, callback) + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + eachField(source, callback) + elseif source.type == 'goto' + or source.type == 'label' then + eachRef(source, callback) + elseif source.type == 'string' + and source.parent.index == source then + eachField(source.parent, callback) + elseif source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' + or source.type == 'nil' then + callback(source) + end +end + +local function checkInIf(source, text, offset) + -- 检查 end + local endA = source.finish - #'end' + 1 + local endB = source.finish + if offset >= endA + and offset <= endB + and text:sub(endA, endB) == 'end' then + return true + end + -- 检查每个子模块 + for _, block in ipairs(source) do + for i = 1, #block.keyword, 2 do + local start = block.keyword[i] + local finish = block.keyword[i+1] + if offset >= start and offset <= finish then + return true + end + end + end + return false +end + +local function makeIf(source, text, callback) + -- end + local endA = source.finish - #'end' + 1 + local endB = source.finish + if text:sub(endA, endB) == 'end' then + callback(endA, endB) + end + -- 每个子模块 + for _, block in ipairs(source) do + for i = 1, #block.keyword, 2 do + local start = block.keyword[i] + local finish = block.keyword[i+1] + callback(start, finish) + end + end + return false +end + +local function findKeyWord(ast, text, offset, callback) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'do' + or source.type == 'function' + or source.type == 'loop' + or source.type == 'in' + or source.type == 'while' + or source.type == 'repeat' then + local ok + for i = 1, #source.keyword, 2 do + local start = source.keyword[i] + local finish = source.keyword[i+1] + if offset >= start and offset <= finish then + ok = true + break + end + end + if ok then + for i = 1, #source.keyword, 2 do + local start = source.keyword[i] + local finish = source.keyword[i+1] + callback(start, finish) + end + end + elseif source.type == 'if' then + local ok = checkInIf(source, text, offset) + if ok then + makeIf(source, text, callback) + end + end + end) +end + +local accept = { + ['label'] = true, + ['goto'] = true, + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['field'] = true, + ['method'] = true, + ['tablefield'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['string'] = true, + ['boolean'] = true, + ['number'] = true, + ['nil'] = true, +} + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local text = files.getText(uri) + local results = {} + local mark = {} + + local source = findSource(ast, offset, accept) + if source then + find(source, uri, function (target) + local kind + if target.type == 'getfield' then + target = target.field + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setfield' + or target.type == 'tablefield' then + target = target.field + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getmethod' then + target = target.method + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setmethod' then + target = target.method + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getindex' then + target = target.index + kind = define.DocumentHighlightKind.Read + elseif target.type == 'field' then + if target.parent.type == 'getfield' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'method' then + if target.parent.type == 'getmethod' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'index' then + if target.parent.type == 'getindex' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'index' then + if target.parent.type == 'getindex' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'setindex' + or target.type == 'tableindex' then + target = target.index + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getlocal' + or target.type == 'getglobal' + or target.type == 'goto' then + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setlocal' + or target.type == 'local' + or target.type == 'setglobal' + or target.type == 'label' then + kind = define.DocumentHighlightKind.Write + elseif target.type == 'string' + or target.type == 'boolean' + or target.type == 'number' + or target.type == 'nil' then + kind = define.DocumentHighlightKind.Text + else + return + end + if mark[target] then + return + end + mark[target] = true + results[#results+1] = { + start = target.start, + finish = target.finish, + kind = kind, + } + end) + end + + findKeyWord(ast, text, offset, function (start, finish) + results[#results+1] = { + start = start, + finish = finish, + kind = define.DocumentHighlightKind.Write + } + end) + + if #results == 0 then + return nil + end + return results +end diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua new file mode 100644 index 00000000..9cd19f02 --- /dev/null +++ b/script/core/hover/arg.lua @@ -0,0 +1,71 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local function optionalArg(arg) + if not arg.bindDocs then + return false + end + local name = arg[1] + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.param' and doc.param[1] == name then + return doc.optional + end + end +end + +local function asFunction(source, oop) + if not source.args then + return '' + end + local args = {} + for i = 1, #source.args do + local arg = source.args[i] + local name = arg.name or guide.getName(arg) + if name then + args[i] = ('%s%s: %s'):format( + name, + optionalArg(arg) and '?' or '', + vm.getInferType(arg) + ) + else + args[i] = ('%s'):format(vm.getInferType(arg)) + end + end + local methodDef + local parent = source.parent + if parent and parent.type == 'setmethod' then + methodDef = true + end + if not methodDef and oop then + return table.concat(args, ', ', 2) + else + return table.concat(args, ', ') + end +end + +local function asDocFunction(source) + if not source.args then + return '' + end + local args = {} + for i = 1, #source.args do + local arg = source.args[i] + local name = arg.name[1] + args[i] = ('%s%s: %s'):format( + name, + arg.optional and '?' or '', + vm.getInferType(arg.extends) + ) + end + return table.concat(args, ', ') +end + +return function (source, oop) + if source.type == 'function' then + return asFunction(source, oop) + end + if source.type == 'doc.type.function' then + return asDocFunction(source) + end + return '' +end diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua new file mode 100644 index 00000000..7d89ee6c --- /dev/null +++ b/script/core/hover/description.lua @@ -0,0 +1,204 @@ +local vm = require 'vm' +local ws = require 'workspace' +local furi = require 'file-uri' +local files = require 'files' +local guide = require 'parser.guide' +local markdown = require 'provider.markdown' +local config = require 'config' +local lang = require 'language' + +local function asStringInRequire(source, literal) + local rootPath = ws.path or '' + local parent = source.parent + if parent and parent.type == 'callargs' then + local result, searchers + local call = parent.parent + local func = call.node + local libName = vm.getLibraryName(func) + if not libName then + return + end + if libName == 'require' then + result, searchers = ws.findUrisByRequirePath(literal) + elseif libName == 'dofile' + or libName == 'loadfile' then + result = ws.findUrisByFilePath(literal) + end + if result and #result > 0 then + for i, uri in ipairs(result) do + local searcher = searchers and furi.decode(searchers[uri]) + uri = files.getOriginUri(uri) + local path = furi.decode(uri) + if files.eq(path:sub(1, #rootPath), rootPath) then + path = path:sub(#rootPath + 1) + end + path = path:gsub('^[/\\]*', '') + if vm.isMetaFile(uri) then + result[i] = ('* [[meta]](%s)'):format(uri) + elseif searcher then + searcher = searcher:sub(#rootPath + 1) + searcher = ws.normalize(searcher) + result[i] = ('* [%s](%s) %s'):format(path, uri, lang.script('HOVER_USE_LUA_PATH', searcher)) + else + result[i] = ('* [%s](%s)'):format(path, uri) + end + end + table.sort(result) + local md = markdown() + md:add('md', table.concat(result, '\n')) + return md:string() + end + end +end + +local function asStringView(source, literal) + -- 内部包含转义符? + local rawLen = source.finish - source.start - 2 * #source[2] + 1 + if config.config.hover.viewString + and (source[2] == '"' or source[2] == "'") + and rawLen > #literal then + local view = literal + local max = config.config.hover.viewStringMax + if #view > max then + view = view:sub(1, max) .. '...' + end + local md = markdown() + md:add('txt', view) + return md:string() + end +end + +local function asString(source) + local literal = guide.getLiteral(source) + if type(literal) ~= 'string' then + return nil + end + return asStringInRequire(source, literal) + or asStringView(source, literal) +end + +local function getBindComment(docGroup, base) + local lines = {} + for _, doc in ipairs(docGroup) do + if doc.type == 'doc.comment' then + lines[#lines+1] = doc.comment.text:sub(2) + elseif #lines > 0 and not base then + break + elseif doc == base then + break + else + lines = {} + end + end + if #lines == 0 then + return nil + end + return table.concat(lines, '\n') +end + +local function buildEnumChunk(docType, name) + local enums = vm.getDocEnums(docType) + if #enums == 0 then + return + end + local types = {} + for _, tp in ipairs(docType.types) do + types[#types+1] = tp[1] + end + local lines = {} + lines[#lines+1] = ('%s: %s'):format(name, table.concat(types)) + for _, enum in ipairs(enums) do + lines[#lines+1] = (' %s %s%s'):format( + (enum.default and '->') + or (enum.additional and '+>') + or ' |', + enum[1], + enum.comment and (' -- %s'):format(enum.comment) or '' + ) + end + return table.concat(lines, '\n') +end + +local function getBindEnums(docGroup) + local mark = {} + local chunks = {} + local returnIndex = 0 + for _, doc in ipairs(docGroup) do + if doc.type == 'doc.param' then + local name = doc.param[1] + if mark[name] then + goto CONTINUE + end + mark[name] = true + chunks[#chunks+1] = buildEnumChunk(doc.extends, name) + elseif doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returnIndex = returnIndex + 1 + local name = rtn.name and rtn.name[1] or ('(return %d)'):format(returnIndex) + if mark[name] then + goto CONTINUE + end + mark[name] = true + chunks[#chunks+1] = buildEnumChunk(rtn, name) + end + end + ::CONTINUE:: + end + if #chunks == 0 then + return nil + end + return table.concat(chunks, '\n\n') +end + +local function tryDocFieldUpComment(source) + if source.type ~= 'doc.field' then + return + end + if not source.bindGroup then + return + end + local comment = getBindComment(source.bindGroup, source) + return comment +end + +local function tryDocComment(source) + if not source.bindDocs then + return + end + local comment = getBindComment(source.bindDocs) + local enums = getBindEnums(source.bindDocs) + local md = markdown() + if comment then + md:add('md', comment) + end + if enums then + md:add('lua', enums) + end + return md:string() +end + +local function tryDocOverloadToComment(source) + if source.type ~= 'doc.type.function' then + return + end + local doc = source.parent + if doc.type ~= 'doc.overload' + or not doc.bindSources then + return + end + for _, src in ipairs(doc.bindSources) do + local md = tryDocComment(src) + if md then + return md + end + end +end + +return function (source) + if source.type == 'string' then + return asString(source) + end + return tryDocOverloadToComment(source) + or tryDocFieldUpComment(source) + or tryDocComment(source) +end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua new file mode 100644 index 00000000..96e01ab5 --- /dev/null +++ b/script/core/hover/init.lua @@ -0,0 +1,164 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local getLabel = require 'core.hover.label' +local getDesc = require 'core.hover.description' +local util = require 'utility' +local findSource = require 'core.find-source' +local lang = require 'language' + +local function eachFunctionAndOverload(value, callback) + callback(value) + if not value.bindDocs then + return + end + for _, doc in ipairs(value.bindDocs) do + if doc.type == 'doc.overload' then + callback(doc.overload) + end + end +end + +local function getHoverAsFunction(source) + local values = vm.getDefs(source, 'deep') + local desc = getDesc(source) + local labels = {} + local defs = 0 + local protos = 0 + local other = 0 + local oop = source.type == 'method' + or source.type == 'getmethod' + or source.type == 'setmethod' + local mark = {} + for _, def in ipairs(values) do + def = guide.getObjectValue(def) or def + if def.type == 'function' + or def.type == 'doc.type.function' then + eachFunctionAndOverload(def, function (value) + if mark[value] then + return + end + mark[value] =true + local label = getLabel(value, oop) + if label then + defs = defs + 1 + labels[label] = (labels[label] or 0) + 1 + if labels[label] == 1 then + protos = protos + 1 + end + end + desc = desc or getDesc(value) + end) + elseif def.type == 'table' + or def.type == 'boolean' + or def.type == 'string' + or def.type == 'number' then + other = other + 1 + desc = desc or getDesc(def) + end + end + + if defs == 1 and other == 0 then + return { + label = next(labels), + source = source, + description = desc, + } + end + + local lines = {} + if defs > 1 then + lines[#lines+1] = lang.script('HOVER_MULTI_DEF_PROTO', defs, protos) + end + if other > 0 then + lines[#lines+1] = lang.script('HOVER_MULTI_PROTO_NOT_FUNC', other) + end + if defs > 1 then + for label, count in util.sortPairs(labels) do + lines[#lines+1] = ('(%d) %s'):format(count, label) + end + else + lines[#lines+1] = next(labels) + end + local label = table.concat(lines, '\n') + return { + label = label, + source = source, + description = desc, + } +end + +local function getHoverAsValue(source) + local oop = source.type == 'method' + or source.type == 'getmethod' + or source.type == 'setmethod' + local label = getLabel(source, oop) + local desc = getDesc(source) + if not desc then + local values = vm.getDefs(source, 'deep') + for _, def in ipairs(values) do + desc = getDesc(def) + if desc then + break + end + end + end + return { + label = label, + source = source, + description = desc, + } +end + +local function getHoverAsDocName(source) + local label = getLabel(source) + local desc = getDesc(source) + return { + label = label, + source = source, + description = desc, + } +end + +local function getHover(source) + if source.type == 'doc.type.name' then + return getHoverAsDocName(source) + end + local isFunction = vm.hasInferType(source, 'function', 'deep') + if isFunction then + return getHoverAsFunction(source) + else + return getHoverAsValue(source) + end +end + +local accept = { + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['field'] = true, + ['method'] = true, + ['string'] = true, + ['number'] = true, + ['doc.type.name'] = true, +} + +local function getHoverByUri(uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local source = findSource(ast, offset, accept) + if not source then + return nil + end + local hover = getHover(source) + return hover +end + +return { + get = getHover, + byUri = getHoverByUri, +} diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua new file mode 100644 index 00000000..d785bc27 --- /dev/null +++ b/script/core/hover/label.lua @@ -0,0 +1,211 @@ +local buildName = require 'core.hover.name' +local buildArg = require 'core.hover.arg' +local buildReturn = require 'core.hover.return' +local buildTable = require 'core.hover.table' +local vm = require 'vm' +local util = require 'utility' +local guide = require 'parser.guide' +local lang = require 'language' +local config = require 'config' +local files = require 'files' + +local function asFunction(source, oop) + local name = buildName(source, oop) + local arg = buildArg(source, oop) + local rtn = buildReturn(source) + local lines = {} + lines[1] = ('function %s(%s)'):format(name, arg) + lines[2] = rtn + return table.concat(lines, '\n') +end + +local function asDocFunction(source) + local name = buildName(source) + local arg = buildArg(source) + local rtn = buildReturn(source) + local lines = {} + lines[1] = ('function %s(%s)'):format(name, arg) + lines[2] = rtn + return table.concat(lines, '\n') +end + +local function asDocTypeName(source) + for _, doc in ipairs(vm.getDocTypes(source[1])) do + if doc.type == 'doc.class.name' then + return 'class ' .. source[1] + end + if doc.type == 'doc.alias.name' then + local extends = doc.parent.extends + return lang.script('HOVER_EXTENDS', vm.getInferType(extends)) + end + end +end + +local function asValue(source, title) + local name = buildName(source) + local infers = vm.getInfers(source, 'deep') + local type = vm.getInferType(source, 'deep') + local class = vm.getClass(source, 'deep') + local literal = vm.getInferLiteral(source, 'deep') + local cont + if type ~= 'string' and not type:find('%[%]$') then + if #vm.getFields(source, 'deep') > 0 + or vm.hasInferType(source, 'table', 'deep') then + cont = buildTable(source) + end + end + local pack = {} + pack[#pack+1] = title + pack[#pack+1] = name .. ':' + if cont and type == 'table' then + type = nil + end + if class then + pack[#pack+1] = class + else + pack[#pack+1] = type + end + if literal then + pack[#pack+1] = '=' + pack[#pack+1] = literal + end + if cont then + pack[#pack+1] = cont + end + return table.concat(pack, ' ') +end + +local function asLocal(source) + return asValue(source, 'local') +end + +local function asGlobal(source) + return asValue(source, 'global') +end + +local function isGlobalField(source) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + if source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' then + local node = source.node + if node.type == 'setglobal' + or node.type == 'getglobal' then + return true + end + return isGlobalField(node) + elseif source.type == 'tablefield' then + local parent = source.parent + if parent.type == 'setglobal' + or parent.type == 'getglobal' then + return true + end + return isGlobalField(parent) + else + return false + end +end + +local function asField(source) + if isGlobalField(source) then + return asGlobal(source) + end + return asValue(source, 'field') +end + +local function asDocField(source) + local name = source.field[1] + local class + for _, doc in ipairs(source.bindGroup) do + if doc.type == 'doc.class' then + class = doc + break + end + end + if not class then + return ('field ?.%s: %s'):format( + name, + vm.getInferType(source.extends) + ) + end + return ('field %s.%s: %s'):format( + class.class[1], + name, + vm.getInferType(source.extends) + ) +end + +local function asString(source) + local str = source[1] + if type(str) ~= 'string' then + return '' + end + local len = #str + local charLen = util.utf8Len(str, 1, -1) + if len == charLen then + return lang.script('HOVER_STRING_BYTES', len) + else + return lang.script('HOVER_STRING_CHARACTERS', len, charLen) + end +end + +local function formatNumber(n) + local str = ('%.10f'):format(n) + str = str:gsub('%.?0*$', '') + return str +end + +local function asNumber(source) + if not config.config.hover.viewNumber then + return nil + end + local num = source[1] + if type(num) ~= 'number' then + return nil + end + local uri = guide.getUri(source) + local text = files.getText(uri) + if not text then + return nil + end + local raw = text:sub(source.start, source.finish) + if not raw or not raw:find '[^%-%d%.]' then + return nil + end + return formatNumber(num) +end + +return function (source, oop) + if source.type == 'function' then + return asFunction(source, oop) + elseif source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' then + return asLocal(source) + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + return asGlobal(source) + elseif source.type == 'getfield' + or source.type == 'setfield' + or source.type == 'getmethod' + or source.type == 'setmethod' + or source.type == 'tablefield' + or source.type == 'field' + or source.type == 'method' then + return asField(source) + elseif source.type == 'string' then + return asString(source) + elseif source.type == 'number' then + return asNumber(source) + elseif source.type == 'doc.type.function' then + return asDocFunction(source) + elseif source.type == 'doc.type.name' then + return asDocTypeName(source) + elseif source.type == 'doc.field' then + return asDocField(source) + end +end diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua new file mode 100644 index 00000000..9ad32e09 --- /dev/null +++ b/script/core/hover/name.lua @@ -0,0 +1,101 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local buildName + +local function asLocal(source) + local name = guide.getName(source) + if not source.attrs then + return name + end + local label = {} + label[#label+1] = name + for _, attr in ipairs(source.attrs) do + label[#label+1] = ('<%s>'):format(attr[1]) + end + return table.concat(label, ' ') +end + +local function asField(source, oop) + local class + if source.node.type ~= 'getglobal' then + class = vm.getClass(source.node, 'deep') + end + local node = class or guide.getName(source.node) or '?' + local method = guide.getName(source) + if oop then + return ('%s:%s'):format(node, method) + else + return ('%s.%s'):format(node, method) + end +end + +local function asTableField(source) + if not source.field then + return + end + return guide.getName(source.field) +end + +local function asGlobal(source) + return guide.getName(source) +end + +local function asDocFunction(source) + local doc = guide.getParentType(source, 'doc.type') + or guide.getParentType(source, 'doc.overload') + if not doc or not doc.bindSources then + return '' + end + for _, src in ipairs(doc.bindSources) do + local name = buildName(src) + if name ~= '' then + return name + end + end + return '' +end + +local function asDocField(source) + return source.field[1] +end + +function buildName(source, oop) + if oop == nil then + oop = source.type == 'setmethod' + or source.type == 'getmethod' + end + if source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' then + return asLocal(source) or '' + end + if source.type == 'setglobal' + or source.type == 'getglobal' then + return asGlobal(source) or '' + end + if source.type == 'setmethod' + or source.type == 'getmethod' then + return asField(source, true) or '' + end + if source.type == 'setfield' + or source.type == 'getfield' then + return asField(source, oop) or '' + end + if source.type == 'tablefield' then + return asTableField(source) or '' + end + if source.type == 'doc.type.function' then + return asDocFunction(source) + end + if source.type == 'doc.field' then + return asDocField(source) + end + local parent = source.parent + if parent then + return buildName(parent, oop) + end + return '' +end + +return buildName diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua new file mode 100644 index 00000000..3829dbed --- /dev/null +++ b/script/core/hover/return.lua @@ -0,0 +1,125 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local function mergeTypes(returns) + if type(returns) == 'string' then + return returns + end + return guide.mergeTypes(returns) +end + +local function getReturnDualByDoc(source) + local docs = source.bindDocs + if not docs then + return + end + local dual + for _, doc in ipairs(docs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + if not dual then + dual = {} + end + dual[#dual+1] = { rtn } + end + end + end + return dual +end + +local function getReturnDualByGrammar(source) + if not source.returns then + return nil + end + local dual + for _, rtn in ipairs(source.returns) do + if not dual then + dual = {} + end + for n = 1, #rtn do + if not dual[n] then + dual[n] = {} + end + dual[n][#dual[n]+1] = rtn[n] + end + end + return dual +end + +local function asFunction(source) + local dual = getReturnDualByDoc(source) + or getReturnDualByGrammar(source) + if not dual then + return + end + local returns = {} + for i, rtn in ipairs(dual) do + local line = {} + local types = {} + if i == 1 then + line[#line+1] = ' -> ' + else + line[#line+1] = ('% 3d. '):format(i) + end + for n = 1, #rtn do + local values = vm.getInfers(rtn[n]) + for _, value in ipairs(values) do + if value.type then + for tp in value.type:gmatch '[^|]+' do + types[#types+1] = tp + end + end + end + end + if #types > 0 or rtn[1] then + local tp = mergeTypes(types) or 'any' + if rtn[1].name then + line[#line+1] = ('%s%s: %s'):format( + rtn[1].name[1], + rtn[1].optional and '?' or '', + tp + ) + else + line[#line+1] = ('%s%s'):format( + tp, + rtn[1].optional and '?' or '' + ) + end + else + break + end + returns[i] = table.concat(line) + end + if #returns == 0 then + return nil + end + return table.concat(returns, '\n') +end + +local function asDocFunction(source) + if not source.returns or #source.returns == 0 then + return nil + end + local returns = {} + for i, rtn in ipairs(source.returns) do + local rtnText = ('%s%s'):format( + vm.getInferType(rtn), + rtn.optional and '?' or '' + ) + if i == 1 then + returns[#returns+1] = (' -> %s'):format(rtnText) + else + returns[#returns+1] = ('% 3d. %s'):format(i, rtnText) + end + end + return table.concat(returns, '\n') +end + +return function (source) + if source.type == 'function' then + return asFunction(source) + end + if source.type == 'doc.type.function' then + return asDocFunction(source) + end +end diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua new file mode 100644 index 00000000..02be5271 --- /dev/null +++ b/script/core/hover/table.lua @@ -0,0 +1,257 @@ +local vm = require 'vm' +local util = require 'utility' +local guide = require 'parser.guide' +local config = require 'config' +local lang = require 'language' + +local function getKey(src) + local key = vm.getKeyName(src) + if not key or #key <= 2 then + if not src.index then + return '[any]' + end + local class = vm.getClass(src.index) + if class then + return ('[%s]'):format(class) + end + local tp = vm.getInferType(src.index) + if tp then + return ('[%s]'):format(tp) + end + return '[any]' + end + local ktype = key:sub(1, 2) + key = key:sub(3) + if ktype == 's|' then + if key:match '^[%a_][%w_]*$' then + return key + else + return ('[%s]'):format(util.viewLiteral(key)) + end + end + return ('[%s]'):format(key) +end + +local function getFieldFast(src) + local value = guide.getObjectValue(src) or src + if not value then + return 'any' + end + if value.type == 'boolean' then + return value.type, util.viewLiteral(value[1]) + end + if value.type == 'number' + or value.type == 'integer' then + if math.tointeger(value[1]) then + if config.config.runtime.version == 'Lua 5.3' + or config.config.runtime.version == 'Lua 5.4' then + return 'integer', util.viewLiteral(value[1]) + end + end + return value.type, util.viewLiteral(value[1]) + end + if value.type == 'table' + or value.type == 'function' then + return value.type + end + if value.type == 'string' then + local literal = value[1] + if type(literal) == 'string' and #literal >= 50 then + literal = literal:sub(1, 47) .. '...' + end + return value.type, util.viewLiteral(literal) + end +end + +local function getFieldFull(src) + local tp = vm.getInferType(src) + --local class = vm.getClass(src) + local literal = vm.getInferLiteral(src) + if type(literal) == 'string' and #literal >= 50 then + literal = literal:sub(1, 47) .. '...' + end + return tp, literal +end + +local function getField(src, timeUp, mark, key) + if src.type == 'table' + or src.type == 'function' then + return nil + end + if src.parent then + if src.type == 'string' + or src.type == 'boolean' + or src.type == 'number' + or src.type == 'integer' then + if src.parent.type == 'tableindex' + or src.parent.type == 'setindex' + or src.parent.type == 'getindex' then + if src.parent.index == src then + src = src.parent + end + end + end + end + local tp, literal + tp, literal = getFieldFast(src) + if tp then + return tp, literal + end + if timeUp or mark[key] then + return nil + end + mark[key] = true + tp, literal = getFieldFull(src) + if tp then + return tp, literal + end + return nil +end + +local function buildAsHash(classes, literals) + local keys = {} + for k in pairs(classes) do + keys[#keys+1] = k + end + table.sort(keys) + local lines = {} + lines[#lines+1] = '{' + for _, key in ipairs(keys) do + local class = classes[key] + local literal = literals[key] + if literal then + lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + else + lines[#lines+1] = (' %s: %s,'):format(key, class) + end + end + lines[#lines+1] = '}' + return table.concat(lines, '\n') +end + +local function buildAsConst(classes, literals) + local keys = {} + for k in pairs(classes) do + keys[#keys+1] = k + end + table.sort(keys, function (a, b) + return tonumber(literals[a]) < tonumber(literals[b]) + end) + local lines = {} + lines[#lines+1] = '{' + for _, key in ipairs(keys) do + local class = classes[key] + local literal = literals[key] + if literal then + lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + else + lines[#lines+1] = (' %s: %s,'):format(key, class) + end + end + lines[#lines+1] = '}' + return table.concat(lines, '\n') +end + +local function mergeLiteral(literals) + local results = {} + local mark = {} + for _, value in ipairs(literals) do + if not mark[value] then + mark[value] = true + results[#results+1] = value + end + end + if #results == 0 then + return nil + end + table.sort(results) + return table.concat(results, '|') +end + +local function mergeTypes(types) + local results = {} + local mark = { + -- 讲道理table的keyvalue不会是nil + ['nil'] = true, + } + for _, tv in ipairs(types) do + for tp in tv:gmatch '[^|]+' do + if not mark[tp] then + mark[tp] = true + results[#results+1] = tp + end + end + end + return guide.mergeTypes(results) +end + +local function clearClasses(classes) + classes['[nil]'] = nil + classes['[any]'] = nil + classes['[string]'] = nil +end + +return function (source) + local literals = {} + local classes = {} + local clock = os.clock() + local timeUp + local mark = {} + local fields = vm.getFields(source, 'deep') + local keyCount = 0 + for _, src in ipairs(fields) do + local key = getKey(src) + if not key then + goto CONTINUE + end + if not classes[key] then + classes[key] = {} + keyCount = keyCount + 1 + end + if not literals[key] then + literals[key] = {} + end + if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then + timeUp = true + end + local class, literal = getField(src, timeUp, mark, key) + if literal == 'nil' then + literal = nil + end + classes[key][#classes[key]+1] = class + literals[key][#literals[key]+1] = literal + if keyCount >= 1000 then + break + end + ::CONTINUE:: + end + + clearClasses(classes) + + for key, class in pairs(classes) do + literals[key] = mergeLiteral(literals[key]) + classes[key] = mergeTypes(class) + end + + if not next(classes) then + return '{}' + end + + local intValue = true + for key, class in pairs(classes) do + if class ~= 'integer' or not tonumber(literals[key]) then + intValue = false + break + end + end + local result + if intValue then + result = buildAsConst(classes, literals) + else + result = buildAsHash(classes, literals) + end + if timeUp then + result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + end + return result +end diff --git a/script/core/keyword.lua b/script/core/keyword.lua new file mode 100644 index 00000000..1cbeb78d --- /dev/null +++ b/script/core/keyword.lua @@ -0,0 +1,264 @@ +local define = require 'proto.define' +local guide = require 'parser.guide' + +local keyWordMap = { + {'do', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'do .. end', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[$0 end]], + } + else + results[#results+1] = { + label = 'do .. end', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +do + $0 +end]], + } + end + return true + end, function (ast, start) + return guide.eachSourceContain(ast.ast, start, function (source) + if source.type == 'while' + or source.type == 'in' + or source.type == 'loop' then + for i = 1, #source.keyword do + if start == source.keyword[i] then + return true + end + end + end + end) + end}, + {'and'}, + {'break'}, + {'else'}, + {'elseif', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'elseif .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[$1 then]], + } + else + results[#results+1] = { + label = 'elseif .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[elseif $1 then]], + } + end + return true + end}, + {'end'}, + {'false'}, + {'for', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'for .. in', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:key, value} in ${2:pairs(${3:t})} do + $0 +end]] + } + results[#results+1] = { + label = 'for i = ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:i} = ${2:1}, ${3:10, 1} do + $0 +end]] + } + else + results[#results+1] = { + label = 'for .. in', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +for ${1:key, value} in ${2:pairs(${3:t})} do + $0 +end]] + } + results[#results+1] = { + label = 'for i = ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +for ${1:i} = ${2:1}, ${3:10, 1} do + $0 +end]] + } + end + return true + end}, + {'function', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'function ()', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +$1($2) + $0 +end]] + } + else + results[#results+1] = { + label = 'function ()', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +function $1($2) + $0 +end]] + } + end + return true + end}, + {'goto'}, + {'if', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'if .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +$1 then + $0 +end]] + } + else + results[#results+1] = { + label = 'if .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +if $1 then + $0 +end]] + } + end + return true + end}, + {'in', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'in ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:pairs(${2:t})} do + $0 +end]] + } + else + results[#results+1] = { + label = 'in ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +in ${1:pairs(${2:t})} do + $0 +end]] + } + end + return true + end}, + {'local', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'local function', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +function $1($2) + $0 +end]] + } + else + results[#results+1] = { + label = 'local function', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +local function $1($2) + $0 +end]] + } + end + return false + end}, + {'nil'}, + {'not'}, + {'or'}, + {'repeat', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'repeat .. until', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[$0 until $1]] + } + else + results[#results+1] = { + label = 'repeat .. until', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +repeat + $0 +until $1]] + } + end + return true + end}, + {'return', function (hasSpace, results) + if not hasSpace then + results[#results+1] = { + label = 'do return end', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[do return $1end]] + } + end + return false + end}, + {'then'}, + {'true'}, + {'until'}, + {'while', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'while .. do', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:true} do + $0 +end]] + } + else + results[#results+1] = { + label = 'while .. do', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +while ${1:true} do + $0 +end]] + } + end + return true + end}, +} + +return keyWordMap diff --git a/script/core/matchkey.lua b/script/core/matchkey.lua new file mode 100644 index 00000000..45c86eff --- /dev/null +++ b/script/core/matchkey.lua @@ -0,0 +1,33 @@ +return function (me, other, fast) + if me == other then + return true + end + if me == '' then + return true + end + if #me > #other then + return false + end + local lMe = me:lower() + local lOther = other:lower() + if lMe == lOther:sub(1, #lMe) then + return true + end + if fast and me:sub(1, 1) ~= other:sub(1, 1) then + return false + end + local chars = {} + for i = 1, #lOther do + local c = lOther:sub(i, i) + chars[c] = (chars[c] or 0) + 1 + end + for i = 1, #lMe do + local c = lMe:sub(i, i) + if chars[c] and chars[c] > 0 then + chars[c] = chars[c] - 1 + else + return false + end + end + return true +end diff --git a/script/core/reference.lua b/script/core/reference.lua new file mode 100644 index 00000000..d7d3df03 --- /dev/null +++ b/script/core/reference.lua @@ -0,0 +1,116 @@ +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm' +local findSource = require 'core.find-source' + +local function isValidFunction(source, offset) + -- 必须点在 `function` 这个单词上才能查找函数引用 + return offset >= source.start and offset < source.start + #'function' +end + +local function sortResults(results) + -- 先按照顺序排序 + table.sort(results, function (a, b) + local u1 = guide.getUri(a.target) + local u2 = guide.getUri(b.target) + if u1 == u2 then + return a.target.start < b.target.start + else + return u1 < u2 + end + end) + -- 如果2个结果处于嵌套状态,则取范围小的那个 + local lf, lu + for i = #results, 1, -1 do + local res = results[i].target + local f = res.finish + local uri = guide.getUri(res) + if lf and f > lf and uri == lu then + table.remove(results, i) + else + lu = uri + lf = f + end + end +end + +local accept = { + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['label'] = true, + ['goto'] = true, + ['field'] = true, + ['method'] = true, + ['setindex'] = true, + ['getindex'] = true, + ['tableindex'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['function'] = true, + + ['doc.type.name'] = true, + ['doc.class.name'] = true, + ['doc.extends.name'] = true, + ['doc.alias.name'] = true, +} + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local source = findSource(ast, offset, accept) + if not source then + return nil + end + if source.type == 'function' and not isValidFunction(source, offset) and not TEST then + return nil + end + + local results = {} + for _, src in ipairs(vm.getRefs(source, 'deep')) do + local root = guide.getRoot(src) + if not root then + goto CONTINUE + end + if vm.isMetaFile(root.uri) then + goto CONTINUE + end + if ( src.type == 'doc.class.name' + or src.type == 'doc.type.name' + ) + and source.type ~= 'doc.type.name' + and source.type ~= 'doc.class.name' then + goto CONTINUE + end + if src.type == 'setfield' + or src.type == 'getfield' + or src.type == 'tablefield' then + src = src.field + elseif src.type == 'setindex' + or src.type == 'getindex' + or src.type == 'tableindex' then + src = src.index + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + src = src.method + elseif src.type == 'table' and src.parent.type ~= 'return' then + goto CONTINUE + end + results[#results+1] = { + target = src, + uri = files.getOriginUri(root.uri), + } + ::CONTINUE:: + end + + if #results == 0 then + return nil + end + + sortResults(results) + + return results +end diff --git a/script/core/rename.lua b/script/core/rename.lua new file mode 100644 index 00000000..89298bdd --- /dev/null +++ b/script/core/rename.lua @@ -0,0 +1,448 @@ +local files = require 'files' +local vm = require 'vm' +local guide = require 'parser.guide' +local proto = require 'proto' +local define = require 'proto.define' +local util = require 'utility' +local findSource = require 'core.find-source' +local ws = require 'workspace' + +local Forcing + +local function askForcing(str) + -- TODO 总是可以替换 + do return true end + if TEST then + return true + end + if Forcing ~= nil then + return Forcing + end + local version = files.globalVersion + -- TODO + local item = proto.awaitRequest('window/showMessageRequest', { + type = define.MessageType.Warning, + message = ('[%s]不是有效的标识符,是否强制替换?'):format(str), + actions = { + { + title = '强制替换', + }, + { + title = '取消', + }, + } + }) + if version ~= files.globalVersion then + Forcing = false + proto.notify('window/showMessage', { + type = define.MessageType.Warning, + message = '文件发生了变化,替换取消。' + }) + return false + end + if not item then + Forcing = false + return false + end + if item.title == '强制替换' then + Forcing = true + return true + else + Forcing = false + return false + end +end + +local function askForMultiChange(results, newname) + -- TODO 总是可以替换 + do return true end + if TEST then + return true + end + local uris = {} + for _, result in ipairs(results) do + local uri = result.uri + if not uris[uri] then + uris[uri] = 0 + uris[#uris+1] = uri + end + uris[uri] = uris[uri] + 1 + end + if #uris <= 1 then + return true + end + + local version = files.globalVersion + -- TODO + local item = proto.awaitRequest('window/showMessageRequest', { + type = define.MessageType.Warning, + message = ('将修改 %d 个文件,共 %d 处。'):format( + #uris, + #results + ), + actions = { + { + title = '继续', + }, + { + title = '放弃', + }, + } + }) + if version ~= files.globalVersion then + proto.notify('window/showMessage', { + type = define.MessageType.Warning, + message = '文件发生了变化,替换取消。' + }) + return false + end + if item and item.title == '继续' then + local fileList = {} + for _, uri in ipairs(uris) do + fileList[#fileList+1] = ('%s (%d)'):format(uri, uris[uri]) + end + + log.debug(('Renamed [%s]\r\n%s'):format(newname, table.concat(fileList, '\r\n'))) + return true + end + return false +end + +local function trim(str) + return str:match '^%s*(%S+)%s*$' +end + +local function isValidName(str) + return str:match '^[%a_][%w_]*$' +end + +local function isValidGlobal(str) + for s in str:gmatch '[^%.]*' do + if not isValidName(trim(s)) then + return false + end + end + return true +end + +local function isValidFunctionName(str) + if isValidGlobal(str) then + return true + end + local pos = str:find(':', 1, true) + if not pos then + return false + end + return isValidGlobal(trim(str:sub(1, pos-1))) + and isValidName(trim(str:sub(pos+1))) +end + +local function isFunctionGlobalName(source) + local parent = source.parent + if parent.type ~= 'setglobal' then + return false + end + local value = parent.value + if not value.type ~= 'function' then + return false + end + return value.start <= parent.start +end + +local function renameLocal(source, newname, callback) + if isValidName(newname) then + callback(source, source.start, source.finish, newname) + return + end + if askForcing(newname) then + callback(source, source.start, source.finish, newname) + end +end + +local function renameField(source, newname, callback) + if isValidName(newname) then + callback(source, source.start, source.finish, newname) + return true + end + local parent = source.parent + if parent.type == 'setfield' + or parent.type == 'getfield' then + local dot = parent.dot + local newstr = '[' .. util.viewString(newname) .. ']' + callback(source, dot.start, source.finish, newstr) + elseif parent.type == 'tablefield' then + local newstr = '[' .. util.viewString(newname) .. ']' + callback(source, source.start, source.finish, newstr) + elseif parent.type == 'getmethod' then + if not askForcing(newname) then + return false + end + callback(source, source.start, source.finish, newname) + elseif parent.type == 'setmethod' then + local uri = guide.getUri(source) + local text = files.getText(uri) + local func = parent.value + -- function mt:name () end --> mt['newname'] = function (self) end + local newstr = string.format('%s[%s] = function ' + , text:sub(parent.start, parent.node.finish) + , util.viewString(newname) + ) + callback(source, func.start, parent.finish, newstr) + local pl = text:find('(', parent.finish, true) + if pl then + if func.args then + callback(source, pl + 1, pl, 'self, ') + else + callback(source, pl + 1, pl, 'self') + end + end + end + return true +end + +local function renameGlobal(source, newname, callback) + if isValidGlobal(newname) then + callback(source, source.start, source.finish, newname) + return true + end + if isValidFunctionName(newname) then + if not isFunctionGlobalName(source) then + askForcing(newname) + end + callback(source, source.start, source.finish, newname) + return true + end + local newstr = '_ENV[' .. util.viewString(newname) .. ']' + -- function name () end --> _ENV['newname'] = function () end + if source.value and source.value.type == 'function' + and source.value.start < source.start then + callback(source, source.value.start, source.finish, newstr .. ' = function ') + return true + end + callback(source, source.start, source.finish, newstr) + return true +end + +local function ofLocal(source, newname, callback) + renameLocal(source, newname, callback) + if source.ref then + for _, ref in ipairs(source.ref) do + renameLocal(ref, newname, callback) + end + end +end + +local function ofFieldThen(key, src, newname, callback) + if vm.getKeyName(src) ~= key then + return + end + if src.type == 'tablefield' + or src.type == 'getfield' + or src.type == 'setfield' then + src = src.field + elseif src.type == 'tableindex' + or src.type == 'getindex' + or src.type == 'setindex' then + src = src.index + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + src = src.method + end + if src.type == 'string' then + local quo = src[2] + local text = util.viewString(newname, quo) + callback(src, src.start, src.finish, text) + return + elseif src.type == 'field' + or src.type == 'method' then + local suc = renameField(src, newname, callback) + if not suc then + return + end + elseif src.type == 'setglobal' + or src.type == 'getglobal' then + local suc = renameGlobal(src, newname, callback) + if not suc then + return + end + end +end + +local function ofField(source, newname, callback) + local key = guide.getKeyName(source) + local node + if source.type == 'tablefield' + or source.type == 'tableindex' then + node = source.parent + else + node = source.node + end + for _, src in ipairs(vm.getFields(node, 'deep')) do + ofFieldThen(key, src, newname, callback) + end +end + +local function ofGlobal(source, newname, callback) + local key = guide.getKeyName(source) + for _, src in ipairs(vm.getRefs(source, 'deep')) do + ofFieldThen(key, src, newname, callback) + end +end + +local function ofLabel(source, newname, callback) + if not isValidName(newname) and not askForcing(newname)then + return false + end + for _, src in ipairs(vm.getRefs(source, 'deep')) do + callback(src, src.start, src.finish, newname) + end +end + +local function rename(source, newname, callback) + if source.type == 'label' + or source.type == 'goto' then + return ofLabel(source, newname, callback) + elseif source.type == 'local' then + return ofLocal(source, newname, callback) + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return ofLocal(source.node, newname, callback) + elseif source.type == 'field' + or source.type == 'method' + or source.type == 'index' then + return ofField(source.parent, newname, callback) + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + return ofGlobal(source, newname, callback) + elseif source.type == 'string' + or source.type == 'number' + or source.type == 'boolean' then + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + return ofField(parent, newname, callback) + end + end + return +end + +local function prepareRename(source) + if source.type == 'label' + or source.type == 'goto' + or source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'field' + or source.type == 'method' + or source.type == 'tablefield' + or source.type == 'setglobal' + or source.type == 'getglobal' then + return source, source[1] + elseif source.type == 'string' + or source.type == 'number' + or source.type == 'boolean' then + local parent = source.parent + if not parent then + return nil + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + return source, source[1] + end + return nil + end + return nil +end + +local accept = { + ['label'] = true, + ['goto'] = true, + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['field'] = true, + ['method'] = true, + ['tablefield'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['string'] = true, + ['boolean'] = true, + ['number'] = true, +} + +local m = {} + +function m.rename(uri, pos, newname) + local ast = files.getAst(uri) + if not ast then + return nil + end + local source = findSource(ast, pos, accept) + if not source then + return nil + end + local results = {} + local mark = {} + + rename(source, newname, function (target, start, finish, text) + local turi = files.getOriginUri(guide.getUri(target)) + local uid = turi .. start + if mark[uid] then + return + end + mark[uid] = true + if files.isLibrary(turi) then + return + end + results[#results+1] = { + start = start, + finish = finish, + text = text, + uri = turi, + } + end) + + if Forcing == false then + Forcing = nil + return nil + end + + if #results == 0 then + return nil + end + + if not askForMultiChange(results, newname) then + return nil + end + + return results +end + +function m.prepareRename(uri, pos) + local ast = files.getAst(uri) + if not ast then + return nil + end + local source = findSource(ast, pos, accept) + if not source then + return + end + + local res, text = prepareRename(source) + if not res then + return nil + end + + return { + start = source.start, + finish = source.finish, + text = text, + } +end + +return m diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua new file mode 100644 index 00000000..e6b35cdd --- /dev/null +++ b/script/core/semantic-tokens.lua @@ -0,0 +1,161 @@ +local files = require 'files' +local guide = require 'parser.guide' +local await = require 'await' +local define = require 'proto.define' +local vm = require 'vm' +local util = require 'utility' + +local Care = {} +Care['setglobal'] = function (source, results) + local isLib = vm.isGlobalLibraryName(source[1]) + if not isLib then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.namespace, + modifieres = define.TokenModifiers.deprecated, + } + end +end +Care['getglobal'] = function (source, results) + local isLib = vm.isGlobalLibraryName(source[1]) + if not isLib then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.namespace, + modifieres = define.TokenModifiers.deprecated, + } + end +end +Care['tablefield'] = function (source, results) + local field = source.field + if not field then + return + end + results[#results+1] = { + start = field.start, + finish = field.finish, + type = define.TokenTypes.property, + modifieres = define.TokenModifiers.declaration, + } +end +Care['getlocal'] = function (source, results) + local loc = source.node + -- 1. 值为函数的局部变量 + local hasFunc + local node = loc.node + if node then + for _, ref in ipairs(node.ref) do + local def = ref.value + if def.type == 'function' then + hasFunc = true + break + end + end + end + if hasFunc then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.interface, + modifieres = define.TokenModifiers.declaration, + } + return + end + -- 2. 对象 + if source.parent.type == 'getmethod' + and source.parent.node == source then + return + end + -- 3. 函数的参数 + if loc.parent and loc.parent.type == 'funcargs' then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.parameter, + modifieres = define.TokenModifiers.declaration, + } + return + end + -- 4. 特殊变量 + if source[1] == '_ENV' + or source[1] == 'self' then + return + end + -- 5. 其他 + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.variable, + } +end +Care['setlocal'] = Care['getlocal'] +Care['doc.return.name'] = function (source, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.parameter, + } +end + +local function buildTokens(results, text, lines) + local tokens = {} + local lastLine = 0 + local lastStartChar = 0 + for i, source in ipairs(results) do + local row, col = guide.positionOf(lines, source.start) + local start = guide.lineRange(lines, row) + local ucol = util.utf8Len(text, start, start + col - 1) + local line = row - 1 + local startChar = ucol - 1 + local deltaLine = line - lastLine + local deltaStartChar + if deltaLine == 0 then + deltaStartChar = startChar - lastStartChar + else + deltaStartChar = startChar + end + lastLine = line + lastStartChar = startChar + -- see https://microsoft.github.io/language-server-protocol/specifications/specification-3-16/#textDocument_semanticTokens + local len = i * 5 - 5 + tokens[len + 1] = deltaLine + tokens[len + 2] = deltaStartChar + tokens[len + 3] = source.finish - source.start + 1 -- length + tokens[len + 4] = source.type + tokens[len + 5] = source.modifieres or 0 + end + return tokens +end + +return function (uri, start, finish) + local ast = files.getAst(uri) + local lines = files.getLines(uri) + local text = files.getText(uri) + if not ast then + return nil + end + + local results = {} + local count = 0 + guide.eachSourceBetween(ast.ast, start, finish, function (source) + local method = Care[source.type] + if not method then + return + end + method(source, results) + count = count + 1 + if count % 100 == 0 then + await.delay() + end + end) + + table.sort(results, function (a, b) + return a.start < b.start + end) + + local tokens = buildTokens(results, text, lines) + + return tokens +end diff --git a/script/core/signature.lua b/script/core/signature.lua new file mode 100644 index 00000000..dad38924 --- /dev/null +++ b/script/core/signature.lua @@ -0,0 +1,106 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local hoverLabel = require 'core.hover.label' +local hoverDesc = require 'core.hover.description' + +local function findNearCall(uri, ast, pos) + local text = files.getText(uri) + -- 检查 `f()$` 的情况,注意要区别于 `f($` + if text:sub(pos, pos) == ')' then + return nil + end + + local nearCall + guide.eachSourceContain(ast.ast, pos, function (src) + if src.type == 'call' + or src.type == 'table' + or src.type == 'function' then + if not nearCall or nearCall.start < src.start then + nearCall = src + end + end + end) + if not nearCall then + return nil + end + if nearCall.type ~= 'call' then + return nil + end + return nearCall +end + +local function makeOneSignature(source, oop, index) + local label = hoverLabel(source, oop) + -- 去掉返回值 + label = label:gsub('%s*->.+', '') + local params = {} + local i = 0 + for start, finish in label:gmatch '[%(%)%,]%s*().-()%s*%f[%(%)%,%[%]]' do + i = i + 1 + params[i] = { + label = {start, finish-1}, + } + end + -- 不定参数 + if index > i and i > 0 then + local lastLabel = params[i].label + local text = label:sub(lastLabel[1], lastLabel[2]) + if text == '...' then + index = i + end + end + return { + label = label, + params = params, + index = index, + description = hoverDesc(source), + } +end + +local function makeSignatures(call, pos) + local node = call.node + local oop = node.type == 'method' + or node.type == 'getmethod' + or node.type == 'setmethod' + local index + local args = call.args + if args then + for i, arg in ipairs(args) do + if arg.start <= pos and arg.finish >= pos then + index = i + break + end + end + if not index then + index = #args + 1 + end + else + index = 1 + end + local signs = {} + local defs = vm.getDefs(node, 'deep') + for _, src in ipairs(defs) do + if src.type == 'function' + or src.type == 'doc.type.function' then + signs[#signs+1] = makeOneSignature(src, oop, index) + end + end + return signs +end + +return function (uri, pos) + local ast = files.getAst(uri) + if not ast then + return nil + end + local call = findNearCall(uri, ast, pos) + if not call then + return nil + end + local signs = makeSignatures(call, pos) + if not signs or #signs == 0 then + return nil + end + return signs +end diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua new file mode 100644 index 00000000..4fc6a854 --- /dev/null +++ b/script/core/workspace-symbol.lua @@ -0,0 +1,69 @@ +local files = require 'files' +local guide = require 'parser.guide' +local matchKey = require 'core.matchkey' +local define = require 'proto.define' +local await = require 'await' + +local function buildSource(uri, source, key, results) + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' then + local name = source[1] + if matchKey(key, name) then + results[#results+1] = { + name = name, + kind = define.SymbolKind.Variable, + uri = uri, + range = { source.start, source.finish }, + } + end + elseif source.type == 'setfield' + or source.type == 'tablefield' then + local field = source.field + local name = field[1] + if matchKey(key, name) then + results[#results+1] = { + name = name, + kind = define.SymbolKind.Field, + uri = uri, + range = { field.start, field.finish }, + } + end + elseif source.type == 'setmethod' then + local method = source.method + local name = method[1] + if matchKey(key, name) then + results[#results+1] = { + name = name, + kind = define.SymbolKind.Method, + uri = uri, + range = { method.start, method.finish }, + } + end + end +end + +local function searchFile(uri, key, results) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSource(ast.ast, function (source) + buildSource(uri, source, key, results) + end) +end + +return function (key) + local results = {} + + for uri in files.eachFile() do + searchFile(files.getOriginUri(uri), key, results) + if #results > 1000 then + break + end + await.delay() + end + + return results +end |