diff options
author | CppCXY <812125110@qq.com> | 2022-08-11 19:36:36 +0800 |
---|---|---|
committer | CppCXY <812125110@qq.com> | 2022-08-11 19:36:36 +0800 |
commit | ff9103ae4001d8e520171b99cd192997fc689bc9 (patch) | |
tree | 04c0b685e81aac48210604dc12d24b91862a36d9 /script/core | |
parent | 40f191a85ea21bb64c427f9dab4bc597e2a0ea1b (diff) | |
parent | 82bcfef9037c26681993c94b2f92b68d335de3c6 (diff) | |
download | lua-language-server-ff9103ae4001d8e520171b99cd192997fc689bc9.zip |
Merge branch 'master' of github.com:CppCXY/lua-language-server
Diffstat (limited to 'script/core')
74 files changed, 2081 insertions, 789 deletions
diff --git a/script/core/code-action.lua b/script/core/code-action.lua index 6bba0a82..4eb21ff8 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -5,11 +5,18 @@ local sp = require 'bee.subprocess' local guide = require "parser.guide" local converter = require 'proto.converter' +---@param uri uri +---@param row integer +---@param mode string +---@param code string local function checkDisableByLuaDocExits(uri, row, mode, code) if row < 0 then return nil end local state = files.getState(uri) + if not state then + return nil + end local lines = state.lines if state.ast.docs and lines then return guide.eachSourceBetween( @@ -124,9 +131,12 @@ local function changeVersion(uri, version, results) end local function solveUndefinedGlobal(uri, diag, results) - local ast = files.getState(uri) - local start = converter.unpackRange(uri, diag.range) - guide.eachSourceContain(ast.ast, start, function (source) + local state = files.getState(uri) + if not state then + return + end + local start = converter.unpackRange(uri, diag.range) + guide.eachSourceContain(state.ast, start, function (source) if source.type ~= 'getglobal' then return end @@ -143,9 +153,12 @@ local function solveUndefinedGlobal(uri, diag, results) end local function solveLowercaseGlobal(uri, diag, results) - local ast = files.getState(uri) - local start = converter.unpackRange(uri, diag.range) - guide.eachSourceContain(ast.ast, start, function (source) + local state = files.getState(uri) + if not state then + return + end + local start = converter.unpackRange(uri, diag.range) + guide.eachSourceContain(state.ast, start, function (source) if source.type ~= 'setglobal' then return end @@ -156,8 +169,11 @@ local function solveLowercaseGlobal(uri, diag, results) end local function findSyntax(uri, diag) - local ast = files.getState(uri) - for _, err in ipairs(ast.errs) do + local state = files.getState(uri) + if not state then + return + end + for _, err in ipairs(state.errs) do if err.type:lower():gsub('_', '-') == diag.code then local range = converter.packRange(uri, err.start, err.finish) if util.equal(range, diag.range) then @@ -333,6 +349,8 @@ local function solveAwaitInSync(uri, diag, results) end local row = guide.rowColOf(parentFunction.start) local pos = guide.positionOf(row, 0) + local offset = guide.positionToOffset(state, pos + 1) + local space = state.lua:match('[ \t]*', offset) results[#results+1] = { title = lang.script.ACTION_MARK_ASYNC, kind = 'quickfix', @@ -342,7 +360,7 @@ local function solveAwaitInSync(uri, diag, results) { start = pos, finish = pos, - newText = '---@async\n', + newText = space .. '---@async\n', } } } @@ -350,6 +368,51 @@ local function solveAwaitInSync(uri, diag, results) } end +local function solveSpell(uri, diag, results) + local spell = require 'provider.spell' + local word = diag.data + if word == nil then + return + end + + results[#results+1] = { + title = lang.script('ACTION_ADD_DICT', word), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_ADD_DICT, + command = 'lua.setConfig', + arguments = { + { + key = 'Lua.spell.dict', + action = 'add', + value = word, + uri = uri, + } + } + } + } + + local suggests = spell.getSpellSuggest(word) + for _, suggest in ipairs(suggests) do + results[#results+1] = { + title = suggest, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + start = converter.unpackPosition(uri, diag.range.start), + finish = converter.unpackPosition(uri, diag.range["end"]), + newText = suggest + } + } + } + } + } + end + +end + local function solveDiagnostic(uri, diag, start, results) if diag.source == lang.script.DIAG_SYNTAX_CHECK then solveSyntax(uri, diag, results) @@ -370,6 +433,8 @@ local function solveDiagnostic(uri, diag, start, results) solveTrailingSpace(uri, diag, results) elseif diag.code == 'await-in-sync' then solveAwaitInSync(uri, diag, results) + elseif diag.code == 'spell-check' then + solveSpell(uri, diag, results) end disableDiagnostic(uri, diag.code, start, results) end @@ -386,7 +451,7 @@ end local function checkSwapParams(results, uri, start, finish) local state = files.getState(uri) local text = files.getText(uri) - if not state then + if not state or not text then return end local args = {} @@ -554,6 +619,9 @@ end local function checkJsonToLua(results, uri, start, finish) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return + end local startOffset = guide.positionToOffset(state, start) local finishOffset = guide.positionToOffset(state, finish) local jsonStart = text:match('()[%{%[]', startOffset + 1) diff --git a/script/core/collector.lua b/script/core/collector.lua deleted file mode 100644 index a2e3ca08..00000000 --- a/script/core/collector.lua +++ /dev/null @@ -1,188 +0,0 @@ -local scope = require 'workspace.scope' - ----@class collector ----@field subscribed table<uri, table<string, any>> ----@field collect table<string, table<uri, any>> -local mt = {} -mt.__index = mt - ---- 订阅一个名字 ----@param uri uri ----@param name string ----@param value any -function mt:subscribe(uri, name, value) - uri = uri or '<fallback>' - -- 订阅部分 - local uriSubscribed = self.subscribed[uri] - if not uriSubscribed then - uriSubscribed = {} - self.subscribed[uri] = uriSubscribed - end - uriSubscribed[name] = true - -- 收集部分 - local nameCollect = self.collect[name] - if not nameCollect then - nameCollect = {} - self.collect[name] = nameCollect - end - if value == nil then - value = true - end - nameCollect[uri] = value -end - ---- 丢弃掉某个 uri 中收集的所有信息 ----@param uri uri -function mt:dropUri(uri) - uri = uri or '<fallback>' - local uriSubscribed = self.subscribed[uri] - if not uriSubscribed then - return - end - self.subscribed[uri] = nil - for name in pairs(uriSubscribed) do - self.collect[name][uri] = nil - if not next(self.collect[name]) then - self.collect[name] = nil - end - end -end - -function mt:dropAll() - self.subscribed = {} - self.collect = {} -end - ---- 是否包含某个名字的订阅 ----@param uri uri ----@param name string ----@return boolean -function mt:has(uri, name) - if self:each(uri, name)() then - return true - else - return false - end -end - -local DUMMY_FUNCTION = function () end - ----@param scp scope -local function eachOfFolder(nameCollect, scp) - local curi, value - - local function getNext() - curi, value = next(nameCollect, curi) - if not curi then - return nil, nil - end - if scp:isChildUri(curi) - or scp:isLinkedUri(curi) then - return value, curi - end - return getNext() - end - - return getNext -end - ----@param scp scope -local function eachOfLinked(nameCollect, scp) - local curi, value - - local function getNext() - curi, value = next(nameCollect, curi) - if not curi then - return nil, nil - end - if scp:isChildUri(curi) - and scp:isLinkedUri(curi) then - return value, curi - end - - local cscp = scope.getFolder(curi) - or scope.getLinkedScope(curi) - or scope.fallback - - if cscp == scp - or cscp:isChildUri(scp.uri) - or cscp:isLinkedUri(scp.uri) then - return value, curi - end - - return getNext() - end - - return getNext -end - ----@param scp scope -local function eachOfFallback(nameCollect, scp) - local curi, value - - local function getNext() - curi, value = next(nameCollect, curi) - if not curi then - return nil, nil - end - if scp:isLinkedUri(curi) then - return value, curi - end - - local cscp = scope.getFolder(curi) - or scope.getLinkedScope(curi) - or scope.fallback - - if cscp == scp then - return value, curi - end - - return getNext() - end - - return getNext -end - ---- 迭代某个名字的订阅 ----@param uri uri ----@param name string -function mt:each(uri, name) - uri = uri or '<fallback>' - local nameCollect = self.collect[name] - if not nameCollect then - return DUMMY_FUNCTION - end - - local scp = scope.getFolder(uri) - - if scp then - return eachOfFolder(nameCollect, scp) - end - - scp = scope.getLinkedScope(uri) - - if scp then - return eachOfLinked(nameCollect, scp) - end - - return eachOfFallback(nameCollect, scope.fallback) -end - -local collectors = {} - -local function new() - return setmetatable({ - collect = {}, - subscribed = {}, - }, mt) -end - ----@return collector -return function (name) - if name then - collectors[name] = collectors[name] or new() - return collectors[name] - else - return new() - end -end diff --git a/script/core/color.lua b/script/core/color.lua new file mode 100644 index 00000000..2cbcce11 --- /dev/null +++ b/script/core/color.lua @@ -0,0 +1,79 @@ +local files = require "files" +local guide = require "parser.guide" + +local colorPattern = string.rep('%x', 8) +---@param source parser.object +---@return boolean +local function isColor(source) + ---@type string + local text = source[1] + if text:len() ~= 8 then + return false + end + return text:match(colorPattern) +end + + +---@param colorText string +---@return Color +local function textToColor(colorText) + return { + alpha = tonumber(colorText:sub(1, 2), 16) / 255, + red = tonumber(colorText:sub(3, 4), 16) / 255, + green = tonumber(colorText:sub(5, 6), 16) / 255, + blue = tonumber(colorText:sub(7, 8), 16) / 255, + } +end + + +---@param color Color +---@return string +local function colorToText(color) + return string.format('%02X%02X%02X%02X' + , math.floor(color.alpha * 255) + , math.floor(color.red * 255) + , math.floor(color.green * 255) + , math.floor(color.blue * 255) + ) +end + +---@class Color +---@field red number +---@field green number +---@field blue number +---@field alpha number + +---@class ColorValue +---@field color Color +---@field start integer +---@field finish integer + +---@async +local function colors(uri) + local state = files.getState(uri) + local text = files.getText(uri) + if not state or not text then + return nil + end + ---@type ColorValue[] + local colorValues = {} + + guide.eachSource(state.ast, function (source) ---@async + if source.type == 'string' and isColor(source) then + ---@type string + local colorText = source[1] + + colorValues[#colorValues+1] = { + start = source.start + 1, + finish = source.finish - 1, + color = textToColor(colorText) + } + end + end) + return colorValues +end + +return { + colors = colors, + colorToText = colorToText +} diff --git a/script/core/command/autoRequire.lua b/script/core/command/autoRequire.lua index c0deecfc..32911d92 100644 --- a/script/core/command/autoRequire.lua +++ b/script/core/command/autoRequire.lua @@ -21,6 +21,9 @@ end local function findInsertRow(uri) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return + end local lines = state.lines local fmt = { pair = false, @@ -68,7 +71,7 @@ local function askAutoRequire(uri, visiblePaths) local selects = {} local nameMap = {} for _, visible in ipairs(visiblePaths) do - local expect = visible.expect + local expect = visible.name local select = lang.script(expect) if not nameMap[select] then nameMap[select] = expect @@ -143,7 +146,7 @@ return function (data) return end table.sort(visiblePaths, function (a, b) - return #a.expect < #b.expect + return #a.name < #b.name end) local result = askAutoRequire(uri, visiblePaths) diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua index aa565f7f..992a0705 100644 --- a/script/core/command/removeSpace.lua +++ b/script/core/command/removeSpace.lua @@ -4,20 +4,12 @@ local proto = require 'proto' local lang = require 'language' local converter = require 'proto.converter' -local function isInString(ast, offset) - return guide.eachSourceContain(ast.ast, offset, function (source) - if source.type == 'string' then - return true - end - end) or false -end - ---@async return function (data) local uri = data.uri local text = files.getText(uri) local state = files.getState(uri) - if not state then + if not state or not text then return end @@ -32,7 +24,8 @@ return function (data) goto NEXT_LINE end local lastPos = guide.offsetToPosition(state, lastOffset) - if isInString(state.ast, lastPos) then + if guide.isInString(state.ast, lastPos) + or guide.isInComment(state.ast, lastPos) then goto NEXT_LINE end local firstOffset = startOffset diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua index 8065aa9d..98ceaa58 100644 --- a/script/core/command/solve.lua +++ b/script/core/command/solve.lua @@ -32,7 +32,7 @@ return function (data) local uri = data.uri local text = files.getText(uri) local state = files.getState(uri) - if not state then + if not state or not text then return end diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index d4c20c60..8f28e450 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -18,6 +18,7 @@ local lookBackward = require 'core.look-backward' local guide = require 'parser.guide' local await = require 'await' local postfix = require 'core.completion.postfix' +local diag = require 'proto.diagnostic' local diagnosticModes = { 'disable-next-line', @@ -56,6 +57,7 @@ local function trim(str) end local function findNearestSource(state, position) + ---@type parser.object local source guide.eachSourceContain(state.ast, position, function (src) source = src @@ -66,6 +68,9 @@ end local function findNearestTableField(state, position) local uri = state.uri local text = files.getText(uri) + if not text then + return nil + end local offset = guide.positionToOffset(state, position) local soffset = lookBackward.findAnyOffset(text, offset) if not soffset then @@ -155,36 +160,24 @@ local function buildFunctionSnip(source, value, oop) if oop then table.remove(args, 1) end - local len = #args - local truncated = false - if len > 0 and args[len]:match('^%s*%.%.%.:') then - table.remove(args) - truncated = true - end - for i = #args, 1, -1 do - if args[i]:match('^%s*[^?]+%?:') then - table.remove(args) - truncated = true - else - break - end - end local snipArgs = {} for id, arg in ipairs(args) do - local str = arg:gsub('^(%s*)(.+)', function (sp, word) + local str, count = arg:gsub('^(%s*)(%.%.%.)(.+)', function (sp, word) return ('%s${%d:%s}'):format(sp, id, word) end) + if count == 0 then + str = arg:gsub('^(%s*)([^:]+)(.+)', function (sp, word) + return ('%s${%d:%s}'):format(sp, id, word) + end) + end table.insert(snipArgs, str) end - if truncated and #snipArgs == 0 then - snipArgs = {'$1'} - end return ('%s(%s)'):format(name, table.concat(snipArgs, ', ')) end local function buildDetail(source) - local types = vm.getInfer(source):view() + local types = vm.getInfer(source):view(guide.getUri(source)) local literals = vm.getInfer(source):viewLiterals() if literals then return types .. ' = ' .. literals @@ -204,6 +197,9 @@ local function getSnip(source) local uri = guide.getUri(def) local text = files.getText(uri) local state = files.getState(uri) + if not state then + goto CONTINUE + end local lines = state.lines if not text then goto CONTINUE @@ -302,7 +298,7 @@ local function checkLocal(state, word, position, results) if name:sub(1, 1) == '@' then goto CONTINUE end - if vm.getInfer(source):hasFunction() then + if vm.getInfer(source):hasFunction(state.uri) then local defs = vm.getDefs(source) -- make sure `function` is before `doc.type.function` local orders = {} @@ -356,6 +352,7 @@ local function checkModule(state, word, position, results) if not config.get(state.uri, 'Lua.completion.autoRequire') then return end + local globals = util.arrayToHash(config.get(state.uri, 'Lua.diagnostics.globals')) local locals = guide.getVisibleLocals(state.ast, position) for uri in files.eachFile(state.uri) do if uri == guide.getUri(state.ast) then @@ -366,7 +363,7 @@ local function checkModule(state, word, position, results) local stemName = fileName:gsub('%..+', '') if not locals[stemName] and not vm.hasGlobalSets(state.uri, 'variable', stemName) - and not config.get(state.uri, 'Lua.diagnostics.globals')[stemName] + and not globals[stemName] and stemName:match '^[%a_][%w_]*$' and matchKey(word, stemName) then local targetState = files.getState(uri) @@ -488,7 +485,7 @@ local function checkFieldFromFieldToIndex(state, name, src, parent, word, startP end local function checkFieldThen(state, name, src, word, startPos, position, parent, oop, results) - local value = vm.getObjectValue(src) or src + local value = vm.getObjectFunctionValue(src) or src local kind = define.CompletionItemKind.Field if value.type == 'function' or value.type == 'doc.type.function' then @@ -512,7 +509,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent }) return end - if oop and not vm.getInfer(src):hasFunction() then + if oop and not vm.getInfer(src):hasFunction(state.uri) then return end local literal = guide.getLiteral(value) @@ -568,7 +565,8 @@ local function checkFieldOfRefs(refs, state, word, startPos, position, parent, o end local funcLabel if config.get(state.uri, 'Lua.completion.showParams') then - local value = vm.getObjectValue(src) or src + --- TODO determine if getlocal should be a function here too + local value = vm.getObjectFunctionValue(src) or src if value.type == 'function' or value.type == 'doc.type.function' then funcLabel = name .. getParams(value, oop) @@ -916,24 +914,24 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position goto CONTINUE end local path = furi.decode(uri) - local infos = rpath.getVisiblePath(uri, path) + local infos = rpath.getVisiblePath(myUri, path) local relative = workspace.getRelativePath(path) for _, info in ipairs(infos) do - if matchKey(literal, info.expect) then - if not collect[info.expect] then - collect[info.expect] = { + if matchKey(literal, info.name) then + if not collect[info.name] then + collect[info.name] = { textEdit = { start = smark and (source.start + #smark) or position, finish = smark and (source.finish - #smark) or position, - newText = smark and info.expect or util.viewString(info.expect), + newText = smark and info.name or util.viewString(info.name), }, path = relative, } end if vm.isMetaFile(uri) then - collect[info.expect][#collect[info.expect]+1] = ('* [[meta]](%s)'):format(uri) + collect[info.name][#collect[info.name]+1] = ('* [[meta]](%s)'):format(uri) else - collect[info.expect][#collect[info.expect]+1] = ([=[* [%s](%s) %s]=]):format( + collect[info.name][#collect[info.name]+1] = ([=[* [%s](%s) %s]=]):format( relative, uri, lang.script('HOVER_USE_LUA_PATH', info.searcher) @@ -1098,11 +1096,11 @@ local function tryLabelInString(label, source) if not source or source.type ~= 'string' then return label end - local state = parser.parse(label, 'String') + local state = parser.compile(label, 'String') if not state or not state.ast then return label end - if not matchKey(source[1], state.ast[1]) then + if not matchKey(source[1], state.ast[1]--[[@as string]]) then return nil end return util.viewString(state.ast[1], source[2]) @@ -1124,18 +1122,112 @@ local function cleanEnums(enums, source) return enums end -local function checkTypingEnum(state, position, defs, str, results) +---@param state parser.state +---@param pos integer +---@param doc vm.node.object +---@param enums table[] +---@return table[]? +local function insertDocEnum(state, pos, doc, enums) + local tbl = doc.bindSource + if not tbl then + return nil + end + local parent = tbl.parent + local parentName + if parent._globalNode then + parentName = parent._globalNode:getName() + else + local locals = guide.getVisibleLocals(state.ast, pos) + for _, loc in pairs(locals) do + if util.arrayHas(vm.getDefs(loc), tbl) then + parentName = loc[1] + break + end + end + end + local valueEnums = {} + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + if field.value.type == 'integer' + or field.value.type == 'string' then + if parentName then + enums[#enums+1] = { + label = parentName .. '.' .. key, + kind = define.CompletionItemKind.EnumMember, + id = stack(function () ---@async + return { + detail = buildDetail(field), + description = buildDesc(field), + } + end), + } + end + valueEnums[#valueEnums+1] = { + label = util.viewLiteral(field.value[1]), + kind = define.CompletionItemKind.EnumMember, + id = stack(function () ---@async + return { + detail = buildDetail(field), + description = buildDesc(field), + } + end), + } + end + ::CONTINUE:: + end + end + for _, enum in ipairs(valueEnums) do + enums[#enums+1] = enum + end + return enums +end + +---@param state parser.state +---@param pos integer +---@param src vm.node.object +---@param enums table[] +---@param isInArray boolean? +local function insertEnum(state, pos, src, enums, isInArray) + if src.type == 'doc.type.string' + or src.type == 'doc.type.integer' + or src.type == 'doc.type.boolean' then + ---@cast src parser.object + enums[#enums+1] = { + label = vm.viewObject(src, state.uri), + description = src.comment, + kind = define.CompletionItemKind.EnumMember, + } + elseif src.type == 'doc.type.code' then + enums[#enums+1] = { + label = src[1], + description = src.comment, + kind = define.CompletionItemKind.EnumMember, + } + elseif isInArray and src.type == 'doc.type.array' then + for i, d in ipairs(vm.getDefs(src.node)) do + insertEnum(state, pos, d, enums, isInArray) + end + elseif src.type == 'global' and src.cate == 'type' then + for _, set in ipairs(src:getSets(state.uri)) do + if set.type == 'doc.enum' then + insertDocEnum(state, pos, set, enums) + end + end + end +end + +local function checkTypingEnum(state, position, defs, str, results, isInArray) local enums = {} for _, def in ipairs(defs) do - if def.type == 'doc.type.string' - or def.type == 'doc.type.integer' - or def.type == 'doc.type.boolean' then - enums[#enums+1] = { - label = vm.viewObject(def), - description = def.comment and def.comment.text, - kind = define.CompletionItemKind.EnumMember, - } - end + insertEnum(state, position, def, enums, isInArray) end cleanEnums(enums, str) for _, res in ipairs(enums) do @@ -1143,7 +1235,7 @@ local function checkTypingEnum(state, position, defs, str, results) end end -local function checkEqualEnumLeft(state, position, source, results) +local function checkEqualEnumLeft(state, position, source, results, isInArray) if not source then return end @@ -1153,7 +1245,7 @@ local function checkEqualEnumLeft(state, position, source, results) end end) local defs = vm.getDefs(source) - checkTypingEnum(state, position, defs, str, results) + checkTypingEnum(state, position, defs, str, results, isInArray) end local function checkEqualEnum(state, position, results) @@ -1197,15 +1289,24 @@ local function checkEqualEnumInString(state, position, results) end checkEqualEnumLeft(state, position, parent[1], results) end + if (parent.type == 'tableexp') then + checkEqualEnumLeft(state, position, parent.parent.parent, results, true) + return + end if parent.type == 'local' then checkEqualEnumLeft(state, position, parent, results) end + if parent.type == 'setlocal' or parent.type == 'setglobal' or parent.type == 'setfield' or parent.type == 'setindex' then checkEqualEnumLeft(state, position, parent.node, results) end + if parent.type == 'tablefield' + or parent.type == 'tableindex' then + checkEqualEnumLeft(state, position, parent, results) + end end local function isFuncArg(state, position) @@ -1234,7 +1335,10 @@ local function tryIndex(state, position, results) if not parent then return end - local word = parent.next.index[1] + local word = parent.next and parent.next.index and parent.next.index[1] + if not word then + return + end checkField(state, word, position, position, parent, oop, results) end @@ -1414,18 +1518,12 @@ local function tryCallArg(state, position, results) if not node then return end + local enums = {} for src in node:eachObject() do - if src.type == 'doc.type.string' - or src.type == 'doc.type.integer' - or src.type == 'doc.type.boolean' then - enums[#enums+1] = { - label = vm.viewObject(src), - description = src.comment, - kind = define.CompletionItemKind.EnumMember, - } - end + insertEnum(state, position, src, enums, arg and arg.type == 'table') if src.type == 'doc.type.function' then + ---@cast src parser.object local insertText = buildInsertDocFunction(src) local description if src.comment then @@ -1439,7 +1537,7 @@ local function tryCallArg(state, position, results) : string() end enums[#enums+1] = { - label = vm.getInfer(src):view(), + label = vm.getInfer(src):view(state.uri), description = description, kind = define.CompletionItemKind.Function, insertText = insertText, @@ -1467,6 +1565,7 @@ local function tryTable(state, position, results) if source.type ~= 'table' then tbl = source.parent end + local defs = vm.getFields(tbl) for _, field in ipairs(defs) do local name = guide.getKeyName(field) @@ -1478,9 +1577,28 @@ local function tryTable(state, position, results) checkTableLiteralField(state, position, tbl, fields, results) end +local function tryArray(state, position, results) + local source = findNearestSource(state, position) + if not source then + return + end + if source.type ~= 'table' and (not source.parent or source.parent.type ~= 'table') then + return + end + local tbl = source + if source.type ~= 'table' then + tbl = source.parent + end + if source.parent.type == 'callargs' and source.parent.parent.type == 'call' then + return + end + -- { } inside when enum + checkEqualEnumLeft(state, position, tbl, results, true) +end + local function getComment(state, position) local offset = guide.positionToOffset(state, position) - local symbolOffset = lookBackward.findAnyOffset(state.lua, offset) + local symbolOffset = lookBackward.findAnyOffset(state.lua, offset, true) if not symbolOffset then return end @@ -1493,9 +1611,9 @@ local function getComment(state, position) return nil end -local function getluaDoc(state, position) +local function getLuaDoc(state, position) local offset = guide.positionToOffset(state, position) - local symbolOffset = lookBackward.findAnyOffset(state.lua, offset) + local symbolOffset = lookBackward.findAnyOffset(state.lua, offset, true) if not symbolOffset then return end @@ -1528,11 +1646,15 @@ local function tryluaDocCate(word, results) 'async', 'nodiscard', 'cast', + 'operator', + 'source', + 'enum', } do if matchKey(word, docType) then results[#results+1] = { label = docType, kind = define.CompletionItemKind.Event, + description = lang.script('LUADOC_DESC_' .. docType:upper()) } end end @@ -1608,6 +1730,7 @@ local function tryluaDocBySource(state, position, source, results) for _, doc in ipairs(vm.getDocSets(state.uri)) do local name = (doc.type == 'doc.class' and doc.class[1]) or (doc.type == 'doc.alias' and doc.alias[1]) + or (doc.type == 'doc.enum' and doc.enum[1]) if name and not used[name] and matchKey(source[1], name) then @@ -1697,6 +1820,35 @@ local function tryluaDocBySource(state, position, source, results) end end return true + elseif source.type == 'doc.operator.name' then + for _, name in ipairs(vm.UNARY_OP) do + if matchKey(source[1], name) then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_UNARY_MAP[name]), + } + end + end + for _, name in ipairs(vm.BINARY_OP) do + if matchKey(source[1], name) then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_BINARY_MAP[name]), + } + end + end + for _, name in ipairs(vm.OTHER_OP) do + if matchKey(source[1], name) then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_OTHER_MAP[name]), + } + end + end + return true end return false end @@ -1734,6 +1886,14 @@ local function tryluaDocByErr(state, position, err, docState, results) kind = define.CompletionItemKind.Class, } end + if doc.type == 'doc.enum' + and not used[doc.enum[1]] then + used[doc.enum[1]] = true + results[#results+1] = { + label = doc.enum[1], + kind = define.CompletionItemKind.Enum, + } + end end elseif err.type == 'LUADOC_MISS_PARAM_NAME' then local funcs = {} @@ -1783,7 +1943,7 @@ local function tryluaDocByErr(state, position, err, docState, results) } end elseif err.type == 'LUADOC_MISS_DIAG_NAME' then - for name in util.sortPairs(define.DiagnosticDefaultSeverity) do + for name in util.sortPairs(diag.getDiagAndErrNameMap()) do results[#results+1] = { label = name, kind = define.CompletionItemKind.Value, @@ -1807,6 +1967,28 @@ local function tryluaDocByErr(state, position, err, docState, results) } end end + elseif err.type == 'LUADOC_MISS_OPERATOR_NAME' then + for _, name in ipairs(vm.UNARY_OP) do + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_UNARY_MAP[name]), + } + end + for _, name in ipairs(vm.BINARY_OP) do + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_BINARY_MAP[name]), + } + end + for _, name in ipairs(vm.OTHER_OP) do + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_OTHER_MAP[name]), + } + end end end @@ -1818,14 +2000,14 @@ local function buildluaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = vm.getInfer(arg):view() + args[#args+1] = vm.getInfer(arg):view(guide.getUri(func)) end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = vm.getInfer(rtns[n]):view() + returns[n] = vm.getInfer(rtns[n]):view(guide.getUri(func)) end end end @@ -1853,25 +2035,27 @@ local function buildluaDocOfFunction(func) end local function tryluaDocOfFunction(doc, results) - if not doc.bindSources then + if not doc.bindSource then return end - local func - for _, source in ipairs(doc.bindSources) do - if source.type == 'function' then - func = source - break - end - end + local func = (doc.bindSource.type == 'function' and doc.bindSource) + or (doc.bindSource.value and doc.bindSource.value.type == 'function' and doc.bindSource.value) + or nil if not func then return end for _, otherDoc in ipairs(doc.bindGroup) do - if otherDoc.type == 'doc.param' - or otherDoc.type == 'doc.return' then + if otherDoc.type == 'doc.return' then return end end + if func.args then + for _, param in ipairs(func.args) do + if param.bindDocs then + return + end + end + end local insertText = buildluaDocOfFunction(func) results[#results+1] = { label = '@param;@return', @@ -1882,8 +2066,8 @@ local function tryluaDocOfFunction(doc, results) } end -local function tryluaDoc(state, position, results) - local doc = getluaDoc(state, position) +local function tryLuaDoc(state, position, results) + local doc = getLuaDoc(state, position) if not doc then return end @@ -1922,7 +2106,7 @@ local function tryComment(state, position, results) return end local word = lookBackward.findWord(state.lua, guide.positionToOffset(state, position)) - local doc = getluaDoc(state, position) + local doc = getLuaDoc(state, position) if not word then local comment = getComment(state, position) if not comment then @@ -1961,7 +2145,7 @@ local function tryCompletions(state, position, triggerCharacter, results) return end if getComment(state, position) then - tryluaDoc(state, position, results) + tryLuaDoc(state, position, results) tryComment(state, position, results) return end @@ -1971,6 +2155,7 @@ local function tryCompletions(state, position, triggerCharacter, results) trySpecial(state, position, results) tryCallArg(state, position, results) tryTable(state, position, results) + tryArray(state, position, results) tryWord(state, position, triggerCharacter, results) tryIndex(state, position, results) trySymbol(state, position, results) @@ -1983,8 +2168,6 @@ local function completion(uri, position, triggerCharacter) return nil end clearStack() - vm.lockCache() - local _ <close> = vm.unlockCache local results = {} tracy.ZoneBeginN 'completion #2' tryCompletions(state, position, triggerCharacter, results) diff --git a/script/core/definition.lua b/script/core/definition.lua index e4868532..866e8f84 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -4,6 +4,7 @@ local vm = require 'vm' local findSource = require 'core.find-source' local guide = require 'parser.guide' local rpath = require 'workspace.require-path' +local jumpSource = require 'core.jump-source' local function sortResults(results) -- 先按照顺序排序 @@ -54,6 +55,7 @@ local accept = { ['doc.see.name'] = true, ['doc.see.field'] = true, ['doc.cast.name'] = true, + ['doc.enum.name'] = true, } local function checkRequire(source, offset) @@ -75,7 +77,7 @@ local function checkRequire(source, offset) return nil end if libName == 'require' then - return rpath.findUrisByRequirePath(guide.getUri(source), literal) + return rpath.findUrisByRequireName(guide.getUri(source), literal) elseif libName == 'dofile' or libName == 'loadfile' then return workspace.findUrisByFilePath(literal) @@ -169,8 +171,12 @@ return function (uri, offset) if src.type == 'doc.alias' then src = src.alias end + if src.type == 'doc.enum' then + src = src.enum + end if src.type == 'doc.class.name' - or src.type == 'doc.alias.name' then + or src.type == 'doc.alias.name' + or src.type == 'doc.enum.name' then if source.type ~= 'doc.type.name' and source.type ~= 'doc.extends.name' and source.type ~= 'doc.see.name' then @@ -197,6 +203,7 @@ return function (uri, offset) end sortResults(results) + jumpSource(results) return results end diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua index f03f4361..830b2f2f 100644 --- a/script/core/diagnostics/ambiguity-1.lua +++ b/script/core/diagnostics/ambiguity-1.lua @@ -27,10 +27,10 @@ local literalMap = { return function (uri, callback) local state = files.getState(uri) - if not state then + local text = files.getText(uri) + if not state or not text then return end - local text = files.getText(uri) guide.eachSourceType(state.ast, 'binary', function (source) if source.op.type ~= 'or' then return diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua new file mode 100644 index 00000000..2d5c3f98 --- /dev/null +++ b/script/core/diagnostics/assign-type-mismatch.lua @@ -0,0 +1,117 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +local checkTypes = { + 'local', + 'setlocal', + 'setglobal', + 'setfield', + 'setindex', + 'setmethod', + 'tablefield', + 'tableindex' +} + +---@param source parser.object +---@return boolean +local function hasMarkType(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' + or doc.type == 'doc.class' then + return true + end + end + return false +end + +---@param source parser.object +---@return boolean +local function hasMarkClass(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' then + return true + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceTypes(state.ast, checkTypes, function (source) + local value = source.value + if not value then + return + end + await.delay() + if source.type == 'setlocal' then + local locNode = vm.compileNode(source.node) + if not locNode:getData 'hasDefined' then + return + end + end + if value.type == 'nil' then + --[[ + ---@class A + local mt + ---@type X + mt._x = nil -- don't warn this + ]] + if hasMarkType(source) then + return + end + if source.type == 'setfield' + or source.type == 'setindex' then + return + end + end + + local valueNode = vm.compileNode(value) + if source.type == 'setindex' then + -- boolean[1] = nil + valueNode = valueNode:copy():removeOptional() + end + + if value.type == 'getfield' + or value.type == 'getindex' then + -- 由于无法对字段进行类型收窄, + -- 因此将假值移除再进行检查 + valueNode = valueNode:copy():setTruthy() + end + + local varNode = vm.compileNode(source) + if vm.canCastType(uri, varNode, valueNode) then + return + end + + -- local Cat = setmetatable({}, {__index = Animal}) 允许逆变 + if hasMarkClass(source) then + if vm.canCastType(uri, valueNode:copy():remove 'table', varNode) then + return + end + end + + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_ASSIGN_TYPE_MISMATCH', { + def = vm.getInfer(varNode):view(uri), + ref = vm.getInfer(valueNode):view(uri), + }), + } + end) +end diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua new file mode 100644 index 00000000..c3d6e1bb --- /dev/null +++ b/script/core/diagnostics/cast-local-type.lua @@ -0,0 +1,50 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'local', function (loc) + if not loc.ref then + return + end + await.delay() + local locNode = vm.compileNode(loc) + if not locNode:getData 'hasDefined' then + return + end + for _, ref in ipairs(loc.ref) do + if ref.type == 'setlocal' and ref.value then + await.delay() + local refNode = vm.compileNode(ref) + local value = ref.value + + if value.type == 'getfield' + or value.type == 'getindex' then + -- 由于无法对字段进行类型收窄, + -- 因此将假值移除再进行检查 + refNode = refNode:copy():setTruthy() + end + + if not vm.canCastType(uri, locNode, refNode) then + callback { + start = ref.start, + finish = ref.finish, + message = lang.script('DIAG_CAST_LOCAL_TYPE', { + def = vm.getInfer(locNode):view(uri), + ref = vm.getInfer(refNode):view(uri), + }), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/cast-type-mismatch.lua b/script/core/diagnostics/cast-type-mismatch.lua new file mode 100644 index 00000000..a48e6cca --- /dev/null +++ b/script/core/diagnostics/cast-type-mismatch.lua @@ -0,0 +1,45 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.cast' and doc.loc then + await.delay() + local defs = vm.getDefs(doc.loc) + local loc = defs[1] + if loc then + local defNode = vm.compileNode(loc) + if defNode:getData 'hasDefined' then + for _, cast in ipairs(doc.casts) do + if not cast.mode and cast.extends then + local refNode = vm.compileNode(cast.extends) + if not vm.canCastType(uri, defNode, refNode) then + callback { + start = cast.extends.start, + finish = cast.extends.finish, + message = lang.script('DIAG_CAST_TYPE_MISMATCH', { + def = vm.getInfer(defNode):view(uri), + ref = vm.getInfer(refNode):view(uri), + }) + } + end + end + end + end + end + end + end +end diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua index 40d4afeb..fcd2021d 100644 --- a/script/core/diagnostics/circle-doc-class.lua +++ b/script/core/diagnostics/circle-doc-class.lua @@ -2,7 +2,9 @@ local files = require 'files' local lang = require 'language' local vm = require 'vm' local guide = require 'parser.guide' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -18,6 +20,7 @@ return function (uri, callback) if not doc.extends then goto CONTINUE end + await.delay() local myName = guide.getKeyName(doc) local list = { doc } local mark = {} diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index c97014fa..1a42b800 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -25,10 +25,11 @@ return function (uri, callback) return end local infer = vm.getInfer(source.value) - if not infer:hasClass() - and not infer:hasType 'nil' - and not infer:hasType 'table' - and infer:view('any', uri) ~= 'any' then + if not infer:hasClass(uri) + and not infer:hasType(uri, 'nil') + and not infer:hasType(uri, 'table') + and not infer:hasUnknown(uri) + and not infer:hasAny(uri) then callback { start = source.value.start, finish = source.value.finish, diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua index 21f7e83a..963fd9ed 100644 --- a/script/core/diagnostics/code-after-break.lua +++ b/script/core/diagnostics/code-after-break.lua @@ -2,7 +2,9 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -10,12 +12,14 @@ return function (uri, callback) end local mark = {} + ---@async guide.eachSourceType(state.ast, 'break', function (source) local list = source.parent if mark[list] then return end mark[list] = true + await.delay() for i = #list, 1, -1 do local src = list[i] if src == source then diff --git a/script/core/diagnostics/codestyle-check.lua b/script/core/diagnostics/codestyle-check.lua index 34d55ee2..25603b4b 100644 --- a/script/core/diagnostics/codestyle-check.lua +++ b/script/core/diagnostics/codestyle-check.lua @@ -7,7 +7,7 @@ local pformatting = require 'provider.formatting' ---@async return function(uri, callback) - local text = files.getText(uri) + local text = files.getOriginText(uri) if not text then return end diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua index 9bc4b273..bd6e5ee3 100644 --- a/script/core/diagnostics/count-down-loop.lua +++ b/script/core/diagnostics/count-down-loop.lua @@ -10,12 +10,15 @@ return function (uri, callback) end guide.eachSourceType(state.ast, 'loop', function (source) - local maxNumer = source.max and tonumber(source.max[1]) - if maxNumer ~= 1 then + local maxNumber = source.max and tonumber(source.max[1]) + if not maxNumber then return end local minNumber = source.init and tonumber(source.init[1]) - if minNumber and minNumber <= 1 then + if minNumber and maxNumber and minNumber <= maxNumber then + return + end + if not minNumber and maxNumber ~= 1 then return end if not source.step then @@ -24,7 +27,7 @@ return function (uri, callback) finish = source.max.finish, message = lang.script('DIAG_COUNT_DOWN_LOOP' , ('%s, %s'):format(text:sub( - guide.positionToOffset(state, source.init.start), + guide.positionToOffset(state, source.init.start + 1), guide.positionToOffset(state, source.max.finish) ), '-1') ) @@ -37,7 +40,7 @@ return function (uri, callback) finish = source.step.finish, message = lang.script('DIAG_COUNT_DOWN_LOOP' , ('%s, -%s'):format(text:sub( - guide.positionToOffset(state, source.init.start), + guide.positionToOffset(state, source.init.start + 1), guide.positionToOffset(state, source.max.finish) ), source.step[1]) ) diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua index 27920c43..85ae2d19 100644 --- a/script/core/diagnostics/deprecated.lua +++ b/script/core/diagnostics/deprecated.lua @@ -15,7 +15,7 @@ return function (uri, callback) return end - local dglobals = config.get(uri, 'Lua.diagnostics.globals') + local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) local rspecial = config.get(uri, 'Lua.runtime.special') guide.eachSourceTypes(ast.ast, types, function (src) ---@async diff --git a/script/core/diagnostics/different-requires.lua b/script/core/diagnostics/different-requires.lua index de063c9f..22e3e681 100644 --- a/script/core/diagnostics/different-requires.lua +++ b/script/core/diagnostics/different-requires.lua @@ -21,7 +21,7 @@ return function (uri, callback) return end local literal = arg1[1] - local results = rpath.findUrisByRequirePath(uri, literal) + local results = rpath.findUrisByRequireName(uri, literal) if not results or #results ~= 1 then return end diff --git a/script/core/diagnostics/duplicate-doc-alias.lua b/script/core/diagnostics/duplicate-doc-alias.lua index 3df6f972..360358e4 100644 --- a/script/core/diagnostics/duplicate-doc-alias.lua +++ b/script/core/diagnostics/duplicate-doc-alias.lua @@ -2,7 +2,9 @@ local files = require 'files' local lang = require 'language' local vm = require 'vm' local guide = require 'parser.guide' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -15,14 +17,20 @@ return function (uri, callback) local cache = {} for _, doc in ipairs(state.ast.docs) do - if doc.type == 'doc.alias' then + if doc.type == 'doc.alias' + or doc.type == 'doc.enum' then local name = guide.getKeyName(doc) + if not name then + return + end + await.delay() if not cache[name] then local docs = vm.getDocSets(uri, name) cache[name] = {} for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.alias' - or otherDoc.type == 'doc.class' then + or otherDoc.type == 'doc.class' + or otherDoc.type == 'doc.enum' then cache[name][#cache[name]+1] = { start = otherDoc.start, finish = otherDoc.finish, @@ -33,10 +41,10 @@ return function (uri, callback) end if #cache[name] > 1 then callback { - start = doc.alias.start, - finish = doc.alias.finish, + start = (doc.alias or doc.enum).start, + finish = (doc.alias or doc.enum).finish, related = cache, - message = lang.script('DIAG_DUPLICATE_DOC_CLASS', name) + message = lang.script('DIAG_DUPLICATE_DOC_ALIAS', name) } end end diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua index d4116b9b..a30dfa88 100644 --- a/script/core/diagnostics/duplicate-doc-field.lua +++ b/script/core/diagnostics/duplicate-doc-field.lua @@ -1,5 +1,7 @@ local files = require 'files' local lang = require 'language' +local vm = require 'vm.vm' +local await = require 'await' local function getFieldEventName(doc) if not doc.extends then @@ -28,6 +30,7 @@ local function getFieldEventName(doc) return nil end +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -45,7 +48,13 @@ return function (uri, callback) mark = {} elseif doc.type == 'doc.field' then if mark then - local name = ('%q'):format(doc.field[1]) + await.delay() + local name + if doc.field.type == 'doc.type' then + name = ('[%s]'):format(vm.getInfer(doc.field):view(uri)) + else + name = ('%q'):format(doc.field[1]) + end local eventName = getFieldEventName(doc) if eventName then name = name .. '|' .. eventName diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua index 5097ab3a..dfd9bd4b 100644 --- a/script/core/diagnostics/duplicate-index.lua +++ b/script/core/diagnostics/duplicate-index.lua @@ -2,14 +2,17 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' +local await = require 'await' +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end - + ---@async guide.eachSourceType(ast.ast, 'table', function (source) + await.delay() local mark = {} for _, obj in ipairs(source) do if obj.type == 'tablefield' diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua index 8052c420..ce67ab46 100644 --- a/script/core/diagnostics/duplicate-set-field.lua +++ b/script/core/diagnostics/duplicate-set-field.lua @@ -3,17 +3,21 @@ local lang = require 'language' local define = require 'proto.define' local guide = require 'parser.guide' local vm = require 'vm' +local await = require 'await' +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end + ---@async guide.eachSourceType(ast.ast, 'local', function (source) if not source.ref then return end + await.delay() local sets = {} for _, ref in ipairs(source.ref) do if ref.type ~= 'getlocal' then @@ -48,10 +52,12 @@ return function (uri, callback) local blocks = {} for _, value in ipairs(values) do local block = guide.getBlock(value) - if not blocks[block] then - blocks[block] = {} + if block then + if not blocks[block] then + blocks[block] = {} + end + blocks[block][#blocks[block]+1] = value end - blocks[block][#blocks[block]+1] = value end for _, defs in pairs(blocks) do if #defs <= 1 then diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua index fc205d7e..e05b6aef 100644 --- a/script/core/diagnostics/empty-block.lua +++ b/script/core/diagnostics/empty-block.lua @@ -2,15 +2,18 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' +local await = require 'await' --- 检查空代码块 +-- 检查空代码块 -- 但是排除忙等待(repeat/while) +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end + await.delay() guide.eachSourceType(ast.ast, 'if', function (source) for _, block in ipairs(source) do if #block > 0 then @@ -24,6 +27,7 @@ return function (uri, callback) message = lang.script.DIAG_EMPTY_BLOCK, } end) + await.delay() guide.eachSourceType(ast.ast, 'loop', function (source) if #source > 0 then return @@ -35,6 +39,7 @@ return function (uri, callback) message = lang.script.DIAG_EMPTY_BLOCK, } end) + await.delay() guide.eachSourceType(ast.ast, 'in', function (source) if #source > 0 then return diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index 334fd81a..e154080c 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -2,65 +2,35 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' --- TODO: 检查路径是否可达 -local function mayRun(path) - return true -end - return function (uri, callback) - local ast = files.getState(uri) - if not ast then - return - end - local root = guide.getRoot(ast.ast) - local env = guide.getENV(root) - - local nilDefs = {} - if not env or not env.ref then - return - end - for _, ref in ipairs(env.ref) do - if ref.type == 'setlocal' then - if ref.value and ref.value.type == 'nil' then - nilDefs[#nilDefs+1] = ref - end - end - end - - if #nilDefs == 0 then + local state = files.getState(uri) + if not state then return end local function check(source) local node = source.node if node.tag == '_ENV' then - local ok - for _, nilDef in ipairs(nilDefs) do - local mode, pathA = guide.getPath(nilDef, source) - if mode == 'before' - and mayRun(pathA) then - ok = nilDef - break - end - end - if ok then - callback { - start = source.start, - finish = source.finish, - uri = uri, - message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, - related = { - { - start = ok.start, - finish = ok.finish, - uri = uri, - } + return + end + + if not node.value or node.value.type == 'nil' then + callback { + start = source.start, + finish = source.finish, + uri = uri, + message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, + related = { + { + start = node.start, + finish = node.finish, + uri = uri, } } - end + } end end - guide.eachSourceType(ast.ast, 'getglobal', check) - guide.eachSourceType(ast.ast, 'setglobal', check) + guide.eachSourceType(state.ast, 'getglobal', check) + guide.eachSourceType(state.ast, 'setglobal', check) end diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua index b4ae3715..c33de6ce 100644 --- a/script/core/diagnostics/init.lua +++ b/script/core/diagnostics/init.lua @@ -3,14 +3,22 @@ local define = require 'proto.define' local config = require 'config' local await = require 'await' local vm = require "vm.vm" +local util = require 'utility' +local diagd = require 'proto.diagnostic' -- 把耗时最长的诊断放到最后面 local diagSort = { - ['redundant-value'] = 96, - ['not-yieldable'] = 97, - ['deprecated'] = 98, - ['undefined-field'] = 99, - ['redundant-parameter'] = 100, + ['redundant-value'] = 100, + ['not-yieldable'] = 100, + ['deprecated'] = 100, + ['undefined-field'] = 110, + ['redundant-parameter'] = 110, + ['cast-local-type'] = 120, + ['assign-type-mismatch'] = 120, + ['param-type-mismatch'] = 120, + ['missing-return'] = 120, + ['missing-return-value'] = 120, + ['redundant-return-value'] = 120, } local diagList = {} @@ -46,30 +54,86 @@ local function checkSleep(uri, passed) sleepRest = sleepRest - sleeped end +---@param uri uri +---@param name string +---@return string +local function getSeverity(uri, name) + local severity = config.get(uri, 'Lua.diagnostics.severity')[name] + or define.DiagnosticDefaultSeverity[name] + if severity:sub(-1) == '!' then + return severity:sub(1, -2) + end + local groupSeverity = config.get(uri, 'Lua.diagnostics.groupSeverity') + local groups = diagd.getGroups(name) + local groupLevel = 999 + for _, groupName in ipairs(groups) do + local gseverity = groupSeverity[groupName] + if gseverity and gseverity ~= 'Fallback' then + groupLevel = math.min(groupLevel, define.DiagnosticSeverity[gseverity]) + end + end + if groupLevel == 999 then + return severity + end + for severityName, level in pairs(define.DiagnosticSeverity) do + if level == groupLevel then + return severityName + end + end + return severity +end + +---@param uri uri +---@param name string +---@return string +local function getStatus(uri, name) + local status = config.get(uri, 'Lua.diagnostics.neededFileStatus')[name] + or define.DiagnosticDefaultNeededFileStatus[name] + if status:sub(-1) == '!' then + return status:sub(1, -2) + end + local groupStatus = config.get(uri, 'Lua.diagnostics.groupFileStatus') + local groups = diagd.getGroups(name) + local groupLevel = 0 + for _, groupName in ipairs(groups) do + local gstatus = groupStatus[groupName] + if gstatus and gstatus ~= 'Fallback' then + groupLevel = math.max(groupLevel, define.DiagnosticFileStatus[gstatus]) + end + end + if groupLevel == 0 then + return status + end + for statusName, level in pairs(define.DiagnosticFileStatus) do + if level == groupLevel then + return statusName + end + end + return status +end + ---@async ---@param uri uri ---@param name string ---@param isScopeDiag boolean ---@param response async fun(result: any) local function check(uri, name, isScopeDiag, response) - if config.get(uri, 'Lua.diagnostics.disable')[name] then + local disables = config.get(uri, 'Lua.diagnostics.disable') + if util.arrayHas(disables, name) then return end - local level = config.get(uri, 'Lua.diagnostics.severity')[name] - or define.DiagnosticDefaultSeverity[name] - - local neededFileStatus = config.get(uri, 'Lua.diagnostics.neededFileStatus')[name] - or define.DiagnosticDefaultNeededFileStatus[name] + local severity = getSeverity(uri, name) + local status = getStatus(uri, name) - if neededFileStatus == 'None' then + if status == 'None' then return end - if neededFileStatus == 'Opened' and not files.isOpen(uri) then + if status == 'Opened' and not files.isOpen(uri) then return end - local severity = define.DiagnosticSeverity[level] + local level = define.DiagnosticSeverity[severity] local clock = os.clock() local mark = {} ---@async @@ -85,7 +149,7 @@ local function check(uri, name, isScopeDiag, response) end mark[result.start] = true - result.level = severity or result.level + result.level = level or result.level result.code = name response(result) end, name) diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua index d03e8c70..68bec234 100644 --- a/script/core/diagnostics/lowercase-global.lua +++ b/script/core/diagnostics/lowercase-global.lua @@ -3,6 +3,7 @@ local guide = require 'parser.guide' local lang = require 'language' local config = require 'config' local vm = require 'vm' +local util = require 'utility' local function isDocClass(source) if not source.bindDocs then @@ -23,10 +24,7 @@ return function (uri, callback) return end - local definedGlobal = {} - for name in pairs(config.get(uri, 'Lua.diagnostics.globals')) do - definedGlobal[name] = true - end + local definedGlobal = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) guide.eachSourceType(ast.ast, 'setglobal', function (source) local name = guide.getKeyName(source) diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua index 698680ca..78b94a09 100644 --- a/script/core/diagnostics/missing-parameter.lua +++ b/script/core/diagnostics/missing-parameter.lua @@ -2,68 +2,27 @@ local files = require 'files' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' +local await = require 'await' -local function countCallArgs(source) - local result = 0 - if not source.args then - return 0 - end - result = result + #source.args - return result -end - ----@return integer -local function countFuncArgs(source) - if not source.args or #source.args == 0 then - return 0 - end - local count = 0 - for i = #source.args, 1, -1 do - local arg = source.args[i] - if arg.type ~= '...' - and not (arg.name and arg.name[1] =='...') - and not vm.compileNode(arg):isNullable() then - return i - end - end - return count -end - -local function getFuncArgs(func) - local funcArgs - local defs = vm.getDefs(func) - for _, def in ipairs(defs) do - if def.type == 'function' - or def.type == 'doc.type.function' then - local args = countFuncArgs(def) - if not funcArgs or args < funcArgs then - funcArgs = args - end - end - end - return funcArgs -end - +---@async return function (uri, callback) local state = files.getState(uri) if not state then return end + ---@async guide.eachSourceType(state.ast, 'call', function (source) - local callArgs = countCallArgs(source) + await.delay() + local _, callArgs = vm.countList(source.args) - local func = source.node - local funcArgs = getFuncArgs(func) + local funcNode = vm.compileNode(source.node) + local funcArgs = vm.countParamsOfNode(funcNode) - if not funcArgs then + if callArgs >= funcArgs then return end - local delta = callArgs - funcArgs - if delta >= 0 then - return - end callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/missing-return-value.lua b/script/core/diagnostics/missing-return-value.lua new file mode 100644 index 00000000..2156d66c --- /dev/null +++ b/script/core/diagnostics/missing-return-value.lua @@ -0,0 +1,66 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' + +local function hasDocReturn(func) + if not func.bindDocs then + return false + end + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + return true + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + await.delay() + if not hasDocReturn(source) then + return + end + local min = vm.countReturnsOfFunction(source) + if min == 0 then + return + end + local returns = source.returns + if not returns then + return + end + for _, ret in ipairs(returns) do + local rmin, rmax = vm.countList(ret) + if rmax < min then + if rmin == rmax then + callback { + start = ret.start, + finish = ret.start + #'return', + message = lang.script('DIAG_MISSING_RETURN_VALUE', { + min = min, + rmax = rmax, + }), + } + else + callback { + start = ret.start, + finish = ret.start + #'return', + message = lang.script('DIAG_MISSING_RETURN_VALUE_RANGE', { + min = min, + rmin = rmin, + rmax = rmax, + }), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/missing-return.lua b/script/core/diagnostics/missing-return.lua new file mode 100644 index 00000000..e3539ac0 --- /dev/null +++ b/script/core/diagnostics/missing-return.lua @@ -0,0 +1,86 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' + +---@param block parser.object +---@return boolean +local function hasReturn(block) + if block.hasReturn or block.hasError then + return true + end + if block.type == 'if' then + local hasElse + for _, subBlock in ipairs(block) do + if not hasReturn(subBlock) then + return false + end + if subBlock.type == 'elseblock' then + hasElse = true + end + end + return hasElse == true + else + if block.type == 'while' then + if vm.testCondition(block.filter) then + return true + end + end + for _, action in ipairs(block) do + if guide.isBlockType(action) then + if hasReturn(action) then + return true + end + end + end + end + return false +end + +---@param func parser.object +---@return boolean +local function isEmptyFunction(func) + if #func > 0 then + return false + end + local startRow = guide.rowColOf(func.start) + local finishRow = guide.rowColOf(func.finish) + return finishRow - startRow <= 1 +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + -- check declare only + if isEmptyFunction(source) then + return + end + await.delay() + if vm.countReturnsOfFunction(source, true) == 0 then + return + end + if hasReturn(source) then + return + end + local lastAction = source[#source] + local pos + if lastAction then + pos = lastAction.range or lastAction.finish + else + local row = guide.rowColOf(source.finish) + pos = guide.positionOf(row - 1, 0) + end + callback { + start = pos, + finish = pos, + message = lang.script('DIAG_MISSING_RETURN'), + } + end) +end diff --git a/script/core/diagnostics/need-check-nil.lua b/script/core/diagnostics/need-check-nil.lua index 98fdfd08..9c86939a 100644 --- a/script/core/diagnostics/need-check-nil.lua +++ b/script/core/diagnostics/need-check-nil.lua @@ -2,14 +2,18 @@ local files = require 'files' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then return end + ---@async guide.eachSourceType(state.ast, 'getlocal', function (src) + await.delay() local checkNil local nxt = src.next if nxt then @@ -24,11 +28,15 @@ return function (uri, callback) if call and call.type == 'call' and call.node == src then checkNil = true end + local setIndex = src.parent + if setIndex and setIndex.type == 'setindex' and setIndex.index == src then + checkNil = true + end if not checkNil then return end local node = vm.compileNode(src) - if node:hasFalsy() then + if node:hasFalsy() and not vm.getInfer(src):hasType(uri, 'any') then callback { start = src.start, finish = src.finish, diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua index 669ed2bb..bd114959 100644 --- a/script/core/diagnostics/newfield-call.lua +++ b/script/core/diagnostics/newfield-call.lua @@ -1,16 +1,20 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' +local await = require 'await' +local sub = require 'core.substring' +---@async return function (uri, callback) - local ast = files.getState(uri) - if not ast then + local state = files.getState(uri) + local text = files.getText(uri) + if not state or not text then return end - local text = files.getText(uri) - - guide.eachSourceType(ast.ast, 'table', function (source) + ---@async + guide.eachSourceType(state.ast, 'table', function (source) + await.delay() for i = 1, #source do local field = source[i] if field.type ~= 'tableexp' then @@ -33,8 +37,8 @@ return function (uri, callback) start = call.start, finish = call.finish, message = lang.script('DIAG_PREFIELD_CALL' - , text:sub(func.start, func.finish) - , text:sub(args.start, args.finish) + , sub(state)(func.start + 1, func.finish) + , sub(state)(args.start + 1, args.finish) ) } end diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua index 3f2d5ca5..2ba2ce03 100644 --- a/script/core/diagnostics/newline-call.lua +++ b/script/core/diagnostics/newline-call.lua @@ -1,14 +1,18 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' +local await = require 'await' +local sub = require 'core.substring' +---@async return function (uri, callback) local state = files.getState(uri) local text = files.getText(uri) - if not state then + if not state or not text then return end + ---@async guide.eachSourceType(state.ast, 'call', function (source) local node = source.node local args = source.args @@ -20,6 +24,9 @@ return function (uri, callback) if not source.next then return end + + await.delay() + local startOffset = guide.positionToOffset(state, args.start) + 1 local finishOffset = guide.positionToOffset(state, args.finish) if text:sub(startOffset, startOffset) ~= '(' @@ -38,8 +45,8 @@ return function (uri, callback) start = node.start, finish = args.finish, message = lang.script('DIAG_PREVIOUS_CALL' - , text:sub(node.start, node.finish) - , text:sub(args.start, args.finish) + , sub(state)(node.start + 1, node.finish) + , sub(state)(args.start + 1, args.finish) ), } end diff --git a/script/core/diagnostics/no-unknown.lua b/script/core/diagnostics/no-unknown.lua index 48aab5da..e706931a 100644 --- a/script/core/diagnostics/no-unknown.lua +++ b/script/core/diagnostics/no-unknown.lua @@ -2,25 +2,30 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local vm = require 'vm' +local await = require 'await' +local types = { + 'local', + 'setlocal', + 'setglobal', + 'getglobal', + 'setfield', + 'setindex', + 'tablefield', + 'tableindex', +} + +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end - guide.eachSource(ast.ast, function (source) - if source.type ~= 'local' - and source.type ~= 'setlocal' - and source.type ~= 'setglobal' - and source.type ~= 'getglobal' - and source.type ~= 'setfield' - and source.type ~= 'setindex' - and source.type ~= 'tablefield' - and source.type ~= 'tableindex' then - return - end - if vm.getInfer(source):view() == 'unknown' then + ---@async + guide.eachSourceTypes(ast.ast, types, function (source) + await.delay() + if vm.getInfer(source):view(uri) == 'unknown' then callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/not-yieldable.lua b/script/core/diagnostics/not-yieldable.lua index a1c84276..055025d4 100644 --- a/script/core/diagnostics/not-yieldable.lua +++ b/script/core/diagnostics/not-yieldable.lua @@ -11,7 +11,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if vm.getInfer(arg):hasType 'any' + if vm.getInfer(arg):hasType(guide.getUri(def), 'any') or vm.isAsync(arg, true) or arg.type == '...' then return true @@ -22,7 +22,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if vm.getInfer(arg.extends):hasType 'any' + if vm.getInfer(arg.extends):hasType(guide.getUri(def), 'any') or vm.isAsync(arg.extends, true) then return true end diff --git a/script/core/diagnostics/param-type-mismatch.lua b/script/core/diagnostics/param-type-mismatch.lua new file mode 100644 index 00000000..6f34f579 --- /dev/null +++ b/script/core/diagnostics/param-type-mismatch.lua @@ -0,0 +1,72 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@param funcNode vm.node + ---@param i integer + ---@return vm.node? + local function getDefNode(funcNode, i) + local defNode = vm.createNode() + for f in funcNode:eachObject() do + if f.type == 'function' + or f.type == 'doc.type.function' then + local param = f.args and f.args[i] + if param then + defNode:merge(vm.compileNode(param)) + if param[1] == '...' then + defNode:addOptional() + end + end + end + end + if defNode:isEmpty() then + return nil + end + return defNode + end + + ---@async + guide.eachSourceType(state.ast, 'call', function (source) + if not source.args then + return + end + await.delay() + local funcNode = vm.compileNode(source.node) + for i, arg in ipairs(source.args) do + if i == 1 and source.node.type == 'getmethod' then + goto CONTINUE + end + local refNode = vm.compileNode(arg) + local defNode = getDefNode(funcNode, i) + if not defNode then + goto CONTINUE + end + if arg.type == 'getfield' + or arg.type == 'getindex' then + -- 由于无法对字段进行类型收窄, + -- 因此将假值移除再进行检查 + refNode = refNode:copy():setTruthy() + end + if not vm.canCastType(uri, defNode, refNode) then + callback { + start = arg.start, + finish = arg.finish, + message = lang.script('DIAG_PARAM_TYPE_MISMATCH', { + def = vm.getInfer(defNode):view(uri), + ref = vm.getInfer(refNode):view(uri), + }) + } + end + ::CONTINUE:: + end + end) +end diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua index 2157ae71..1fb3ca6b 100644 --- a/script/core/diagnostics/redefined-local.lua +++ b/script/core/diagnostics/redefined-local.lua @@ -1,18 +1,23 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' +local await = require 'await' +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end + + ---@async guide.eachSourceType(ast.ast, 'local', function (source) local name = source[1] if name == '_' or name == ast.ENVMode then return end + await.delay() local exist = guide.getLocal(source, name, source.start-1) if exist then callback { diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index 41781df8..9898d9bd 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -2,73 +2,48 @@ local files = require 'files' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' +local await = require 'await' -local function countCallArgs(source) - local result = 0 - if not source.args then - return 0 - end - result = result + #source.args - return result -end - -local function countFuncArgs(source) - if not source.args or #source.args == 0 then - return 0 - end - local lastArg = source.args[#source.args] - if lastArg.type == '...' - or (lastArg.name and lastArg.name[1] == '...') then - return math.maxinteger - else - return #source.args - end -end - -local function getFuncArgs(func) - local funcArgs - local defs = vm.getDefs(func) - for _, def in ipairs(defs) do - if def.type == 'function' - or def.type == 'doc.type.function' then - local args = countFuncArgs(def) - if not funcArgs or args > funcArgs then - funcArgs = args - end - end - end - return funcArgs -end - +---@async return function (uri, callback) local state = files.getState(uri) if not state then return end + ---@async guide.eachSourceType(state.ast, 'call', function (source) - local callArgs = countCallArgs(source) + await.delay() + local callArgs = vm.countList(source.args) if callArgs == 0 then return end - local func = source.node - local funcArgs = getFuncArgs(func) + local funcNode = vm.compileNode(source.node) + local _, funcArgs = vm.countParamsOfNode(funcNode) - if not funcArgs then - return - end - - local delta = callArgs - funcArgs - if delta <= 0 then + if callArgs <= funcArgs then return end if callArgs == 1 and source.node.type == 'getmethod' then return end - for i = #source.args - delta + 1, #source.args do - local arg = source.args[i] - if arg then + if funcArgs + 1 > #source.args then + local lastArg = source.args[#source.args] + if lastArg.type == 'call' and funcArgs > 0 then + -- 如果函数接收至少一个参数,那么调用方最后一个参数是函数调用 + -- 导致的参数数量太多可以忽略。 + -- 如果函数不接收任何参数,那么任何参数都是错误的。 + return + end + callback { + start = lastArg.start, + finish = lastArg.finish, + message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs) + } + else + for i = funcArgs + 1, #source.args do + local arg = source.args[i] callback { start = arg.start, finish = arg.finish, diff --git a/script/core/diagnostics/redundant-return-value.lua b/script/core/diagnostics/redundant-return-value.lua new file mode 100644 index 00000000..36432f98 --- /dev/null +++ b/script/core/diagnostics/redundant-return-value.lua @@ -0,0 +1,73 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' + +local function hasDocReturn(func) + if not func.bindDocs then + return false + end + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + return true + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + await.delay() + if not hasDocReturn(source) then + return + end + local _, max = vm.countReturnsOfFunction(source) + local returns = source.returns + if not returns then + return + end + for _, ret in ipairs(returns) do + local rmin, rmax = vm.countList(ret) + if rmin > max then + for i = max + 1, #ret - 1 do + callback { + start = ret[i].start, + finish = ret[i].finish, + message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', { + max = max, + rmax = i, + }), + } + end + if #ret == rmax then + callback { + start = ret[#ret].start, + finish = ret[#ret].finish, + message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', { + max = max, + rmax = rmax, + }), + } + else + callback { + start = ret[#ret].start, + finish = ret[#ret].finish, + message = lang.script('DIAG_REDUNDANT_RETURN_VALUE_RANGE', { + max = max, + rmin = #ret, + rmax = rmax, + }), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/return-type-mismatch.lua b/script/core/diagnostics/return-type-mismatch.lua new file mode 100644 index 00000000..cce4aad8 --- /dev/null +++ b/script/core/diagnostics/return-type-mismatch.lua @@ -0,0 +1,76 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +---@param func parser.object +---@return vm.node[]? +local function getDocReturns(func) + if not func.bindDocs then + return nil + end + local returns = {} + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + for _, ret in ipairs(doc.returns) do + returns[ret.returnIndex] = vm.compileNode(ret) + end + end + end + if #returns == 0 then + return nil + end + return returns +end +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@param docReturns vm.node[] + ---@param rets parser.object + local function checkReturn(docReturns, rets) + for i, docRet in ipairs(docReturns) do + local retNode, exp = vm.selectNode(rets, i) + if not exp then + break + end + if retNode:hasName 'nil' then + if exp.type == 'getfield' + or exp.type == 'getindex' then + retNode = retNode:copy():removeOptional() + end + end + if not vm.canCastType(uri, docRet, retNode) then + callback { + start = exp.start, + finish = exp.finish, + message = lang.script('DIAG_RETURN_TYPE_MISMATCH', { + def = vm.getInfer(docRet):view(uri), + ref = vm.getInfer(retNode):view(uri), + index = i, + }), + } + end + end + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + if not source.returns then + return + end + await.delay() + local docReturns = getDocReturns(source) + if not docReturns then + return + end + for _, ret in ipairs(source.returns) do + checkReturn(docReturns, ret) + await.delay() + end + end) +end diff --git a/script/core/diagnostics/spell-check.lua b/script/core/diagnostics/spell-check.lua new file mode 100644 index 00000000..7369a235 --- /dev/null +++ b/script/core/diagnostics/spell-check.lua @@ -0,0 +1,34 @@ +local files = require 'files' +local converter = require 'proto.converter' +local log = require 'log' +local spell = require 'provider.spell' + + +---@async +return function(uri, callback) + local text = files.getOriginText(uri) + if not text then + return + end + + local status, diagnosticInfos = spell.spellCheck(uri, text) + + if not status then + if diagnosticInfos ~= nil then + log.error(diagnosticInfos) + end + + return + end + + if diagnosticInfos then + for _, diagnosticInfo in ipairs(diagnosticInfos) do + callback { + start = converter.unpackPosition(uri, diagnosticInfo.range.start), + finish = converter.unpackPosition(uri, diagnosticInfo.range["end"]), + message = diagnosticInfo.message, + data = diagnosticInfo.data + } + end + end +end diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua index cc51cf77..2e0398b2 100644 --- a/script/core/diagnostics/trailing-space.lua +++ b/script/core/diagnostics/trailing-space.lua @@ -1,25 +1,18 @@ local files = require 'files' local lang = require 'language' local guide = require 'parser.guide' +local await = require 'await' -local function isInString(ast, offset) - local result = false - guide.eachSourceType(ast, 'string', function (source) - if offset >= source.start and offset <= source.finish then - result = true - end - end) - return result -end - +---@async return function (uri, callback) local state = files.getState(uri) - if not state then + local text = files.getText(uri) + if not state or not text then return end - local text = files.getText(uri) local lines = state.lines for i = 0, #lines do + await.delay() local startOffset = lines[i] local finishOffset = text:find('[\r\n]', startOffset) or (#text + 1) local lastOffset = finishOffset - 1 @@ -28,7 +21,8 @@ return function (uri, callback) goto NEXT_LINE end local lastPos = guide.offsetToPosition(state, lastOffset) - if isInString(state.ast, lastPos) then + if guide.isInString(state.ast, lastPos) + or guide.isInComment(state.ast, lastPos) then goto NEXT_LINE end local firstOffset = startOffset diff --git a/script/core/diagnostics/type-check.lua b/script/core/diagnostics/type-check.lua deleted file mode 100644 index cc2b3228..00000000 --- a/script/core/diagnostics/type-check.lua +++ /dev/null @@ -1,3 +0,0 @@ ----@async -return function(uri, callback) -end diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua index df71f0c9..c21ca993 100644 --- a/script/core/diagnostics/unbalanced-assignments.lua +++ b/script/core/diagnostics/unbalanced-assignments.lua @@ -2,7 +2,17 @@ local files = require 'files' local define = require 'proto.define' local lang = require 'language' local guide = require 'parser.guide' +local await = require 'await' +local types = { + 'local', + 'setlocal', + 'setglobal', + 'setfield', + 'setindex' , +} + +---@async return function (uri, callback, code) local ast = files.getState(uri) if not ast then @@ -31,13 +41,9 @@ return function (uri, callback, code) end end - guide.eachSource(ast.ast, function (source) - if source.type == 'local' - or source.type == 'setlocal' - or source.type == 'setglobal' - or source.type == 'setfield' - or source.type == 'setindex' then - checkSet(source) - end + ---@async + guide.eachSourceTypes(ast.ast, types, function (source) + await.delay() + checkSet(source) end) end diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index 69edb380..bacd4288 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -32,7 +32,7 @@ return function (uri, callback) return end local name = source[1] - if name == '...' then + if name == '...' or name == '_' then return end if #vm.getDocSets(uri, name) > 0 diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua index 98919284..7a60a74f 100644 --- a/script/core/diagnostics/undefined-doc-param.lua +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -1,21 +1,6 @@ local files = require 'files' local lang = require 'language' -local function hasParamName(func, name) - if not func.args then - return false - end - for _, arg in ipairs(func.args) do - if arg[1] == name then - return true - end - if arg.type == '...' and name == '...' then - return true - end - end - return false -end - return function (uri, callback) local state = files.getState(uri) if not state then @@ -27,26 +12,13 @@ return function (uri, callback) end for _, doc in ipairs(state.ast.docs) do - if doc.type ~= 'doc.param' then - goto CONTINUE - end - local binds = doc.bindSources - if not binds then - goto CONTINUE - end - local param = doc.param - local name = param[1] - for _, source in ipairs(binds) do - if source.type == 'function' then - if not hasParamName(source, name) then - callback { - start = param.start, - finish = param.finish, - message = lang.script('DIAG_UNDEFINED_DOC_PARAM', name) - } - end - end + if doc.type == 'doc.param' + and not doc.bindSource then + callback { + start = doc.param.start, + finish = doc.param.finish, + message = lang.script('DIAG_UNDEFINED_DOC_PARAM', doc.param[1]) + } end - ::CONTINUE:: end end diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua index 2f559697..1dff575b 100644 --- a/script/core/diagnostics/undefined-env-child.lua +++ b/script/core/diagnostics/undefined-env-child.lua @@ -3,20 +3,40 @@ local guide = require 'parser.guide' local lang = require 'language' local vm = require "vm.vm" +---@param source parser.object +---@return boolean +local function isBindDoc(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' + or doc.type == 'doc.class' then + return true + end + end + return false +end + return function (uri, callback) - local ast = files.getState(uri) - if not ast then + local state = files.getState(uri) + if not state then return end - guide.eachSourceType(ast.ast, 'getglobal', function (source) - -- 单独验证自己是否在重载过的 _ENV 中有定义 + + guide.eachSourceType(state.ast, 'getglobal', function (source) if source.node.tag == '_ENV' then return end - local defs = vm.getDefs(source) - if #defs > 0 then + + if not isBindDoc(source.node) then return end + + if #vm.getDefs(source) > 0 then + return + end + local key = source[1] callback { start = source.start, diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index 41fcda48..a83241f5 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -34,11 +34,11 @@ return function (uri, callback) local node = src.node if node then local ok - for view in vm.getInfer(node):eachView() do - if not skipCheckClass[view] then - ok = true - break + for view in vm.getInfer(node):eachView(uri) do + if skipCheckClass[view] then + return end + ok = true end if not ok then return diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index bd0aae69..179c9204 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -4,6 +4,7 @@ local lang = require 'language' local config = require 'config' local guide = require 'parser.guide' local await = require 'await' +local util = require 'utility' local requireLike = { ['include'] = true, @@ -14,17 +15,17 @@ local requireLike = { ---@async return function (uri, callback) - local ast = files.getState(uri) - if not ast then + local state = files.getState(uri) + if not state then return end - local dglobals = config.get(uri, 'Lua.diagnostics.globals') + local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) local rspecial = config.get(uri, 'Lua.runtime.special') local cache = {} -- 遍历全局变量,检查所有没有 set 模式的全局变量 - guide.eachSourceType(ast.ast, 'getglobal', function (src) ---@async + guide.eachSourceType(state.ast, 'getglobal', function (src) ---@async local key = src[1] if not key then return @@ -40,6 +41,7 @@ return function (uri, callback) return end if cache[key] == nil then + await.delay() cache[key] = vm.hasGlobalSets(uri, 'variable', key) end if cache[key] then diff --git a/script/core/diagnostics/unknown-cast-variable.lua b/script/core/diagnostics/unknown-cast-variable.lua new file mode 100644 index 00000000..3f082a50 --- /dev/null +++ b/script/core/diagnostics/unknown-cast-variable.lua @@ -0,0 +1,32 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.cast' and doc.loc then + await.delay() + local defs = vm.getDefs(doc.loc) + local loc = defs[1] + if not loc then + callback { + start = doc.loc.start, + finish = doc.loc.finish, + message = lang.script('DIAG_UNKNOWN_CAST_VARIABLE', doc.loc[1]) + } + end + end + end +end diff --git a/script/core/diagnostics/unknown-diag-code.lua b/script/core/diagnostics/unknown-diag-code.lua index 9e492a29..07128a27 100644 --- a/script/core/diagnostics/unknown-diag-code.lua +++ b/script/core/diagnostics/unknown-diag-code.lua @@ -1,6 +1,6 @@ local files = require 'files' local lang = require 'language' -local define = require 'proto.define' +local diag = require 'proto.diagnostic' return function (uri, callback) local state = files.getState(uri) @@ -17,7 +17,7 @@ return function (uri, callback) if doc.names then for _, nameUnit in ipairs(doc.names) do local code = nameUnit[1] - if not define.DiagnosticDefaultSeverity[code] then + if not diag.getDiagAndErrNameMap()[code] then callback { start = nameUnit.start, finish = nameUnit.finish, diff --git a/script/core/diagnostics/unknown-operator.lua b/script/core/diagnostics/unknown-operator.lua new file mode 100644 index 00000000..7404b5ef --- /dev/null +++ b/script/core/diagnostics/unknown-operator.lua @@ -0,0 +1,36 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local vm = require 'vm' +local await = require 'await' +local util = require 'utility' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.operator' then + local op = doc.op + if op then + local opName = op[1] + if not vm.OP_BINARY_MAP[opName] + and not vm.OP_UNARY_MAP[opName] + and not vm.OP_OTHER_MAP[opName] then + callback { + start = doc.op.start, + finish = doc.op.finish, + message = lang.script('DIAG_UNKNOWN_OPERATOR', opName) + } + end + end + end + end +end diff --git a/script/core/diagnostics/unreachable-code.lua b/script/core/diagnostics/unreachable-code.lua new file mode 100644 index 00000000..4f0a38b7 --- /dev/null +++ b/script/core/diagnostics/unreachable-code.lua @@ -0,0 +1,84 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' +local define = require 'proto.define' + +---@param source parser.object +---@return boolean +local function allLiteral(source) + local result = true + guide.eachSource(source, function (src) + if src.type ~= 'unary' + and src.type ~= 'binary' + and not guide.isLiteral(src) then + result = false + return false + end + end) + return result +end + +---@param block parser.object +---@return boolean +local function hasReturn(block) + if block.hasReturn or block.hasError then + return true + end + if block.type == 'if' then + local hasElse + for _, subBlock in ipairs(block) do + if not hasReturn(subBlock) then + return false + end + if subBlock.type == 'elseblock' then + hasElse = true + end + end + return hasElse == true + else + if block.type == 'while' then + if vm.testCondition(block.filter) + and not block.breaks + and allLiteral(block.filter) then + return true + end + end + for _, action in ipairs(block) do + if guide.isBlockType(action) then + if hasReturn(action) then + return true + end + end + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceTypes(state.ast, {'main', 'function'}, function (source) + await.delay() + for i, action in ipairs(source) do + if guide.isBlockType(action) + and hasReturn(action) then + if i < #source then + callback { + start = source[i+1].start, + finish = source[#source].finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNREACHABLE_CODE'), + } + end + return + end + end + end) +end diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua index 813ac804..a873375f 100644 --- a/script/core/diagnostics/unused-function.lua +++ b/script/core/diagnostics/unused-function.lua @@ -18,7 +18,8 @@ local function isToBeClosed(source) return false end ----@param source parser.object +---@param source parser.object? +---@return boolean local function isValidFunction(source) if not source then return false @@ -55,7 +56,7 @@ local function collect(ast, white, roots, links) for _, ref in ipairs(loc.ref or {}) do if ref.type == 'getlocal' then local func = guide.getParentFunction(ref) - if not isValidFunction(func) or roots[func] then + if not func or not isValidFunction(func) or roots[func] then roots[src] = true return end diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua index d12ceb2b..8f2ee217 100644 --- a/script/core/diagnostics/unused-local.lua +++ b/script/core/diagnostics/unused-local.lua @@ -3,6 +3,8 @@ local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' local vm = require 'vm.vm' +local config = require 'config.config' +local glob = require 'glob' local function hasGet(loc) if not loc.ref then @@ -63,18 +65,24 @@ local function isDocClass(source) return false end -local function isDocParam(source) - if not source.bindDocs then +---@param func parser.object +---@return boolean +local function isEmptyFunction(func) + if #func > 0 then return false end - for _, doc in ipairs(source.bindDocs) do - if doc.type == 'doc.param' then - if doc.param[1] == source[1] then - return true - end - end + local startRow = guide.rowColOf(func.start) + local finishRow = guide.rowColOf(func.finish) + return finishRow - startRow <= 1 +end + +---@param source parser.object +local function isDeclareFunctionParam(source) + if source.parent.type ~= 'funcargs' then + return false end - return false + local func = source.parent.parent + return isEmptyFunction(func) end return function (uri, callback) @@ -82,19 +90,24 @@ return function (uri, callback) if not ast then return end + local ignorePatterns = config.get(uri, 'Lua.diagnostics.unusedLocalExclude') + local ignore = glob.glob(ignorePatterns) guide.eachSourceType(ast.ast, 'local', function (source) local name = source[1] if name == '_' or name == ast.ENVMode then return end + if ignore(name) then + return + end if isToBeClosed(source) then return end if isDocClass(source) then return end - if vm.isMetaFile(uri) and isDocParam(source) then + if isDeclareFunctionParam(source) then return end local data = hasGet(source) diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua index ce033cf3..08f12c4d 100644 --- a/script/core/diagnostics/unused-vararg.lua +++ b/script/core/diagnostics/unused-vararg.lua @@ -15,6 +15,9 @@ return function (uri, callback) end guide.eachSourceType(ast.ast, 'function', function (source) + if #source == 0 then + return + end local args = source.args if not args then return diff --git a/script/core/find-source.lua b/script/core/find-source.lua index 26a411e5..99013b31 100644 --- a/script/core/find-source.lua +++ b/script/core/find-source.lua @@ -21,7 +21,7 @@ return function (ast, position, accept) end end local start, finish = guide.getStartFinish(source) - if finish - start < len and accept[source.type] then + if finish - start <= len and accept[source.type] then result = source len = finish - start end diff --git a/script/core/folding.lua b/script/core/folding.lua index 4f93aed9..7f59636e 100644 --- a/script/core/folding.lua +++ b/script/core/folding.lua @@ -66,7 +66,8 @@ local care = { ['repeat'] = function (source, text, results) local start = source.start local finish = source.keyword[#source.keyword] - if text:sub(finish - #'until' + 1, finish) ~= 'until' then + -- must end with 'until' + if #source.keyword ~= 4 then return end local folding = { @@ -143,6 +144,15 @@ local care = { } results[#results+1] = folding end, + ['doc.alias'] = function (source, text, results) + local folding = { + start = source.start, + finish = source.bindGroup[#source.bindGroup].finish, + kind = 'comment', + hideLastLine = true, + } + results[#results+1] = folding + end, } ---@async diff --git a/script/core/formatting.lua b/script/core/formatting.lua index b52854a4..fb5ca9c7 100644 --- a/script/core/formatting.lua +++ b/script/core/formatting.lua @@ -4,7 +4,10 @@ local log = require("log") return function(uri, options) local text = files.getOriginText(uri) - local ast = files.getState(uri) + local state = files.getState(uri) + if not state then + return + end local status, formattedText = codeFormat.format(uri, text, options) if not status then @@ -17,8 +20,8 @@ return function(uri, options) return { { - start = ast.ast.start, - finish = ast.ast.finish, + start = state.ast.start, + finish = state.ast.finish, text = formattedText, } } diff --git a/script/core/hint.lua b/script/core/hint.lua index f97cdcec..767e531e 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -5,6 +5,7 @@ local guide = require 'parser.guide' local await = require 'await' local define = require 'proto.define' local lang = require 'language' +local substr = require 'core.substring' ---@async local function typeHint(uri, results, start, finish) @@ -38,7 +39,7 @@ local function typeHint(uri, results, start, finish) end end await.delay() - local view = vm.getInfer(source):view() + local view = vm.getInfer(source):view(uri) if view == 'any' or view == 'unknown' or view == 'nil' then @@ -189,24 +190,44 @@ local function arrayIndex(uri, results, start, finish) end ---@async - guide.eachSourceBetween(state.ast, start, finish, function (source) - if source.type ~= 'tableexp' then + guide.eachSourceType(state.ast, 'table', function (source) + if source.finish < start or source.start > finish then return end await.delay() if option == 'Auto' then - if not isMixedOrLargeTable(source.parent) then + if not isMixedOrLargeTable(source) then return end end - results[#results+1] = { - text = ('[%d]'):format(source.tindex), - offset = source.start, - kind = define.InlayHintKind.Other, - where = 'left', - source = source.parent, - } + local list = {} + local max = 0 + for _, field in ipairs(source) do + if field.type == 'tableexp' + and field.start < finish + and field.finish > start then + list[#list+1] = field + if field.tindex > max then + max = field.tindex + end + end + end + + if #list > 0 then + local length = #tostring(max) + local fmt = '[%0' .. length .. 'd]' + for _, field in ipairs(list) do + results[#results+1] = { + text = fmt:format(field.tindex), + offset = field.start, + kind = define.InlayHintKind.Other, + where = 'left', + source = field.parent, + } + end + end end) + end ---@async @@ -238,6 +259,72 @@ local function awaitHint(uri, results, start, finish) end) end +local blockTypes = { + 'main', + 'function', + 'for', + 'loop', + 'in', + 'do', + 'repeat', + 'while', + 'ifblock', + 'elseifblock', + 'elseblock', +} + +---@async +local function semicolonHint(uri, results, start, finish) + local state = files.getState(uri) + if not state then + return + end + local mode = config.get(uri, 'Lua.hint.semicolon') + if mode == 'Disable' then + return + end + local subber = substr(state) + ---@async + guide.eachSourceTypes(state.ast, blockTypes, function (src) + await.delay() + for i = 1, #src - 1 do + local current = src[i] + local next = src[i+1] + local left = current.finish + local right = next.start + local text = subber(left, right) + if mode == 'All' then + if not text:find '[,;]' then + results[#results+1] = { + text = ';', + offset = left, + kind = define.InlayHintKind.Other, + where = 'right', + } + end + elseif mode == 'SameLine' then + if not text:find '[,;\r\n]' then + results[#results+1] = { + text = ';', + offset = left, + kind = define.InlayHintKind.Other, + where = 'right', + } + end + end + end + if mode == 'All' then + local last = src[#src] + results[#results+1] = { + text = ';', + offset = last.range or last.finish, + kind = define.InlayHintKind.Other, + where = 'right', + } + end + end) +end + ---@async return function (uri, start, finish) local results = {} @@ -245,5 +332,6 @@ return function (uri, start, finish) paramName(uri, results, start, finish) awaitHint(uri, results, start, finish) arrayIndex(uri, results, start, finish) + semicolonHint(uri, results, start, finish) return results end diff --git a/script/core/hover/args.lua b/script/core/hover/args.lua index c485d9b9..bb4d4297 100644 --- a/script/core/hover/args.lua +++ b/script/core/hover/args.lua @@ -9,7 +9,7 @@ local function asFunction(source) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view 'any') + args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view(guide.getUri(source), 'any')) end if source.args then for i = 1, #source.args do @@ -29,15 +29,15 @@ local function asFunction(source) args[#args+1] = ('%s%s: %s'):format( name, optional and '?' or '', - vm.getInfer(argNode):view('any', guide.getUri(source)) + vm.getInfer(argNode):view(guide.getUri(source), 'any') ) elseif arg.type == '...' then - args[#args+1] = ('%s: %s'):format( + args[#args+1] = ('%s%s'):format( '...', - vm.getInfer(arg):view 'any' + vm.getInfer(arg):view(guide.getUri(source), 'any') ) else - args[#args+1] = ('%s'):format(vm.getInfer(arg):view 'any') + args[#args+1] = ('%s'):format(vm.getInfer(arg):view(guide.getUri(source), 'any')) end ::CONTINUE:: end @@ -46,17 +46,17 @@ local function asFunction(source) end local function asDocFunction(source) + local args = {} if not source.args then - return '' + return args end - local args = {} for i = 1, #source.args do local arg = source.args[i] local name = arg.name[1] args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - arg.extends and vm.getInfer(arg.extends):view 'any' or 'any' + arg.extends and vm.getInfer(arg.extends):view(guide.getUri(source), 'any') or 'any' ) end return args diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index e9267c0f..e11dd6c8 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -6,11 +6,12 @@ local lang = require 'language' local util = require 'utility' local guide = require 'parser.guide' local rpath = require 'workspace.require-path' +local furi = require 'file-uri' local function collectRequire(mode, literal, uri) local result, searchers if mode == 'require' then - result, searchers = rpath.findUrisByRequirePath(uri, literal) + result, searchers = rpath.findUrisByRequireName(uri, literal) elseif mode == 'dofile' or mode == 'loadfile' then result = ws.findUrisByFilePath(literal) @@ -82,7 +83,53 @@ local function asString(source) or asStringView(source, literal) end -local function getBindComment(source, docGroup, base) +---@param comment string +---@param suri uri +---@return string? +local function normalizeComment(comment, suri) + if not comment then + return nil + end + if comment:sub(1, 1) == '-' then + comment = comment:sub(2) + end + if comment:sub(1, 1) == '@' then + return nil + end + comment = comment:gsub('(%[.-%]%()(.-)(%))', function (left, path, right) + local scheme = furi.split(path) + if scheme + -- strange way to check `C:/xxx.lua` + and #scheme > 1 then + return + end + local absPath = ws.getAbsolutePath(suri:gsub('/[^/]+$', ''), path) + if not absPath then + return + end + local uri = furi.encode(absPath) + return left .. uri .. right + end) + return comment +end + +local function getBindComment(source) + local uri = guide.getUri(source) + local lines = {} + for _, docComment in ipairs(source.bindComments) do + lines[#lines+1] = normalizeComment(docComment.comment.text, uri) + end + if not lines or #lines == 0 then + return nil + end + return table.concat(lines, '\n') +end + +local function lookUpDocComments(source) + local docGroup = source.bindDocs + if not docGroup then + return + end if source.type == 'setlocal' or source.type == 'getlocal' then source = source.node @@ -90,34 +137,23 @@ local function getBindComment(source, docGroup, base) if source.parent.type == 'funcargs' then return end - local continue - local lines + local uri = guide.getUri(source) + local lines = {} for _, doc in ipairs(docGroup) do if doc.type == 'doc.comment' then - if not continue then - continue = true - lines = {} + lines[#lines+1] = normalizeComment(doc.comment.text, uri) + elseif doc.type == 'doc.type' then + if doc.comment then + lines[#lines+1] = normalizeComment(doc.comment.text, uri) end - if doc.comment.text:sub(1, 1) == '-' then - lines[#lines+1] = doc.comment.text:sub(2) - else - lines[#lines+1] = doc.comment.text - end - elseif doc == base then - break - else - continue = false - if doc.type == 'doc.field' - or doc.type == 'doc.class' then - lines = nil + elseif doc.type == 'doc.class' then + for _, docComment in ipairs(doc.bindComments) do + lines[#lines+1] = normalizeComment(docComment.comment.text, uri) end end end if source.comment then - if not lines then - lines = {} - end - lines[#lines+1] = source.comment.text + lines[#lines+1] = normalizeComment(source.comment.text, uri) end if not lines or #lines == 0 then return nil @@ -128,8 +164,9 @@ end local function tryDocClassComment(source) for _, def in ipairs(vm.getDefs(source)) do if def.type == 'doc.class' - or def.type == 'doc.alias' then - local comment = getBindComment(def, def.bindGroup, def) + or def.type == 'doc.alias' + or def.type == 'doc.enum' then + local comment = getBindComment(def) if comment then return comment end @@ -144,7 +181,7 @@ local function tryDocModule(source) return collectRequire('require', source.module, guide.getUri(source)) end -local function buildEnumChunk(docType, name) +local function buildEnumChunk(docType, name, uri) if not docType then return nil end @@ -152,10 +189,11 @@ local function buildEnumChunk(docType, name) local types = {} local lines = {} for _, tp in ipairs(vm.getDefs(docType)) do - types[#types+1] = vm.getInfer(tp):view() + types[#types+1] = vm.getInfer(tp):view(guide.getUri(docType)) if tp.type == 'doc.type.string' or tp.type == 'doc.type.integer' - or tp.type == 'doc.type.boolean' then + or tp.type == 'doc.type.boolean' + or tp.type == 'doc.type.code' then enums[#enums+1] = tp end local comment = tryDocClassComment(tp) @@ -174,7 +212,7 @@ local function buildEnumChunk(docType, name) (enum.default and '->') or (enum.additional and '+>') or ' |', - vm.viewObject(enum) + vm.viewObject(enum, uri) ) if enum.comment then local first = true @@ -198,26 +236,33 @@ local function getBindEnums(source, docGroup) return end + local uri = guide.getUri(source) local mark = {} local chunks = {} local returnIndex = 0 for _, doc in ipairs(docGroup) do if doc.type == 'doc.param' then local name = doc.param[1] + if name == '...' then + name = '...(param)' + end if mark[name] then goto CONTINUE end mark[name] = true - chunks[#chunks+1] = buildEnumChunk(doc.extends, name) + chunks[#chunks+1] = buildEnumChunk(doc.extends, name, uri) elseif doc.type == 'doc.return' then for _, rtn in ipairs(doc.returns) do returnIndex = returnIndex + 1 local name = rtn.name and rtn.name[1] or ('return #%d'):format(returnIndex) + if name == '...' then + name = '...(return)' + end if mark[name] then goto CONTINUE end mark[name] = true - chunks[#chunks+1] = buildEnumChunk(rtn, name) + chunks[#chunks+1] = buildEnumChunk(rtn, name, uri) end end ::CONTINUE:: @@ -228,37 +273,38 @@ local function getBindEnums(source, docGroup) return table.concat(chunks, '\n\n') end -local function tryDocFieldUpComment(source) - if source.type ~= 'doc.field.name' then +local function tryDocFieldComment(source) + if source.type ~= 'doc.field' then return end - local docField = source.parent - if not docField.bindGroup then - return + if source.comment then + return normalizeComment(source.comment.text, guide.getUri(source)) + end + if source.bindGroup then + return getBindComment(source) end - local comment = getBindComment(docField, docField.bindGroup, docField) - return comment end local function getFunctionComment(source) local docGroup = source.bindDocs + if not docGroup then + return + end local hasReturnComment = false - for _, doc in ipairs(docGroup) do + for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.return' and doc.comment then hasReturnComment = true break end end + local uri = guide.getUri(source) local md = markdown() for _, doc in ipairs(docGroup) do if doc.type == 'doc.comment' then - if doc.comment.text:sub(1, 1) == '-' then - md:add('md', doc.comment.text:sub(2)) - else - md:add('md', doc.comment.text) - end + local comment = normalizeComment(doc.comment.text, uri) + md:add('md', comment) elseif doc.type == 'doc.param' then if doc.comment then md:add('md', ('@*param* `%s` — %s'):format( @@ -295,18 +341,36 @@ local function getFunctionComment(source) local enums = getBindEnums(source, docGroup) md:add('lua', enums) - return md + + local comment = md:string() + if comment == '' then + return nil + end + return comment end local function tryDocComment(source) - if not source.bindDocs then - return + local md = markdown() + if source.type == 'function' then + local comment = getFunctionComment(source) + md:add('md', comment) + source = source.parent end - if source.type ~= 'function' then - local comment = getBindComment(source, source.bindDocs) - return comment + local comment = lookUpDocComments(source) + md:add('md', comment) + if source.type == 'doc.alias' then + local enums = buildEnumChunk(source, source.alias[1], guide.getUri(source)) + md:add('lua', enums) end - return getFunctionComment(source) + if source.type == 'doc.enum' then + local enums = buildEnumChunk(source, source.enum[1], guide.getUri(source)) + md:add('lua', enums) + end + local result = md:string() + if result == '' then + return nil + end + return result end local function tryDocOverloadToComment(source) @@ -315,14 +379,12 @@ local function tryDocOverloadToComment(source) end local doc = source.parent if doc.type ~= 'doc.overload' - or not doc.bindSources then + or not doc.bindSource then return end - for _, src in ipairs(doc.bindSources) do - local md = tryDocComment(src) - if md then - return md - end + local md = tryDocComment(doc.bindSource) + if md then + return md end end @@ -350,6 +412,45 @@ local function tyrDocParamComment(source) end end +---@param source parser.object +local function tryDocEnum(source) + if source.type ~= 'doc.enum' then + return + end + local tbl = source.bindSource + if not tbl then + return + end + local md = markdown() + md:add('lua', '{') + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + if field.value.type == 'integer' + or field.value.type == 'string' then + md:add('lua', (' %s: %s = %s,'):format(key, field.value.type, field.value[1])) + end + if field.value.type == 'binary' + or field.value.type == 'unary' then + local number = vm.getNumber(field.value) + if number then + md:add('lua', (' %s: %s = %s,'):format(key, math.tointeger(number) and 'integer' or 'number', number)) + end + end + ::CONTINUE:: + end + end + md:add('lua', '}') + return md:string() +end + return function (source) if source.type == 'string' then return asString(source) @@ -358,9 +459,10 @@ return function (source) source = source.parent end return tryDocOverloadToComment(source) - or tryDocFieldUpComment(source) + or tryDocFieldComment(source) or tyrDocParamComment(source) or tryDocComment(source) or tryDocClassComment(source) or tryDocModule(source) + or tryDocEnum(source) end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index 7231944a..5a65cbce 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -39,7 +39,7 @@ local function getHover(source) end local oop - if vm.getInfer(source):view() == 'function' then + if vm.getInfer(source):view(guide.getUri(source)) == 'function' then local defs = vm.getDefs(source) -- make sure `function` is before `doc.type.function` local orders = {} @@ -92,19 +92,21 @@ local function getHover(source) end local accept = { - ['local'] = true, - ['setlocal'] = true, - ['getlocal'] = true, - ['setglobal'] = true, - ['getglobal'] = true, - ['field'] = true, - ['method'] = true, - ['string'] = true, - ['number'] = true, - ['integer'] = true, - ['doc.type.name'] = true, - ['function'] = true, - ['doc.module'] = true, + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['field'] = true, + ['method'] = true, + ['string'] = true, + ['number'] = true, + ['integer'] = true, + ['doc.type.name'] = true, + ['doc.class.name'] = true, + ['doc.enum.name'] = true, + ['function'] = true, + ['doc.module'] = true, } ---@async diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 2bbfe806..5c502ec1 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -33,7 +33,10 @@ local function asDocTypeName(source) return '(class) ' .. doc.class[1] end if doc.type == 'doc.alias' then - return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view()) + return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view(guide.getUri(source))) + end + if doc.type == 'doc.enum' then + return '(enum) ' .. doc.enum[1] end end end @@ -42,7 +45,7 @@ end local function asValue(source, title) local name = buildName(source, false) or '' local ifr = vm.getInfer(source) - local type = ifr:view() + local type = ifr:view(guide.getUri(source)) local literal = ifr:viewLiterals() local cont = buildTable(source) local pack = {} @@ -55,10 +58,11 @@ local function asValue(source, title) and ( type == 'table' or type == 'any' or type == 'unknown' - or type == 'nil') then - type = nil + or type == 'nil' + or type:sub(1, 1) == '{') then + else + pack[#pack+1] = type end - pack[#pack+1] = type if literal then pack[#pack+1] = '=' pack[#pack+1] = literal @@ -139,7 +143,7 @@ local function asDocFieldName(source) break end end - local view = vm.getInfer(source.extends):view() + local view = vm.getInfer(source.extends):view(guide.getUri(source)) if not class then return ('(field) ?.%s: %s'):format(name, view) end @@ -212,7 +216,8 @@ return function (source, oop) elseif source.type == 'number' or source.type == 'integer' then return asNumber(source) - elseif source.type == 'doc.type.name' then + elseif source.type == 'doc.type.name' + or source.type == 'doc.enum.name' then return asDocTypeName(source) elseif source.type == 'doc.field' then return asDocFieldName(source) diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua index f8473638..3fabfb89 100644 --- a/script/core/hover/name.lua +++ b/script/core/hover/name.lua @@ -20,6 +20,9 @@ local function asField(source, oop) local class if source.node.type ~= 'getglobal' then class = vm.getInfer(source.node):viewClass() + if class == 'any' or class == 'unknown' then + class = nil + end end local node = class or buildName(source.node, false) @@ -47,14 +50,12 @@ end local function asDocFunction(source, oop) local doc = guide.getParentType(source, 'doc.type') or guide.getParentType(source, 'doc.overload') - if not doc or not doc.bindSources then + if not doc or not doc.bindSource then return '' end - for _, src in ipairs(doc.bindSources) do - local name = buildName(src, oop) - if name ~= '' then - return name - end + local name = buildName(doc.bindSource, oop) + if name ~= '' then + return name end return '' end diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index 3d8a94a5..b71b9e5d 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,34 +1,5 @@ local vm = require 'vm.vm' - ----@param source parser.object ----@return integer -local function countReturns(source) - local n = 0 - - local docs = source.bindDocs - if docs then - for _, doc in ipairs(docs) do - if doc.type == 'doc.return' then - for _, rtn in ipairs(doc.returns) do - if rtn.returnIndex and rtn.returnIndex > n then - n = rtn.returnIndex - end - end - end - end - end - - local returns = source.returns - if returns then - for _, rtn in ipairs(returns) do - if #rtn > n then - n = #rtn - end - end - end - - return n -end +local guide = require 'parser.guide' ---@param source parser.object ---@return parser.object[] @@ -50,7 +21,7 @@ local function getReturnDocs(source) end local function asFunction(source) - local num = countReturns(source) + local _, _, num = vm.countReturnsOfFunction(source) if num == 0 then return nil end @@ -62,11 +33,14 @@ local function asFunction(source) for i = 1, num do local rtn = vm.getReturnOfFunction(source, i) local doc = docs[i] - local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ') - local text = ('%s%s'):format( + local name = doc and doc.name and doc.name[1] + if name and name ~= '...' then + name = name .. ': ' + end + local text = rtn and ('%s%s'):format( name or '', - vm.getInfer(rtn):view() - ) + vm.getInfer(rtn):view(guide.getUri(source)) + ) or 'unknown' if i == 1 then returns[i] = (' -> %s'):format(text) else @@ -83,7 +57,14 @@ local function asDocFunction(source) end local returns = {} for i, rtn in ipairs(source.returns) do - local rtnText = vm.getInfer(rtn):view() + local rtnText = vm.getInfer(rtn):view(guide.getUri(source)) + if rtn.name then + if rtn.name[1] == '...' then + rtnText = rtn.name[1] .. rtnText + else + rtnText = rtn.name[1] .. ': ' .. rtnText + end + end if i == 1 then returns[#returns+1] = (' -> %s'):format(rtnText) else diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index 16874101..677fd76c 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -30,7 +30,7 @@ local function buildAsHash(uri, keys, nodeMap, reachMax) node:removeOptional() end local ifr = vm.getInfer(node) - local typeView = ifr:view('unknown', uri) + local typeView = ifr:view(uri, 'unknown') local literalView = ifr:viewLiterals() if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( @@ -75,7 +75,7 @@ local function buildAsConst(uri, keys, nodeMap, reachMax) node = node:copy() node:removeOptional() end - local typeView = vm.getInfer(node):view('unknown', uri) + local typeView = vm.getInfer(node):view(uri, 'unknown') local literalView = literalMap[key] if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( @@ -154,7 +154,7 @@ local function getNodeMap(fields, keyMap) local nodeMap = {} for _, field in ipairs(fields) do local key = vm.getKeyName(field) - if not keyMap[key] then + if not key or not keyMap[key] then goto CONTINUE end await.delay() @@ -178,9 +178,15 @@ return function (source) return nil end - for view in vm.getInfer(source):eachView() do - if view == 'string' - or vm.isSubType(uri, view, 'string') then + local node = vm.compileNode(source) + for n in node:eachObject() do + if n.type == 'global' and n.cate == 'type' then + if n.name == 'string' + or (n.name ~= 'unknown' and n.name ~= 'any' and vm.isSubType(uri, n.name, 'string')) then + return nil + end + elseif n.type == 'doc.type.string' + or n.type == 'string' then return nil end end diff --git a/script/core/jump-source.lua b/script/core/jump-source.lua new file mode 100644 index 00000000..5ce5e048 --- /dev/null +++ b/script/core/jump-source.lua @@ -0,0 +1,62 @@ +local guide = require 'parser.guide' +local furi = require 'file-uri' +local ws = require 'workspace' + +---@param doc parser.object +---@return uri +local function parseUri(doc) + local uri + local scheme = furi.split(doc.path) + if scheme and #scheme >= 2 then + uri = doc.path + else + local suri = guide.getUri(doc):gsub('[/\\][^/\\]*$', '') + local path = ws.getAbsolutePath(suri, doc.path) + if path then + uri = furi.encode(path) + else + uri = doc.path + end + end + ---@cast uri uri + return uri +end + +---@param results table +return function (results) + for _, result in ipairs(results) do + if result.target.type == 'doc.field.name' + or result.target.type == 'doc.class.name' then + local doc = result.target.parent.source + if doc then + local uri = parseUri(doc) + result.uri = uri + result.target = { + uri = uri, + start = guide.positionOf(doc.line - 1, doc.char), + finish = guide.positionOf(doc.line - 1, doc.char), + } + end + else + local target = result.target + if target.type == 'method' + or target.type == 'field' then + target = target.parent + end + if target.bindDocs then + for _, doc in ipairs(target.bindDocs) do + if doc.type == 'doc.source' + and doc.bindSource == target then + local uri = parseUri(doc) + result.uri = uri + result.target = { + uri = uri, + start = guide.positionOf(doc.line - 1, doc.char), + finish = guide.positionOf(doc.line - 1, doc.char), + } + end + end + end + end + end +end diff --git a/script/core/look-backward.lua b/script/core/look-backward.lua index eeee6017..8d3e3439 100644 --- a/script/core/look-backward.lua +++ b/script/core/look-backward.lua @@ -81,9 +81,19 @@ function m.findTargetSymbol(text, offset, symbol) return nil end -function m.findAnyOffset(text, offset) +---@param text string +---@param offset integer +---@param inline? boolean # 必须在同一行中(排除换行符) +function m.findAnyOffset(text, offset, inline) for i = offset, 1, -1 do - if not m.isSpace(text:sub(i, i)) then + local c = text:sub(i, i) + if inline then + if c == '\r' + or c == '\n' then + return nil + end + end + if not m.isSpace(c) then return i end end diff --git a/script/core/reference.lua b/script/core/reference.lua index 4c9c193d..fa838cff 100644 --- a/script/core/reference.lua +++ b/script/core/reference.lua @@ -2,6 +2,7 @@ local guide = require 'parser.guide' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' +local jumpSource = require 'core.jump-source' local function sortResults(results) -- 先按照顺序排序 @@ -49,6 +50,7 @@ local accept = { ['doc.class.name'] = true, ['doc.extends.name'] = true, ['doc.alias.name'] = true, + ['doc.enum.name'] = true, } ---@async @@ -101,12 +103,17 @@ return function (uri, position) if src.type == 'doc.alias' then src = src.alias end + if src.type == 'doc.enum' then + src = src.enum + end if src.type == 'doc.class.name' or src.type == 'doc.alias.name' + or src.type == 'doc.enum.name' or src.type == 'doc.type.name' or src.type == 'doc.extends.name' then if source.type ~= 'doc.type.name' and source.type ~= 'doc.class.name' + and source.type ~= 'doc.enum.name' and source.type ~= 'doc.extends.name' and source.type ~= 'doc.see.name' then goto CONTINUE @@ -132,6 +139,7 @@ return function (uri, position) end sortResults(results) + jumpSource(results) return results end diff --git a/script/core/rename.lua b/script/core/rename.lua index 7599fad6..90e66224 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -81,6 +81,9 @@ local function renameField(source, newname, callback) local uri = guide.getUri(source) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return false + end local func = parent.value -- function mt:name () end --> mt['newname'] = function (self) end local startOffset = guide.positionToOffset(state, parent.start) + 1 @@ -183,13 +186,16 @@ local function ofField(source, newname, callback) local key = guide.getKeyName(source) local refs = vm.getRefs(source) for _, ref in ipairs(refs) do - ofFieldThen(key, ref, newname, callback) + ofFieldThen(key, ref, newname, callback) end end ---@async local function ofGlobal(source, newname, callback) local key = guide.getKeyName(source) + if not key then + return + end local global = vm.getGlobal('variable', key) if not global then return @@ -225,6 +231,9 @@ local function ofDocTypeName(source, newname, callback) if doc.type == 'doc.alias' then callback(doc, doc.alias.start, doc.alias.finish, newname) end + if doc.type == 'doc.enum' then + callback(doc, doc.enum.start, doc.enum.finish, newname) + end end for _, doc in ipairs(global:getGets(uri)) do if doc.type == 'doc.type.name' then @@ -236,16 +245,15 @@ end local function ofDocParamName(source, newname, callback) callback(source, source.start, source.finish, newname) local doc = source.parent - if doc.bindSources then - for _, src in ipairs(doc.bindSources) do - if src.type == 'local' - and src.parent.type == 'funcargs' - and src[1] == source[1] then - renameLocal(src, newname, callback) - if src.ref then - for _, ref in ipairs(src.ref) do - renameLocal(ref, newname, callback) - end + local src = doc.bindSource + if src then + if src.type == 'local' + and src.parent.type == 'funcargs' + and src[1] == source[1] then + renameLocal(src, newname, callback) + if src.ref then + for _, ref in ipairs(src.ref) do + renameLocal(ref, newname, callback) end end end @@ -271,7 +279,8 @@ local function rename(source, newname, callback) return ofGlobal(source, newname, callback) elseif source.type == 'doc.class.name' or source.type == 'doc.type.name' - or source.type == 'doc.alias.name' then + or source.type == 'doc.alias.name' + or source.type == 'doc.enum.name' then return ofDocTypeName(source, newname, callback) elseif source.type == 'doc.param.name' then return ofDocParamName(source, newname, callback) @@ -305,6 +314,7 @@ local function prepareRename(source) or source.type == 'doc.class.name' or source.type == 'doc.type.name' or source.type == 'doc.alias.name' + or source.type == 'doc.enum.name' or source.type == 'doc.param.name' then return source, source[1] elseif source.type == 'string' @@ -345,6 +355,7 @@ local accept = { ['doc.type.name'] = true, ['doc.alias.name'] = true, ['doc.param.name'] = true, + ['doc.enum.name'] = true, } local m = {} diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 33449013..5833807b 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -32,7 +32,7 @@ local Care = util.switch() end options.libGlobals[name] = isLib end - local isFunc = vm.getInfer(source):hasFunction() + local isFunc = vm.getInfer(source):hasFunction(guide.getUri(source)) local type = isFunc and define.TokenTypes['function'] or define.TokenTypes.variable local modifier = isLib and define.TokenModifiers.defaultLibrary or define.TokenModifiers.static @@ -81,7 +81,7 @@ local Care = util.switch() return end end - if vm.getInfer(source):hasFunction() then + if vm.getInfer(source):hasFunction(guide.getUri(source)) then results[#results+1] = { start = source.start, finish = source.finish, @@ -134,19 +134,16 @@ local Care = util.switch() return end local loc = source.node or source + local uri = guide.getUri(loc) -- 1. 值为函数的局部变量 | Local variable whose value is a function - if loc.ref then - for _, ref in ipairs(loc.ref) do - if ref.value and ref.value.type == 'function' then - results[#results+1] = { - start = source.start, - finish = source.finish, - type = define.TokenTypes['function'], - modifieres = define.TokenModifiers.declaration, - } - return - end - end + if vm.getInfer(source):hasFunction(uri) then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes['function'], + modifieres = define.TokenModifiers.declaration, + } + return end -- 3. 特殊变量 | Special variableif source[1] == '_ENV' then if loc[1] == '_ENV' then @@ -196,7 +193,7 @@ local Care = util.switch() end end -- 6. References to other functions - if vm.getInfer(loc):hasFunction() then + if vm.getInfer(loc):hasFunction(guide.getUri(source)) then results[#results+1] = { start = source.start, finish = source.finish, @@ -449,6 +446,7 @@ local Care = util.switch() end end) : case 'doc.alias.name' + : case 'doc.enum.name' : call(function (source, options, results) if not options.annotation then return @@ -667,6 +665,14 @@ local Care = util.switch() type = define.TokenTypes.keyword, } end) + : case 'doc.cast.block' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.keyword, + } + end) : case 'doc.cast.name' : call(function (source, options, results) results[#results+1] = { @@ -675,6 +681,23 @@ local Care = util.switch() type = define.TokenTypes.variable, } end) + : case 'doc.type.code' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.string, + modifieres = define.TokenModifiers.abstract, + } + end) + : case 'doc.operator.name' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.operator, + } + end) local function buildTokens(uri, results) local tokens = {} @@ -811,9 +834,13 @@ return function (uri, start, finish) keyword = config.get(uri, 'Lua.semantic.keyword'), } + local n = 0 guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async Care(source.type, source, options, results) - await.delay() + n = n + 1 + if n % 100 == 0 then + await.delay() + end end) for _, comm in ipairs(state.comms) do diff --git a/script/core/signature.lua b/script/core/signature.lua index 025e70b7..21e954bf 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -8,6 +8,9 @@ local lookback = require 'core.look-backward' local function findNearCall(uri, ast, pos) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return nil + end local nearCall guide.eachSourceContain(ast.ast, pos, function (src) if src.type == 'call' @@ -65,27 +68,30 @@ local function makeOneSignature(source, oop, index) } end -- 不定参数 - if index > i and i > 0 then + if index and index > i and i > 0 then local lastLabel = params[i].label local text = label:sub(lastLabel[1] + 1, lastLabel[2]) if text:sub(1, 3) == '...' then index = i end end + if #params < (index or 0) then + return nil + end return { label = label, params = params, - index = index, + index = index or 1, description = hoverDesc(source), } end ---@async local function makeSignatures(text, call, pos) - local node = call.node - local oop = node.type == 'method' - or node.type == 'getmethod' - or node.type == 'setmethod' + local func = call.node + local oop = func.type == 'method' + or func.type == 'getmethod' + or func.type == 'setmethod' local index if call.args then local args = {} @@ -121,13 +127,13 @@ local function makeSignatures(text, call, pos) index = #args end end - else - index = 1 end local signs = {} - local defs = vm.getDefs(node) + local node = vm.compileNode(func) + ---@type vm.node + node = node:getData 'originNode' or node local mark = {} - for _, src in ipairs(defs) do + for src in node:eachObject() do if src.type == 'function' or src.type == 'doc.type.function' then if not mark[src] then @@ -142,10 +148,10 @@ end ---@async return function (uri, pos) local state = files.getState(uri) - if not state then + local text = files.getText(uri) + if not state or not text then return nil end - local text = files.getText(uri) local offset = guide.positionToOffset(state, pos) pos = guide.offsetToPosition(state, lookback.skipSpace(text, offset)) local call = findNearCall(uri, state, pos) @@ -156,5 +162,8 @@ return function (uri, pos) if not signs or #signs == 0 then return nil end + table.sort(signs, function (a, b) + return #a.params < #b.params + end) return signs end diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua index d8434c8c..a1c2b29f 100644 --- a/script/core/type-definition.lua +++ b/script/core/type-definition.lua @@ -4,6 +4,7 @@ local vm = require 'vm' local findSource = require 'core.find-source' local guide = require 'parser.guide' local rpath = require 'workspace.require-path' +local jumpSource = require 'core.jump-source' local function sortResults(results) -- 先按照顺序排序 @@ -51,6 +52,7 @@ local accept = { ['doc.class.name'] = true, ['doc.extends.name'] = true, ['doc.alias.name'] = true, + ['doc.enum.name'] = true, ['doc.see.name'] = true, ['doc.see.field'] = true, } @@ -74,7 +76,7 @@ local function checkRequire(source, offset) return nil end if libName == 'require' then - return rpath.findUrisByRequirePath(guide.getUri(source), literal) + return rpath.findUrisByRequireName(guide.getUri(source), literal) elseif libName == 'dofile' or libName == 'loadfile' then return workspace.findUrisByFilePath(literal) @@ -144,6 +146,9 @@ return function (uri, offset) if src.type == 'doc.alias' then src = src.alias end + if src.type == 'doc.enum' then + src = src.enum + end if src.type == 'doc.class.name' or src.type == 'doc.alias.name' or src.type == 'doc.type.function' @@ -164,6 +169,7 @@ return function (uri, offset) end sortResults(results) + jumpSource(results) return results end |