diff options
author | fesily <fesil@foxmail.com> | 2024-01-10 11:03:21 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-10 11:03:21 +0800 |
commit | 1e7bb72ad3ff2b75a1c55ee4bc53004cb7fe30f7 (patch) | |
tree | dd94cf09b9e3a675a73a9f1d41248d1538165997 /script | |
parent | bb6e172d6166190bd4edd3bb56230a7d60ebcb93 (diff) | |
parent | 37779f9b2493e51e59e1e4366bf7dcb8350e69bd (diff) | |
download | lua-language-server-1e7bb72ad3ff2b75a1c55ee4bc53004cb7fe30f7.zip |
Merge branch 'LuaLS:master' into plugin-add-OnTransformAst
Diffstat (limited to 'script')
52 files changed, 1354 insertions, 485 deletions
diff --git a/script/await.lua b/script/await.lua index fa2aea13..22745570 100644 --- a/script/await.lua +++ b/script/await.lua @@ -108,6 +108,11 @@ function m.hasID(id, co) return m.idMap[id] and m.idMap[id][co] ~= nil end +function m.unique(id, callback) + m.close(id) + m.setID(id, callback) +end + --- 休眠一段时间 ---@param time number ---@async diff --git a/script/brave/work.lua b/script/brave/work.lua index 9eb756eb..a6a7a41e 100644 --- a/script/brave/work.lua +++ b/script/brave/work.lua @@ -15,10 +15,7 @@ end) brave.on('loadProtoBySocket', function (param) local jsonrpc = require 'jsonrpc' - local socket = require 'bee.socket' - local util = require 'utility' - local rfd = socket.fd(param.rfd) - local wfd = socket.fd(param.wfd) + local net = require 'service.net' local buf = '' ---@async @@ -44,29 +41,24 @@ brave.on('loadProtoBySocket', function (param) end end) + local lsclient = net.connect('tcp', '127.0.0.1', param.port) + local lsmaster = net.connect('unix', param.unixPath) + + assert(lsclient) + assert(lsmaster) + + function lsclient:on_data(data) + buf = buf .. data + coroutine.resume(parser) + end + + function lsmaster:on_data(data) + lsclient:write(data) + net.update() + end + while true do - local rd = socket.select({rfd, wfd}, nil, 10) - if not rd or #rd == 0 then - goto continue - end - if util.arrayHas(rd, wfd) then - local needSend = wfd:recv() - if needSend then - rfd:send(needSend) - elseif needSend == nil then - error('socket closed!') - end - end - if util.arrayHas(rd, rfd) then - local recved = rfd:recv() - if recved then - buf = buf .. recved - elseif recved == nil then - error('socket closed!') - end - coroutine.resume(parser) - end - ::continue:: + net.update(10) end end) diff --git a/script/cli/check.lua b/script/cli/check.lua index 94f34b2a..4295fa06 100644 --- a/script/cli/check.lua +++ b/script/cli/check.lua @@ -9,6 +9,7 @@ local lang = require 'language' local define = require 'proto.define' local config = require 'config.config' local fs = require 'bee.filesystem' +local provider = require 'provider' require 'vm' @@ -52,6 +53,8 @@ lclient():start(function (client) io.write(lang.script('CLI_CHECK_INITING')) + provider.updateConfig(rootUri) + ws.awaitReady(rootUri) local disables = util.arrayToHash(config.get(rootUri, 'Lua.diagnostics.disable')) @@ -96,7 +99,10 @@ end if count == 0 then print(lang.script('CLI_CHECK_SUCCESS')) else - local outpath = LOGPATH .. '/check.json' + local outpath = CHECK_OUT_PATH + if outpath == nil then + outpath = LOGPATH .. '/check.json' + end util.saveFile(outpath, jsonb.beautify(results)) print(lang.script('CLI_CHECK_RESULTS', count, outpath)) diff --git a/script/cli/doc.lua b/script/cli/doc.lua index c643deea..9140a258 100644 --- a/script/cli/doc.lua +++ b/script/cli/doc.lua @@ -65,6 +65,7 @@ local function packObject(source, mark) end if source.type == 'function.return' then new['desc'] = source.comment and getDesc(source.comment) + new['rawdesc'] = source.comment and getDesc(source.comment, true) end if source.type == 'doc.type.table' then new['fields'] = packObject(source.fields, mark) @@ -82,6 +83,7 @@ local function packObject(source, mark) end if source.bindDocs then new['desc'] = getDesc(source) + new['rawdesc'] = getDesc(source, true) end new['view'] = new['view'] or vm.getInfer(source):view(ws.rootUri) end @@ -115,6 +117,7 @@ local function collectTypes(global, results) name = global.name, type = 'type', desc = nil, + rawdesc = nil, defines = {}, fields = {}, } @@ -131,6 +134,7 @@ local function collectTypes(global, results) extends = getExtends(set), } result.desc = result.desc or getDesc(set) + result.rawdesc = result.rawdesc or getDesc(set, true) ::CONTINUE:: end if #result.defines == 0 then @@ -163,6 +167,7 @@ local function collectTypes(global, results) field.start = source.start field.finish = source.finish field.desc = getDesc(source) + field.rawdesc = getDesc(source, true) field.extends = packObject(source.extends) return end @@ -180,6 +185,7 @@ local function collectTypes(global, results) field.start = source.start field.finish = source.finish field.desc = getDesc(source) + field.rawdesc = getDesc(source, true) field.extends = packObject(source.value) return end @@ -199,6 +205,7 @@ local function collectTypes(global, results) field.start = source.start field.finish = source.finish field.desc = getDesc(source) + field.rawdesc = getDesc(source, true) field.extends = packObject(source.value) return end @@ -237,6 +244,7 @@ local function collectVars(global, results) extends = packObject(set.value), } result.desc = result.desc or getDesc(set) + result.rawdesc = result.rawdesc or getDesc(set, true) end end if #result.defines == 0 then diff --git a/script/config/template.lua b/script/config/template.lua index 436f5e1a..ba0fd503 100644 --- a/script/config/template.lua +++ b/script/config/template.lua @@ -317,7 +317,12 @@ local template = { ['Lua.workspace.maxPreload'] = Type.Integer >> 5000, ['Lua.workspace.preloadFileSize'] = Type.Integer >> 500, ['Lua.workspace.library'] = Type.Array(Type.String), - ['Lua.workspace.checkThirdParty'] = Type.Boolean >> true, + ['Lua.workspace.checkThirdParty'] = Type.Or(Type.String >> 'Ask' << { + 'Ask', + 'Apply', + 'ApplyInMemory', + 'Disable', + }, Type.Boolean), ['Lua.workspace.userThirdParty'] = Type.Array(Type.String), ['Lua.completion.enable'] = Type.Boolean >> true, ['Lua.completion.callSnippet'] = Type.String >> 'Disable' << { diff --git a/script/core/code-action.lua b/script/core/code-action.lua index e226e03b..720cd4c4 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -4,6 +4,11 @@ local util = require 'utility' local sp = require 'bee.subprocess' local guide = require "parser.guide" local converter = require 'proto.converter' +local autoreq = require 'core.completion.auto-require' +local rpath = require 'workspace.require-path' +local furi = require 'file-uri' +local undefined = require 'core.diagnostics.undefined-global' +local vm = require 'vm' ---@param uri uri ---@param row integer @@ -676,6 +681,54 @@ local function checkJsonToLua(results, uri, start, finish) } end +local function findRequireTargets(visiblePaths) + local targets = {} + for _, visible in ipairs(visiblePaths) do + targets[#targets+1] = visible.name + end + return targets +end + +local function checkMissingRequire(results, uri, start, finish) + local state = files.getState(uri) + local text = files.getText(uri) + if not state or not text then + return + end + + local function addRequires(global, endpos) + autoreq.check(state, global, endpos, function(moduleFile, stemname, targetSource) + local visiblePaths = rpath.getVisiblePath(uri, furi.decode(moduleFile)) + if not visiblePaths or #visiblePaths == 0 then return end + + for _, target in ipairs(findRequireTargets(visiblePaths)) do + results[#results+1] = { + title = lang.script('ACTION_AUTOREQUIRE', target, global), + kind = 'refactor.rewrite', + command = { + title = 'autoRequire', + command = 'lua.autoRequire', + arguments = { + { + uri = guide.getUri(state.ast), + target = moduleFile, + name = global, + requireName = target + }, + }, + } + } + end + end) + end + + guide.eachSourceBetween(state.ast, start, finish, function (source) + if vm.isUndefinedGlobal(source) then + addRequires(source[1], source.finish) + end + end) +end + return function (uri, start, finish, diagnostics) local ast = files.getState(uri) if not ast then @@ -688,6 +741,7 @@ return function (uri, start, finish, diagnostics) checkSwapParams(results, uri, start, finish) --checkExtractAsFunction(results, uri, start, finish) checkJsonToLua(results, uri, start, finish) + checkMissingRequire(results, uri, start, finish) return results end diff --git a/script/core/code-lens.lua b/script/core/code-lens.lua index ed95ea6a..bc39ec86 100644 --- a/script/core/code-lens.lua +++ b/script/core/code-lens.lua @@ -6,7 +6,7 @@ local getRef = require 'core.reference' local lang = require 'language' ---@class parser.state ----@field package _codeLens codeLens +---@field package _codeLens? codeLens ---@class codeLens.resolving ---@field mode 'reference' @@ -14,7 +14,6 @@ local lang = require 'language' ---@class codeLens.result ---@field position integer ----@field uri uri ---@field id integer ---@class codeLens diff --git a/script/core/command/autoRequire.lua b/script/core/command/autoRequire.lua index 32911d92..a96cc918 100644 --- a/script/core/command/autoRequire.lua +++ b/script/core/command/autoRequire.lua @@ -135,6 +135,7 @@ return function (data) local uri = data.uri local target = data.target local name = data.name + local requireName = data.requireName local state = files.getState(uri) if not state then return @@ -149,11 +150,13 @@ return function (data) return #a.name < #b.name end) - local result = askAutoRequire(uri, visiblePaths) - if not result then - return + if not requireName then + requireName = askAutoRequire(uri, visiblePaths) + if not requireName then + return + end end local offset, fmt = findInsertRow(uri) - applyAutoRequire(uri, offset, name, result, fmt) + applyAutoRequire(uri, offset, name, requireName, fmt) end diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index 0ec503de..acb3adbe 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -504,7 +504,8 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent local kind = define.CompletionItemKind.Field if (value.type == 'function' and not vm.isVarargFunctionWithOverloads(value)) or value.type == 'doc.type.function' then - if oop then + local isMethod = value.parent.type == 'setmethod' + if isMethod then kind = define.CompletionItemKind.Method else kind = define.CompletionItemKind.Function @@ -512,6 +513,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent buildFunction(results, src, value, oop, { label = name, kind = kind, + isMethod = isMethod, match = name:match '^[^(]+', insertText = name:match '^[^(]+', deprecated = vm.getDeprecated(src) and true or nil, @@ -626,12 +628,43 @@ local function checkFieldOfRefs(refs, state, word, startPos, position, parent, o end ::CONTINUE:: end + + local fieldResults = {} for name, src in util.sortPairs(fields) do if src then - checkFieldThen(state, name, src, word, startPos, position, parent, oop, results) + checkFieldThen(state, name, src, word, startPos, position, parent, oop, fieldResults) await.delay() end end + + local scoreMap = {} + for i, res in ipairs(fieldResults) do + scoreMap[res] = i + end + table.sort(fieldResults, function (a, b) + local score1 = scoreMap[a] + local score2 = scoreMap[b] + if oop then + if not a.isMethod then + score1 = score1 + 10000 + end + if not b.isMethod then + score2 = score2 + 10000 + end + else + if a.isMethod then + score1 = score1 + 10000 + end + if b.isMethod then + score2 = score2 + 10000 + end + end + return score1 < score2 + end) + + for _, res in ipairs(fieldResults) do + results[#results+1] = res + end end ---@async @@ -1218,6 +1251,46 @@ local function insertDocEnum(state, pos, doc, enums) return enums end +---@param state parser.state +---@param pos integer +---@param doc vm.node.object +---@param enums table[] +---@return table[]? +local function insertDocEnumKey(state, pos, doc, enums) + local tbl = doc.bindSource + if not tbl then + return nil + end + local keyEnums = {} + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + enums[#enums+1] = { + label = ('%q'):format(key), + kind = define.CompletionItemKind.EnumMember, + id = stack(field, function (newField) ---@async + return { + detail = buildDetail(newField), + description = buildDesc(newField), + } + end), + } + ::CONTINUE:: + end + end + for _, enum in ipairs(keyEnums) do + enums[#enums+1] = enum + end + return enums +end + local function buildInsertDocFunction(doc) local args = {} for i, arg in ipairs(doc.args) do @@ -1283,7 +1356,11 @@ local function insertEnum(state, pos, src, enums, isInArray, mark) elseif src.type == 'global' and src.cate == 'type' then for _, set in ipairs(src:getSets(state.uri)) do if set.type == 'doc.enum' then - insertDocEnum(state, pos, set, enums) + if vm.docHasAttr(set, 'key') then + insertDocEnumKey(state, pos, set, enums) + else + insertDocEnum(state, pos, set, enums) + end end end end @@ -1539,14 +1616,13 @@ local function checkTableLiteralField(state, position, tbl, fields, results) end end if left then - local hasResult = false + local fieldResults = {} for _, field in ipairs(fields) do local name = guide.getKeyName(field) if name and not mark[name] and matchKey(left, tostring(name)) then - hasResult = true - results[#results+1] = { + local res = { label = guide.getKeyName(field), kind = define.CompletionItemKind.Property, id = stack(field, function (newField) ---@async @@ -1556,9 +1632,20 @@ local function checkTableLiteralField(state, position, tbl, fields, results) } end), } + if field.optional + or vm.compileNode(field):isNullable() then + res.insertText = res.label + res.label = res.label.. '?' + end + fieldResults[#fieldResults+1] = res end end - return hasResult + util.sortByScore(fieldResults, { + function (r) return r.insertText and 0 or 1 end, + util.sortCallbackOfIndex(fieldResults), + }) + util.arrayMerge(results, fieldResults) + return #fieldResults > 0 end end @@ -1571,7 +1658,8 @@ local function tryCallArg(state, position, results) if arg and arg.type == 'function' then return end - local node = vm.compileCallArg({ type = 'dummyarg' }, call, argIndex) + ---@diagnostic disable-next-line: missing-fields + local node = vm.compileCallArg({ type = 'dummyarg', uri = state.uri }, call, argIndex) if not node then return end @@ -2070,7 +2158,7 @@ local function tryluaDocByErr(state, position, err, docState, results) end end -local function buildluaDocOfFunction(func) +local function buildluaDocOfFunction(func, pad) local index = 1 local buf = {} buf[#buf+1] = '${1:comment}' @@ -2094,7 +2182,8 @@ local function buildluaDocOfFunction(func) local funcArg = func.args[n] if funcArg[1] and funcArg.type ~= 'self' then index = index + 1 - buf[#buf+1] = ('---@param %s ${%d:%s}'):format( + buf[#buf+1] = ('---%s@param %s ${%d:%s}'):format( + pad and ' ' or '', funcArg[1], index, arg @@ -2103,7 +2192,8 @@ local function buildluaDocOfFunction(func) end for _, rtn in ipairs(returns) do index = index + 1 - buf[#buf+1] = ('---@return ${%d:%s}'):format( + buf[#buf+1] = ('---%s@return ${%d:%s}'):format( + pad and ' ' or '', index, rtn ) @@ -2112,7 +2202,7 @@ local function buildluaDocOfFunction(func) return insertText end -local function tryluaDocOfFunction(doc, results) +local function tryluaDocOfFunction(doc, results, pad) if not doc.bindSource then return end @@ -2134,7 +2224,7 @@ local function tryluaDocOfFunction(doc, results) end end end - local insertText = buildluaDocOfFunction(func) + local insertText = buildluaDocOfFunction(func, pad) results[#results+1] = { label = '@param;@return', kind = define.CompletionItemKind.Snippet, @@ -2152,9 +2242,9 @@ local function tryLuaDoc(state, position, results) end if doc.type == 'doc.comment' then local line = doc.originalComment.text - -- 尝试 ---$ - if line == '-' then - tryluaDocOfFunction(doc, results) + -- 尝试 '---$' or '--- $' + if line == '-' or line == '- ' then + tryluaDocOfFunction(doc, results, line == '- ') return end -- 尝试 ---@$ diff --git a/script/core/completion/keyword.lua b/script/core/completion/keyword.lua index e6f50242..aa0e2148 100644 --- a/script/core/completion/keyword.lua +++ b/script/core/completion/keyword.lua @@ -3,6 +3,7 @@ local files = require 'files' local guide = require 'parser.guide' local config = require 'config' local util = require 'utility' +local lookback = require 'core.look-backward' local keyWordMap = { { 'do', function(info, results) @@ -372,17 +373,35 @@ end" else newText = '::continue::' end + local additional = {} + + local word = lookback.findWord(info.state.lua, guide.positionToOffset(info.state, info.start) - 1) + if word ~= 'goto' then + additional[#additional+1] = { + start = info.start, + finish = info.start, + newText = 'goto ', + } + end + + local hasContinue = guide.eachSourceType(mostInsideBlock, 'label', function (src) + if src[1] == 'continue' then + return true + end + end) + + if not hasContinue then + additional[#additional+1] = { + start = endPos, + finish = endPos, + newText = newText, + } + end results[#results+1] = { label = 'goto continue ..', kind = define.CompletionItemKind.Snippet, - insertText = "goto continue", - additionalTextEdits = { - { - start = endPos, - finish = endPos, - newText = newText, - } - } + insertText = "continue", + additionalTextEdits = additional, } return true end } diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua index 1998b915..26445374 100644 --- a/script/core/diagnostics/cast-local-type.lua +++ b/script/core/diagnostics/cast-local-type.lua @@ -16,6 +16,9 @@ return function (uri, callback) if not loc.ref then return end + if loc[1] == '_' then + return + end await.delay() local locNode = vm.compileNode(loc) if not locNode.hasDefined then diff --git a/script/core/diagnostics/helper/missing-doc-helper.lua b/script/core/diagnostics/helper/missing-doc-helper.lua index 84221693..116173f2 100644 --- a/script/core/diagnostics/helper/missing-doc-helper.lua +++ b/script/core/diagnostics/helper/missing-doc-helper.lua @@ -38,8 +38,9 @@ end local function checkFunction(source, callback, commentId, paramId, returnId) local functionName = source.parent[1] + local argCount = source.args and #source.args or 0 - if #source.args == 0 and not source.returns and not source.bindDocs then + if argCount == 0 and not source.returns and not source.bindDocs then callback { start = source.start, finish = source.finish, @@ -47,10 +48,11 @@ local function checkFunction(source, callback, commentId, paramId, returnId) } end - if #source.args > 0 then + if argCount > 0 then for _, arg in ipairs(source.args) do local argName = arg[1] - if argName ~= 'self' then + if argName ~= 'self' + and argName ~= '_' then if not findParam(source.bindDocs, argName) then callback { start = arg.start, diff --git a/script/core/diagnostics/incomplete-signature-doc.lua b/script/core/diagnostics/incomplete-signature-doc.lua index 91f2db74..1ffbb77a 100644 --- a/script/core/diagnostics/incomplete-signature-doc.lua +++ b/script/core/diagnostics/incomplete-signature-doc.lua @@ -38,6 +38,19 @@ local function findReturn(docs, index) return false end +--- check if there's any signature doc (@param or @return), or just comments, @async, ... +local function findSignatureDoc(docs) + if not docs then + return false + end + for _, doc in ipairs(docs) do + if doc.type == 'doc.return' or doc.type == 'doc.param' then + return true + end + end + return false +end + ---@async return function (uri, callback) local state = files.getState(uri) @@ -57,17 +70,22 @@ return function (uri, callback) return end - local functionName = source.parent[1] + --- don't apply rule if there is no @param or @return annotation yet + --- so comments and @async can be applied without the need for a full documentation + if(not findSignatureDoc(source.bindDocs)) then + return + end - if #source.args > 0 then + if source.args and #source.args > 0 then for _, arg in ipairs(source.args) do local argName = arg[1] - if argName ~= 'self' then + if argName ~= 'self' + and argName ~= '_' then if not findParam(source.bindDocs, argName) then callback { start = arg.start, finish = arg.finish, - message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_PARAM', argName, functionName), + message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_PARAM', argName), } end end @@ -81,7 +99,7 @@ return function (uri, callback) callback { start = expr.start, finish = expr.finish, - message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_RETURN', index, functionName), + message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_RETURN', index), } end end diff --git a/script/core/diagnostics/inject-field.lua b/script/core/diagnostics/inject-field.lua new file mode 100644 index 00000000..2866eef8 --- /dev/null +++ b/script/core/diagnostics/inject-field.lua @@ -0,0 +1,147 @@ +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local guide = require 'parser.guide' +local await = require 'await' +local hname = require 'core.hover.name' + +local skipCheckClass = { + ['unknown'] = true, + ['any'] = true, + ['table'] = true, +} + +---@async +return function (uri, callback) + local ast = files.getState(uri) + if not ast then + return + end + + ---@async + local function checkInjectField(src) + await.delay() + + local node = src.node + if not node then + return + end + local ok + for view in vm.getInfer(node):eachView(uri) do + if skipCheckClass[view] then + return + end + ok = true + end + if not ok then + return + end + + local isExact + local class = vm.getDefinedClass(uri, node) + if class then + for _, doc in ipairs(class:getSets(uri)) do + if vm.docHasAttr(doc, 'exact') then + isExact = true + break + end + end + if not isExact then + return + end + if src.type == 'setmethod' + and not guide.getSelfNode(node) then + return + end + end + + for _, def in ipairs(vm.getDefs(src)) do + local dnode = def.node + if dnode + and not isExact + and vm.getDefinedClass(uri, dnode) then + return + end + if def.type == 'doc.type.field' then + return + end + if def.type == 'doc.field' then + return + end + end + + local howToFix = '' + if not isExact then + howToFix = lang.script('DIAG_INJECT_FIELD_FIX_CLASS', { + node = hname(node), + fix = '---@class', + }) + for _, ndef in ipairs(vm.getDefs(node)) do + if ndef.type == 'doc.type.table' then + howToFix = lang.script('DIAG_INJECT_FIELD_FIX_TABLE', { + fix = '[any]: any', + }) + break + end + end + end + + local message = lang.script('DIAG_INJECT_FIELD', { + class = vm.getInfer(node):view(uri), + field = guide.getKeyName(src), + fix = howToFix, + }) + if src.type == 'setfield' and src.field then + callback { + start = src.field.start, + finish = src.field.finish, + message = message, + } + elseif src.type == 'setmethod' and src.method then + callback { + start = src.method.start, + finish = src.method.finish, + message = message, + } + end + end + guide.eachSourceType(ast.ast, 'setfield', checkInjectField) + guide.eachSourceType(ast.ast, 'setmethod', checkInjectField) + + ---@async + local function checkExtraTableField(src) + await.delay() + + if not src.bindSource then + return + end + if not vm.docHasAttr(src, 'exact') then + return + end + local value = src.bindSource.value + if not value or value.type ~= 'table' then + return + end + for _, field in ipairs(value) do + local defs = vm.getDefs(field) + for _, def in ipairs(defs) do + if def.type == 'doc.field' then + goto nextField + end + end + local message = lang.script('DIAG_INJECT_FIELD', { + class = vm.getInfer(src):view(uri), + field = guide.getKeyName(src), + fix = '', + }) + callback { + start = field.start, + finish = field.finish, + message = message, + } + ::nextField:: + end + end + + guide.eachSourceType(ast.ast, 'doc.class', checkExtraTableField) +end diff --git a/script/core/diagnostics/missing-fields.lua b/script/core/diagnostics/missing-fields.lua new file mode 100644 index 00000000..210920fd --- /dev/null +++ b/script/core/diagnostics/missing-fields.lua @@ -0,0 +1,84 @@ +local vm = require 'vm' +local files = require 'files' +local guide = require 'parser.guide' +local await = require 'await' +local lang = require 'language' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'table', function (src) + await.delay() + + local defs = vm.getDefs(src) + for _, def in ipairs(defs) do + if def.type == 'doc.class' and def.bindSource then + if guide.isInRange(def.bindSource, src.start) then + return + end + end + if def.type == 'doc.type.array' + or def.type == 'doc.type.table' then + return + end + end + local warnings = {} + for _, def in ipairs(defs) do + if def.type == 'doc.class' then + if not def.fields then + return + end + + local requiresKeys = {} + for _, field in ipairs(def.fields) do + if not field.optional + and not vm.compileNode(field):isNullable() then + local key = vm.getKeyName(field) + if key and not requiresKeys[key] then + requiresKeys[key] = true + requiresKeys[#requiresKeys+1] = key + end + end + end + + if #requiresKeys == 0 then + return + end + local myKeys = {} + for _, field in ipairs(src) do + local key = vm.getKeyName(field) + if key then + myKeys[key] = true + end + end + + local missedKeys = {} + for _, key in ipairs(requiresKeys) do + if not myKeys[key] then + missedKeys[#missedKeys+1] = ('`%s`'):format(key) + end + end + + if #missedKeys == 0 then + return + end + + warnings[#warnings+1] = lang.script('DIAG_MISSING_FIELDS', def.class[1], table.concat(missedKeys, ', ')) + end + end + + if #warnings == 0 then + return + end + callback { + start = src.start, + finish = src.finish, + message = table.concat(warnings, '\n') + } + end) +end diff --git a/script/core/diagnostics/missing-local-export-doc.lua b/script/core/diagnostics/missing-local-export-doc.lua index 5825c115..da413961 100644 --- a/script/core/diagnostics/missing-local-export-doc.lua +++ b/script/core/diagnostics/missing-local-export-doc.lua @@ -10,7 +10,13 @@ local function findSetField(ast, name, callback) await.delay() if source.node[1] == name then local funcPtr = source.value.node + if not funcPtr then + return + end local func = funcPtr.value + if not func then + return + end if funcPtr.type == 'local' and func.type == 'function' then helper.CheckFunction(func, callback, 'DIAG_MISSING_LOCAL_EXPORT_DOC_COMMENT', 'DIAG_MISSING_LOCAL_EXPORT_DOC_PARAM', 'DIAG_MISSING_LOCAL_EXPORT_DOC_RETURN') end diff --git a/script/core/diagnostics/param-type-mismatch.lua b/script/core/diagnostics/param-type-mismatch.lua index da39c5e1..acbf9c8c 100644 --- a/script/core/diagnostics/param-type-mismatch.lua +++ b/script/core/diagnostics/param-type-mismatch.lua @@ -32,26 +32,28 @@ end ---@param funcNode vm.node ---@param i integer +---@param uri uri ---@return vm.node? -local function getDefNode(funcNode, i) +local function getDefNode(funcNode, i, uri) local defNode = vm.createNode() - for f in funcNode:eachObject() do - if f.type == 'function' - or f.type == 'doc.type.function' then - local param = f.args and f.args[i] + for src in funcNode:eachObject() do + if src.type == 'function' + or src.type == 'doc.type.function' then + local param = src.args and src.args[i] if param then defNode:merge(vm.compileNode(param)) if param[1] == '...' then defNode:addOptional() end - - expandGenerics(defNode) end end end if defNode:isEmpty() then return nil end + + expandGenerics(defNode) + return defNode end @@ -91,7 +93,7 @@ return function (uri, callback) if not refNode then goto CONTINUE end - local defNode = getDefNode(funcNode, i) + local defNode = getDefNode(funcNode, i, uri) if not defNode then goto CONTINUE end diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index a83241f5..4fd55966 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -8,13 +8,6 @@ local skipCheckClass = { ['unknown'] = true, ['any'] = true, ['table'] = true, - ['nil'] = true, - ['number'] = true, - ['integer'] = true, - ['boolean'] = true, - ['function'] = true, - ['userdata'] = true, - ['lightuserdata'] = true, } ---@async @@ -61,5 +54,4 @@ return function (uri, callback) end guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField) guide.eachSourceType(ast.ast, 'getmethod', checkUndefinedField) - guide.eachSourceType(ast.ast, 'getindex', checkUndefinedField) end diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index 179c9204..d9d94959 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -20,41 +20,21 @@ return function (uri, callback) return end - local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) - local rspecial = config.get(uri, 'Lua.runtime.special') - local cache = {} - -- 遍历全局变量,检查所有没有 set 模式的全局变量 guide.eachSourceType(state.ast, 'getglobal', function (src) ---@async - local key = src[1] - if not key then - return - end - if dglobals[key] then - return - end - if rspecial[key] then - return - end - local node = src.node - if node.tag ~= '_ENV' then - return - end - if cache[key] == nil then - await.delay() - cache[key] = vm.hasGlobalSets(uri, 'variable', key) - end - if cache[key] then - return - end - local message = lang.script('DIAG_UNDEF_GLOBAL', key) - if requireLike[key:lower()] then - message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key)) + if vm.isUndefinedGlobal(src) then + local key = src[1] + local message = lang.script('DIAG_UNDEF_GLOBAL', key) + if requireLike[key:lower()] then + message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key)) + end + + callback { + start = src.start, + finish = src.finish, + message = message, + undefinedGlobal = src[1] + } end - callback { - start = src.start, - finish = src.finish, - message = message, - } end) end diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index f5890b21..75189b06 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -336,7 +336,7 @@ local function tryDocFieldComment(source) end end -local function getFunctionComment(source) +local function getFunctionComment(source, raw) local docGroup = source.bindDocs if not docGroup then return @@ -356,14 +356,14 @@ local function getFunctionComment(source) if doc.type == 'doc.comment' then local comment = normalizeComment(doc.comment.text, uri) md:add('md', comment) - elseif doc.type == 'doc.param' then + elseif doc.type == 'doc.param' and not raw then if doc.comment then md:add('md', ('@*param* `%s` — %s'):format( doc.param[1], doc.comment.text )) end - elseif doc.type == 'doc.return' then + elseif doc.type == 'doc.return' and not raw then if hasReturnComment then local name = {} for _, rtn in ipairs(doc.returns) do @@ -401,13 +401,13 @@ local function getFunctionComment(source) end ---@async -local function tryDocComment(source) +local function tryDocComment(source, raw) local md = markdown() if source.value and source.value.type == 'function' then source = source.value end if source.type == 'function' then - local comment = getFunctionComment(source) + local comment = getFunctionComment(source, raw) md:add('md', comment) source = source.parent end @@ -429,7 +429,7 @@ local function tryDocComment(source) end ---@async -local function tryDocOverloadToComment(source) +local function tryDocOverloadToComment(source, raw) if source.type ~= 'doc.type.function' then return end @@ -438,7 +438,7 @@ local function tryDocOverloadToComment(source) or not doc.bindSource then return end - local md = tryDocComment(doc.bindSource) + local md = tryDocComment(doc.bindSource, raw) if md then return md end @@ -477,38 +477,59 @@ local function tryDocEnum(source) if not tbl then return end - local md = markdown() - md:add('lua', '{') - for _, field in ipairs(tbl) do - if field.type == 'tablefield' - or field.type == 'tableindex' then - if not field.value then - goto CONTINUE - end - local key = guide.getKeyName(field) - if not key then - goto CONTINUE - end - if field.value.type == 'integer' - or field.value.type == 'string' then - md:add('lua', (' %s: %s = %s,'):format(key, field.value.type, field.value[1])) + if vm.docHasAttr(source, 'key') then + local md = markdown() + local keys = {} + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + keys[#keys+1] = ('%q'):format(key) + ::CONTINUE:: end - if field.value.type == 'binary' - or field.value.type == 'unary' then - local number = vm.getNumber(field.value) - if number then - md:add('lua', (' %s: %s = %s,'):format(key, math.tointeger(number) and 'integer' or 'number', number)) + end + md:add('lua', table.concat(keys, ' | ')) + return md:string() + else + local md = markdown() + md:add('lua', '{') + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + if field.value.type == 'integer' + or field.value.type == 'string' then + md:add('lua', (' %s: %s = %s,'):format(key, field.value.type, field.value[1])) + end + if field.value.type == 'binary' + or field.value.type == 'unary' then + local number = vm.getNumber(field.value) + if number then + md:add('lua', (' %s: %s = %s,'):format(key, math.tointeger(number) and 'integer' or 'number', number)) + end end + ::CONTINUE:: end - ::CONTINUE:: end + md:add('lua', '}') + return md:string() end - md:add('lua', '}') - return md:string() end ---@async -return function (source) +return function (source, raw) if source.type == 'string' then return asString(source) end @@ -518,10 +539,10 @@ return function (source) if source.type == 'field' then source = source.parent end - return tryDocOverloadToComment(source) + return tryDocOverloadToComment(source, raw) or tryDocFieldComment(source) or tyrDocParamComment(source) - or tryDocComment(source) + or tryDocComment(source, raw) or tryDocClassComment(source) or tryDocModule(source) or tryDocEnum(source) diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 6ce4dde9..62e51927 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -134,7 +134,7 @@ local function asField(source) end local function asDocFieldName(source) - local name = source.field[1] + local name = vm.viewKey(source, guide.getUri(source)) or '?' local class for _, doc in ipairs(source.bindGroup) do if doc.type == 'doc.class' then @@ -143,10 +143,12 @@ local function asDocFieldName(source) end end local view = vm.getInfer(source.extends):view(guide.getUri(source)) - if not class then - return ('(field) ?.%s: %s'):format(name, view) + local className = class and class.class[1] or '?' + if name:match(guide.namePatternFull) then + return ('(field) %s.%s: %s'):format(className, name, view) + else + return ('(field) %s%s: %s'):format(className, name, view) end - return ('(field) %s.%s: %s'):format(class.class[1], name, view) end local function asString(source) diff --git a/script/core/rename.lua b/script/core/rename.lua index 507def20..cc5d37f3 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -3,42 +3,59 @@ local vm = require 'vm' local util = require 'utility' local findSource = require 'core.find-source' local guide = require 'parser.guide' +local config = require 'config' local Forcing +---@param str string +---@return string local function trim(str) return str:match '^%s*(%S+)%s*$' end -local function isValidName(str) +---@param uri uri +---@param str string +---@return boolean +local function isValidName(uri, str) if not str then return false end - return str:match '^[%a_][%w_]*$' + local allowUnicode = config.get(uri, 'Lua.runtime.unicodeName') + if allowUnicode then + return str:match '^[%a_\x80-\xff][%w_\x80-\xff]*$' + else + return str:match '^[%a_][%w_]*$' + end end -local function isValidGlobal(str) +---@param uri uri +---@param str string +---@return boolean +local function isValidGlobal(uri, str) if not str then return false end for s in str:gmatch '[^%.]*' do - if not isValidName(trim(s)) then + if not isValidName(uri, trim(s)) then return false end end return true end -local function isValidFunctionName(str) - if isValidGlobal(str) then +---@param uri uri +---@param str string +---@return boolean +local function isValidFunctionName(uri, str) + if isValidGlobal(uri, str) then return true end local offset = str:find(':', 1, true) if not offset then return false end - return isValidGlobal(trim(str:sub(1, offset-1))) - and isValidName(trim(str:sub(offset+1))) + return isValidGlobal(uri, trim(str:sub(1, offset-1))) + and isValidName(uri, trim(str:sub(offset+1))) end local function isFunctionGlobalName(source) @@ -54,7 +71,7 @@ local function isFunctionGlobalName(source) end local function renameLocal(source, newname, callback) - if isValidName(newname) then + if isValidName(guide.getUri(source), newname) then callback(source, source.start, source.finish, newname) return end @@ -62,7 +79,7 @@ local function renameLocal(source, newname, callback) end local function renameField(source, newname, callback) - if isValidName(newname) then + if isValidName(guide.getUri(source), newname) then callback(source, source.start, source.finish, newname) return true end @@ -108,11 +125,11 @@ local function renameField(source, newname, callback) end local function renameGlobal(source, newname, callback) - if isValidGlobal(newname) then + if isValidGlobal(guide.getUri(source), newname) then callback(source, source.start, source.finish, newname) return true end - if isValidFunctionName(newname) then + if isValidFunctionName(guide.getUri(source), newname) then callback(source, source.start, source.finish, newname) return true end diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 4d191b69..4e1d8e00 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -138,12 +138,20 @@ local Care = util.switch() local uri = guide.getUri(loc) -- 1. 值为函数的局部变量 | Local variable whose value is a function if vm.getInfer(source):hasFunction(uri) then - results[#results+1] = { - start = source.start, - finish = source.finish, - type = define.TokenTypes['function'], - modifieres = define.TokenModifiers.declaration, - } + if source.type == 'local' then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes['function'], + modifieres = define.TokenModifiers.declaration, + } + else + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes['function'], + } + end return end -- 3. 特殊变量 | Special variableif source[1] == '_ENV' then @@ -703,6 +711,14 @@ local Care = util.switch() type = define.TokenTypes.namespace, } end) + : case 'doc.attr' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.decorator, + } + end) ---@param state table ---@param results table diff --git a/script/core/signature.lua b/script/core/signature.lua index 63b0cd0d..98018b21 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -89,6 +89,39 @@ local function makeOneSignature(source, oop, index) } end +local function isEventNotMatch(call, src) + if not call.args or not src.args then + return false + end + local literal, index + for i = 1, 2 do + if not call.args[i] then + break + end + literal = guide.getLiteral(call.args[i]) + if literal then + index = i + break + end + end + if not literal then + return false + end + local event = src.args[index] + if not event or event.type ~= 'doc.type.arg' then + return false + end + if not event.extends + or #event.extends.types ~= 1 then + return false + end + local eventLiteral = event.extends.types[1] and guide.getLiteral(event.extends.types[1]) + if eventLiteral == nil then + return false + end + return eventLiteral ~= literal +end + ---@async local function makeSignatures(text, call, pos) local func = call.node @@ -139,7 +172,8 @@ local function makeSignatures(text, call, pos) for src in node:eachObject() do if (src.type == 'function' and not vm.isVarargFunctionWithOverloads(src)) or src.type == 'doc.type.function' then - if not mark[src] then + if not mark[src] + and not isEventNotMatch(call, src) then mark[src] = true signs[#signs+1] = makeOneSignature(src, oop, index) end @@ -149,7 +183,8 @@ local function makeSignatures(text, call, pos) if set.type == 'doc.class' then for _, overload in ipairs(set.calls) do local f = overload.overload - if not mark[f] then + if not mark[f] + and not isEventNotMatch(call, src) then mark[f] = true signs[#signs+1] = makeOneSignature(f, oop, index) end diff --git a/script/file-uri.lua b/script/file-uri.lua index 8e9dd938..8a075f7e 100644 --- a/script/file-uri.lua +++ b/script/file-uri.lua @@ -49,7 +49,7 @@ function m.encode(path) --lower-case windows drive letters in /C:/fff or C:/fff local start, finish, drive = path:find '/(%u):' - if drive then + if drive and finish then path = path:sub(1, start) .. drive:lower() .. path:sub(finish, -1) end @@ -102,6 +102,9 @@ function m.isValid(uri) if path == '' then return false end + if scheme ~= 'file' then + return false + end return true end diff --git a/script/files.lua b/script/files.lua index 7998ceed..b7a20e8b 100644 --- a/script/files.lua +++ b/script/files.lua @@ -19,20 +19,20 @@ local sp = require 'bee.subprocess' local pub = require 'pub' ---@class file ----@field uri uri ----@field content string ----@field ref? integer ----@field trusted? boolean ----@field rows? integer[] ----@field originText? string ----@field text string ----@field version? integer ----@field originLines? integer[] ----@field diffInfo? table[] ----@field cache table ----@field id integer ----@field state? parser.state ----@field compileCount integer +---@field uri uri +---@field ref? integer +---@field trusted? boolean +---@field rows? integer[] +---@field originText? string +---@field text? string +---@field version? integer +---@field originLines? integer[] +---@field diffInfo? table[] +---@field cache? table +---@field id integer +---@field state? parser.state +---@field compileCount? integer +---@field words? table ---@class files ---@field lazyCache? lazy-cacher @@ -725,7 +725,8 @@ end ---@class parser.state ---@field diffInfo? table[] ---@field originLines? integer[] ----@field originText string +---@field originText? string +---@field lua? string --- 获取文件语法树 ---@param uri uri diff --git a/script/global.d.lua b/script/global.d.lua index f84ff0e4..b44d6371 100644 --- a/script/global.d.lua +++ b/script/global.d.lua @@ -55,6 +55,12 @@ DOC = '' ---@type string | '"Error"' | '"Warning"' | '"Information"' | '"Hint"' CHECKLEVEL = 'Warning' +--Where to write the check results (JSON). +-- +--If nil, use `LOGPATH/check.json`. +---@type string|nil +CHECK_OUT_PATH = '' + ---@type 'trace' | 'debug' | 'info' | 'warn' | 'error' LOGLEVEL = 'warn' @@ -77,3 +83,6 @@ jit = false -- connect to client by socket ---@type integer SOCKET = 0 + +-- Allowing the use of the root directory or home directory as the workspace +FORCE_ACCEPT_WORKSPACE = false diff --git a/script/library.lua b/script/library.lua index 3a9bbbc6..4446797a 100644 --- a/script/library.lua +++ b/script/library.lua @@ -469,34 +469,45 @@ end local hasAsked = {} ---@async -local function askFor3rd(uri, cfg) +local function askFor3rd(uri, cfg, checkThirdParty) if hasAsked[cfg.name] then return nil end - hasAsked[cfg.name] = true - local yes1 = lang.script.WINDOW_APPLY_WHIT_SETTING - local yes2 = lang.script.WINDOW_APPLY_WHITOUT_SETTING - local no = lang.script.WINDOW_DONT_SHOW_AGAIN - local result = client.awaitRequestMessage('Info' - , lang.script('WINDOW_ASK_APPLY_LIBRARY', cfg.name) - , {yes1, yes2, no} - ) - if not result then - return nil - end - if result == yes1 then + + if checkThirdParty == 'Apply' then apply3rd(uri, cfg, false) - elseif result == yes2 then + elseif checkThirdParty == 'ApplyInMemory' then apply3rd(uri, cfg, true) - else - client.setConfig({ - { - key = 'Lua.workspace.checkThirdParty', - action = 'set', - value = false, - uri = uri, - }, - }, false) + elseif checkThirdParty == 'Disable' then + return nil + elseif checkThirdParty == 'Ask' then + hasAsked[cfg.name] = true + local applyAndSetConfig = lang.script.WINDOW_APPLY_WHIT_SETTING + local applyInMemory = lang.script.WINDOW_APPLY_WHITOUT_SETTING + local dontShowAgain = lang.script.WINDOW_DONT_SHOW_AGAIN + local result = client.awaitRequestMessage('Info' + , lang.script('WINDOW_ASK_APPLY_LIBRARY', cfg.name) + , {applyAndSetConfig, applyInMemory, dontShowAgain} + ) + if not result then + -- "If none got selected" + -- See: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#window_showMessageRequest + return nil + end + if result == applyAndSetConfig then + apply3rd(uri, cfg, false) + elseif result == applyInMemory then + apply3rd(uri, cfg, true) + else + client.setConfig({ + { + key = 'Lua.workspace.checkThirdParty', + action = 'set', + value = 'Disable', + uri = uri, + }, + }, false) + end end end @@ -517,7 +528,7 @@ local function wholeMatch(a, b) return true end -local function check3rdByWords(uri, configs) +local function check3rdByWords(uri, configs, checkThirdParty) if not files.isLua(uri) then return end @@ -546,7 +557,7 @@ local function check3rdByWords(uri, configs) log.info('Found 3rd library by word: ', word, uri, library, inspect(config.get(uri, 'Lua.workspace.library'))) ---@async await.call(function () - askFor3rd(uri, cfg) + askFor3rd(uri, cfg, checkThirdParty) end) return end @@ -556,7 +567,7 @@ local function check3rdByWords(uri, configs) end, id) end -local function check3rdByFileName(uri, configs) +local function check3rdByFileName(uri, configs, checkThirdParty) local path = ws.getRelativePath(uri) if not path then return @@ -582,7 +593,7 @@ local function check3rdByFileName(uri, configs) log.info('Found 3rd library by filename: ', filename, uri, library, inspect(config.get(uri, 'Lua.workspace.library'))) ---@async await.call(function () - askFor3rd(uri, cfg) + askFor3rd(uri, cfg, checkThirdParty) end) return end @@ -597,8 +608,12 @@ local function check3rd(uri) if ws.isIgnored(uri) then return end - if not config.get(uri, 'Lua.workspace.checkThirdParty') then + local checkThirdParty = config.get(uri, 'Lua.workspace.checkThirdParty') + -- Backwards compatability: `checkThirdParty` used to be a boolean. + if not checkThirdParty or checkThirdParty == 'Disable' then return + elseif checkThirdParty == true then + checkThirdParty = 'Ask' end local scp = scope.getScope(uri) if not scp:get 'canCheckThirdParty' then @@ -608,8 +623,8 @@ local function check3rd(uri) if not thirdConfigs then return end - check3rdByWords(uri, thirdConfigs) - check3rdByFileName(uri, thirdConfigs) + check3rdByWords(uri, thirdConfigs, checkThirdParty) + check3rdByFileName(uri, thirdConfigs, checkThirdParty) end local function check3rdOfWorkspace(suri) diff --git a/script/parser/compile.lua b/script/parser/compile.lua index e09c958f..5321d9b8 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -470,7 +470,7 @@ end local function parseLongString() local start, finish, mark = sfind(Lua, '^(%[%=*%[)', Tokens[Index]) - if not mark then + if not start then return nil end fastForwardToken(finish + 1) @@ -2313,7 +2313,12 @@ local function parseFunction(isLocal, isAction) local params if func.name and func.name.type == 'getmethod' then if func.name.type == 'getmethod' then - params = {} + params = { + type = 'funcargs', + start = funcRight, + finish = funcRight, + parent = func + } params[1] = createLocal { start = funcRight, finish = funcRight, diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 0d190ed5..4e71c832 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -56,7 +56,7 @@ local type = type ---@field returnIndex integer ---@field assignIndex integer ---@field docIndex integer ----@field docs parser.object[] +---@field docs parser.object ---@field state table ---@field comment table ---@field optional boolean @@ -74,7 +74,9 @@ local type = type ---@field hasBreak? true ---@field hasExit? true ---@field [integer] parser.object|any ----@field package _root parser.object +---@field package _root parser.object +---@field package _eachCache? parser.object[] +---@field package _isGlobal? boolean ---@class guide ---@field debugMode boolean @@ -154,10 +156,10 @@ local childMap = { ['unary'] = {1}, ['doc'] = {'#'}, - ['doc.class'] = {'class', '#extends', '#signs', 'comment'}, + ['doc.class'] = {'class', '#extends', '#signs', 'docAttr', 'comment'}, ['doc.type'] = {'#types', 'name', 'comment'}, ['doc.alias'] = {'alias', 'extends', 'comment'}, - ['doc.enum'] = {'enum', 'extends', 'comment'}, + ['doc.enum'] = {'enum', 'extends', 'comment', 'docAttr'}, ['doc.param'] = {'param', 'extends', 'comment'}, ['doc.return'] = {'#returns', 'comment'}, ['doc.field'] = {'field', 'extends', 'comment'}, @@ -180,6 +182,7 @@ local childMap = { ['doc.cast.block'] = {'extends'}, ['doc.operator'] = {'op', 'exp', 'extends'}, ['doc.meta'] = {'name'}, + ['doc.attr'] = {'#names'}, } ---@type table<string, fun(obj: parser.object, list: parser.object[])> @@ -736,7 +739,7 @@ end --- 遍历所有指定类型的source ---@param ast parser.object ---@param type string ----@param callback fun(src: parser.object) +---@param callback fun(src: parser.object): any ---@return any function m.eachSourceType(ast, type, callback) local cache = getSourceTypeCache(ast) diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 9f3b8fd5..8e7d334f 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -156,6 +156,7 @@ Symbol <- ({} { ---@field calls? parser.object[] ---@field generics? parser.object[] ---@field generic? parser.object +---@field docAttr? parser.object local function parseTokens(text, offset) Ci = 0 @@ -252,6 +253,40 @@ local function nextSymbolOrError(symbol) return false end +local function parseDocAttr(parent) + if not checkToken('symbol', '(', 1) then + return nil + end + nextToken() + + local attrs = { + type = 'doc.attr', + parent = parent, + start = getStart(), + finish = getStart(), + names = {}, + } + + while true do + if checkToken('symbol', ',', 1) then + nextToken() + goto continue + end + local name = parseName('doc.attr.name', attrs) + if not name then + break + end + attrs.names[#attrs.names+1] = name + attrs.finish = name.finish + ::continue:: + end + + nextSymbolOrError(')') + attrs.finish = getFinish() + + return attrs +end + local function parseIndexField(parent) if not checkToken('symbol', '[', 1) then return nil @@ -806,6 +841,7 @@ local docSwitch = util.switch() operators = {}, calls = {}, } + result.docAttr = parseDocAttr(result) result.class = parseName('doc.class.name', result) if not result.class then pushWarning { @@ -1108,13 +1144,13 @@ local docSwitch = util.switch() end) : case 'meta' : call(function () - local requireName = parseName('doc.meta.name') - return { + local meta = { type = 'doc.meta', - name = requireName, start = getFinish(), finish = getFinish(), } + meta.name = parseName('doc.meta.name', meta) + return meta end) : case 'version' : call(function () @@ -1428,17 +1464,22 @@ local docSwitch = util.switch() end) : case 'enum' : call(function () + local attr = parseDocAttr() local name = parseName('doc.enum.name') if not name then return nil end local result = { - type = 'doc.enum', - start = name.start, - finish = name.finish, - enum = name, + type = 'doc.enum', + start = name.start, + finish = name.finish, + enum = name, + docAttr = attr, } name.parent = result + if attr then + attr.parent = result + end return result end) : case 'private' @@ -1534,7 +1575,7 @@ local function buildLuaDoc(comment) parseTokens(doc, startOffset + startPos) local result, rests = convertTokens(doc) if result then - result.range = comment.finish + result.range = math.max(comment.finish, result.finish) local finish = result.firstFinish or result.finish if rests then for _, rest in ipairs(rests) do @@ -1669,7 +1710,9 @@ local function bindDocWithSource(doc, source) if not source.bindDocs then source.bindDocs = {} end - source.bindDocs[#source.bindDocs+1] = doc + if source.bindDocs[#source.bindDocs] ~= doc then + source.bindDocs[#source.bindDocs+1] = doc + end doc.bindSource = source end @@ -1820,7 +1863,7 @@ local function bindDocsBetween(sources, binded, start, finish) or src.type == 'setindex' or src.type == 'setmethod' or src.type == 'function' - or src.type == 'table' + or src.type == 'return' or src.type == '...' then if bindDoc(src, binded) then ok = true @@ -1940,7 +1983,7 @@ local bindDocAccept = { 'local' , 'setlocal' , 'setglobal', 'setfield' , 'setmethod' , 'setindex' , 'tablefield', 'tableindex', 'self' , - 'function' , 'table' , '...' , + 'function' , 'return' , '...' , } local function bindDocs(state) diff --git a/script/plugins/ffi/c-parser/ctypes.lua b/script/plugins/ffi/c-parser/ctypes.lua index 81d0ccf6..115f78ab 100644 --- a/script/plugins/ffi/c-parser/ctypes.lua +++ b/script/plugins/ffi/c-parser/ctypes.lua @@ -149,7 +149,10 @@ end local function add_to_fields(lst, field_src, fields) if type(field_src) == "table" and not field_src.ids then assert(field_src.type.type == "union") - local subfields = get_fields(lst, field_src.type.fields) + local subfields, err = get_fields(lst, field_src.type.fields) + if not subfields then + return nil, err + end for _, subfield in ipairs(subfields) do table.insert(fields, subfield) end diff --git a/script/plugins/ffi/cdefRerence.lua b/script/plugins/ffi/cdefRerence.lua index 14643f0f..54a8c2a7 100644 --- a/script/plugins/ffi/cdefRerence.lua +++ b/script/plugins/ffi/cdefRerence.lua @@ -20,7 +20,7 @@ end return function () local ffi_state for uri in files.eachFile() do - if find(uri, "ffi.lua", 0, true) and find(uri, "lua-language-server", 0, true) then + if find(uri, "ffi.lua", 0, true) and find(uri, "meta", 0, true) then ffi_state = files.getState(uri) break end diff --git a/script/proto/converter.lua b/script/proto/converter.lua index a723face..e86e4904 100644 --- a/script/proto/converter.lua +++ b/script/proto/converter.lua @@ -207,6 +207,14 @@ function m.setOffsetEncoding(encoding) offsetEncoding = encoding:lower():gsub('%-', '') end +---@param s string +---@param i? integer +---@param j? integer +---@return integer +function m.len(s, i, j) + return encoder.len(offsetEncoding, s, i, j) +end + ---@class proto.command ---@field title string ---@field command string diff --git a/script/proto/diagnostic.lua b/script/proto/diagnostic.lua index 8175a2c5..61b8ff4b 100644 --- a/script/proto/diagnostic.lua +++ b/script/proto/diagnostic.lua @@ -62,6 +62,7 @@ m.register { 'missing-return-value', 'redundant-return-value', 'missing-return', + 'missing-fields', } { group = 'unbalanced', severity = 'Warning', @@ -76,6 +77,7 @@ m.register { 'param-type-mismatch', 'cast-type-mismatch', 'return-type-mismatch', + 'inject-field', } { group = 'type-check', severity = 'Warning', diff --git a/script/proto/proto.lua b/script/proto/proto.lua index 73544ffc..d01c8f36 100644 --- a/script/proto/proto.lua +++ b/script/proto/proto.lua @@ -8,6 +8,9 @@ local define = require 'proto.define' local json = require 'json' local inspect = require 'inspect' local thread = require 'bee.thread' +local fs = require 'bee.filesystem' +local net = require 'service.net' +local timer = require 'timer' local reqCounter = util.counter() @@ -32,8 +35,7 @@ m.ability = {} m.waiting = {} m.holdon = {} m.mode = 'stdio' ----@type bee.socket.fd -m.fd = nil +m.client = nil function m.getMethodName(proto) if proto.method:sub(1, 2) == '$/' then @@ -54,7 +56,8 @@ function m.send(data) if m.mode == 'stdio' then io.write(buf) elseif m.mode == 'socket' then - m.fd:send(buf) + m.client:write(buf) + net.update() end end @@ -237,13 +240,37 @@ function m.listen(mode, socketPort) io.stdout:setvbuf 'no' pub.task('loadProtoByStdio') elseif mode == 'socket' then - local rfd = assert(socket('tcp')) - rfd:connect('127.0.0.1', socketPort) - local wfd1, wfd2 = socket.pair() - m.fd = wfd1 + local unixFolder = LOGPATH .. '/unix' + fs.create_directories(fs.path(unixFolder)) + local unixPath = unixFolder .. '/' .. tostring(socketPort) + + local server = net.listen('unix', unixPath) + + assert(server) + + local dummyClient = { + buf = '', + write = function (self, data) + self.buf = self.buf.. data + end, + update = function () end, + } + m.client = dummyClient + + local t = timer.loop(0.1, function () + net.update() + end) + + function server:on_accept(client) + t:remove() + m.client = client + client:write(dummyClient.buf) + net.update() + end + pub.task('loadProtoBySocket', { - wfd = wfd2:detach(), - rfd = rfd:detach(), + port = socketPort, + unixPath = unixPath, }) end end diff --git a/script/provider/formatting.lua b/script/provider/formatting.lua index 73f9a534..ea94db08 100644 --- a/script/provider/formatting.lua +++ b/script/provider/formatting.lua @@ -82,16 +82,13 @@ function m.updateNonStandardSymbols(symbols) return end - local eqTokens = {} - for _, token in ipairs(symbols) do - if token:find("=") and token ~= "!=" then - table.insert(eqTokens, token) + for _, symbol in ipairs(symbols) do + if symbol == "//" then + codeFormat.set_clike_comments_symbol() end end - if #eqTokens ~= 0 then - codeFormat.set_nonstandard_symbol() - end + codeFormat.set_nonstandard_symbol() end config.watch(function(uri, key, value) diff --git a/script/provider/provider.lua b/script/provider/provider.lua index 787cfeb8..a791e980 100644 --- a/script/provider/provider.lua +++ b/script/provider/provider.lua @@ -712,11 +712,11 @@ m.register 'completionItem/resolve' { --await.setPriority(1000) local state = files.getState(uri) if not state then - return nil + return item end local resolved = core.resolve(id) if not resolved then - return nil + return item end item.detail = resolved.detail or item.detail item.documentation = resolved.description and { @@ -772,8 +772,8 @@ m.register 'textDocument/signatureHelp' { for j, param in ipairs(result.params) do parameters[j] = { label = { - param.label[1], - param.label[2], + converter.len(result.label, 1, param.label[1]), + converter.len(result.label, 1, param.label[2]), } } end @@ -904,7 +904,7 @@ m.register 'textDocument/codeLens' { resolveProvider = true, } }, - abortByFileUpdate = true, + --abortByFileUpdate = true, ---@async function (params) local uri = files.getRealUri(params.textDocument.uri) diff --git a/script/service/net.lua b/script/service/net.lua index 61603d79..2019406e 100644 --- a/script/service/net.lua +++ b/script/service/net.lua @@ -1,42 +1,39 @@ local socket = require "bee.socket" +local select = require "bee.select" +local selector = select.create() +local SELECT_READ <const> = select.SELECT_READ +local SELECT_WRITE <const> = select.SELECT_WRITE -local readfds = {} -local writefds = {} -local map = {} - -local function FD_SET(set, fd) - for i = 1, #set do - if fd == set[i] then - return - end +local function fd_set_read(s) + if s._flags & SELECT_READ ~= 0 then + return end - set[#set+1] = fd + s._flags = s._flags | SELECT_READ + selector:event_mod(s._fd, s._flags) end -local function FD_CLR(set, fd) - for i = 1, #set do - if fd == set[i] then - set[i] = set[#set] - set[#set] = nil - return - end +local function fd_clr_read(s) + if s._flags & SELECT_READ == 0 then + return end + s._flags = s._flags & (~SELECT_READ) + selector:event_mod(s._fd, s._flags) end -local function fd_set_read(fd) - FD_SET(readfds, fd) -end - -local function fd_clr_read(fd) - FD_CLR(readfds, fd) -end - -local function fd_set_write(fd) - FD_SET(writefds, fd) +local function fd_set_write(s) + if s._flags & SELECT_WRITE ~= 0 then + return + end + s._flags = s._flags | SELECT_WRITE + selector:event_mod(s._fd, s._flags) end -local function fd_clr_write(fd) - FD_CLR(writefds, fd) +local function fd_clr_write(s) + if s._flags & SELECT_WRITE == 0 then + return + end + s._flags = s._flags & (~SELECT_WRITE) + selector:event_mod(s._fd, s._flags) end local function on_event(self, name, ...) @@ -49,8 +46,8 @@ end local function close(self) local fd = self._fd on_event(self, "close") + selector:event_del(fd) fd:close() - map[fd] = nil end local stream_mt = {} @@ -69,7 +66,7 @@ function stream:write(data) return end if self._writebuf == "" then - fd_set_write(self._fd) + fd_set_write(self) end self._writebuf = self._writebuf .. data end @@ -79,35 +76,17 @@ end function stream:close() if not self.shutdown_r then self.shutdown_r = true - fd_clr_read(self._fd) + fd_clr_read(self) end if self.shutdown_w or self._writebuf == "" then self.shutdown_w = true - fd_clr_write(self._fd) + fd_clr_write(self) close(self) end end -function stream:update(timeout) - local fd = self._fd - local r = {fd} - local w = r - if self._writebuf == "" then - w = nil - end - local rd, wr = socket.select(r, w, timeout or 0) - if rd then - if #rd > 0 then - self:select_r() - end - if #wr > 0 then - self:select_w() - end - end -end local function close_write(self) - fd_clr_write(self._fd) + fd_clr_write(self) if self.shutdown_r then - fd_clr_read(self._fd) close(self) end end @@ -125,6 +104,7 @@ function stream:select_w() if n == nil then self.shutdown_w = true close_write(self) + elseif n == false then else self._writebuf = self._writebuf:sub(n + 1) if self._writebuf == "" then @@ -132,26 +112,43 @@ function stream:select_w() end end end +local function update_stream(s, event) + if event & SELECT_READ ~= 0 then + s:select_r() + end + if event & SELECT_WRITE ~= 0 then + s:select_w() + end +end local function accept_stream(fd) - local self = setmetatable({ + local s = setmetatable({ _fd = fd, + _flags = SELECT_READ, _event = {}, _writebuf = "", shutdown_r = false, shutdown_w = false, }, stream_mt) - map[fd] = self - fd_set_read(fd) - return self -end -local function connect_stream(self) - setmetatable(self, stream_mt) - fd_set_read(self._fd) - if self._writebuf ~= "" then - self:select_w() + selector:event_add(fd, SELECT_READ, function (event) + update_stream(s, event) + end) + return s +end +local function connect_stream(s) + setmetatable(s, stream_mt) + selector:event_del(s._fd) + if s._writebuf ~= "" then + s._flags = SELECT_READ | SELECT_WRITE + selector:event_add(s._fd, SELECT_READ | SELECT_WRITE, function (event) + update_stream(s, event) + end) + s:select_w() else - fd_clr_write(self._fd) + s._flags = SELECT_READ + selector:event_add(s._fd, SELECT_READ, function (event) + update_stream(s, event) + end) end end @@ -169,35 +166,32 @@ function listen:is_closed() end function listen:close() self.shutdown_r = true - fd_clr_read(self._fd) close(self) end -function listen:update(timeout) - local fd = self._fd - local r = {fd} - local rd = socket.select(r, nil, timeout or 0) - if rd then - if #rd > 0 then - self:select_r() - end - end -end -function listen:select_r() - local newfd = self._fd:accept() - if newfd:status() then - local news = accept_stream(newfd) - on_event(self, "accept", news) - end -end local function new_listen(fd) local s = { _fd = fd, + _flags = SELECT_READ, _event = {}, shutdown_r = false, shutdown_w = true, } - map[fd] = s - fd_set_read(fd) + selector:event_add(fd, SELECT_READ, function () + local newfd, err = fd:accept() + if not newfd then + on_event(s, "error", err) + return + end + local ok, err = newfd:status() + if not ok then + on_event(s, "error", err) + return + end + if newfd:status() then + local news = accept_stream(newfd) + on_event(s, "accept", news) + end + end) return setmetatable(s, listen_mt) end @@ -220,39 +214,27 @@ function connect:is_closed() end function connect:close() self.shutdown_w = true - fd_clr_write(self._fd) close(self) end -function connect:update(timeout) - local fd = self._fd - local w = {fd} - local rd, wr = socket.select(nil, w, timeout or 0) - if rd then - if #wr > 0 then - self:select_w() - end - end -end -function connect:select_w() - local ok, err = self._fd:status() - if ok then - connect_stream(self) - on_event(self, "connect") - else - on_event(self, "error", err) - self:close() - end -end local function new_connect(fd) local s = { _fd = fd, + _flags = SELECT_WRITE, _event = {}, _writebuf = "", shutdown_r = false, shutdown_w = false, } - map[fd] = s - fd_set_write(fd) + selector:event_add(fd, SELECT_WRITE, function () + local ok, err = fd:status() + if ok then + connect_stream(s) + on_event(s, "connect") + else + on_event(s, "error", err) + s:close() + end + end) return setmetatable(s, connect_mt) end @@ -292,18 +274,8 @@ function m.connect(protocol, ...) end function m.update(timeout) - local rd, wr = socket.select(readfds, writefds, timeout or 0) - if rd then - for i = 1, #rd do - local fd = rd[i] - local s = map[fd] - s:select_r() - end - for i = 1, #wr do - local fd = wr[i] - local s = map[fd] - s:select_w() - end + for func, event in selector:wait(timeout or 0) do + func(event) end end diff --git a/script/service/service.lua b/script/service/service.lua index 7011ec4f..b6056390 100644 --- a/script/service/service.lua +++ b/script/service/service.lua @@ -257,7 +257,6 @@ function m.lockCache() if err then log.error(err) end - pub.task('removeCaches', cacheDir) end function m.start() diff --git a/script/utility.lua b/script/utility.lua index be945791..936726f9 100644 --- a/script/utility.lua +++ b/script/utility.lua @@ -190,38 +190,48 @@ function m.dump(tbl, option) end --- 递归判断A与B是否相等 ----@param a any ----@param b any +---@param valueA any +---@param valueB any ---@return boolean -function m.equal(a, b) - local tp1 = type(a) - local tp2 = type(b) - if tp1 ~= tp2 then - return false - end - if tp1 == 'table' then - local mark = {} - for k, v in pairs(a) do - mark[k] = true - local res = m.equal(v, b[k]) - if not res then - return false - end +function m.equal(valueA, valueB) + local hasChecked = {} + + local function equal(a, b) + local tp1 = type(a) + local tp2 = type(b) + if tp1 ~= tp2 then + return false end - for k in pairs(b) do - if not mark[k] then - return false + if tp1 == 'table' then + if hasChecked[a] then + return true + end + hasChecked[a] = true + local mark = {} + for k, v in pairs(a) do + mark[k] = true + local res = equal(v, b[k]) + if not res then + return false + end + end + for k in pairs(b) do + if not mark[k] then + return false + end end - end - return true - elseif tp1 == 'number' then - if mathAbs(a - b) <= 1e-10 then return true + elseif tp1 == 'number' then + if mathAbs(a - b) <= 1e-10 then + return true + end + return tostring(a) == tostring(b) + else + return a == b end - return tostring(a) == tostring(b) - else - return a == b end + + return equal(valueA, valueB) end local function sortTable(tbl) @@ -283,6 +293,9 @@ end --- 读取文件 ---@param path string +---@param keepBom? boolean +---@return string? text +---@return string? errMsg function m.loadFile(path, keepBom) local f, e = ioOpen(path, 'rb') if not f then @@ -308,6 +321,8 @@ end --- 写入文件 ---@param path string ---@param content string +---@return boolean ok +---@return string? errMsg function m.saveFile(path, content) local f, e = ioOpen(path, "wb") @@ -357,6 +372,7 @@ end --- 深拷贝(不处理元表) ---@param source table ---@param target? table +---@return table function m.deepCopy(source, target) local mark = {} local function copy(a, b) @@ -379,6 +395,8 @@ function m.deepCopy(source, target) end --- 序列化 +---@param t table +---@return table function m.unpack(t) local result = {} local tid = 0 @@ -406,6 +424,8 @@ function m.unpack(t) end --- 反序列化 +---@param t table +---@return table function m.pack(t) local cache = {} local function pack(id) @@ -518,18 +538,25 @@ function m.utf8Len(str, start, finish) return len end -function m.revertTable(t) - local len = #t +-- 把数组中的元素顺序*原地*反转 +---@param arr any[] +---@return any[] +function m.revertArray(arr) + local len = #arr if len <= 1 then - return t + return arr end for x = 1, len // 2 do local y = len - x + 1 - t[x], t[y] = t[y], t[x] + arr[x], arr[y] = arr[y], arr[x] end - return t + return arr end +-- 创建一个value-key表 +---@generic K, V +---@param t table<K, V> +---@return table<V, K> function m.revertMap(t) local nt = {} for k, v in pairs(t) do @@ -624,6 +651,11 @@ function m.eachLine(text, keepNL) end end +---@alias SortByScoreCallback fun(o: any): integer + +-- 按照分数排序,分数越高越靠前 +---@param tbl any[] +---@param callbacks SortByScoreCallback | SortByScoreCallback[] function m.sortByScore(tbl, callbacks) if type(callbacks) ~= 'table' then callbacks = { callbacks } @@ -651,6 +683,16 @@ function m.sortByScore(tbl, callbacks) end) end +---@param arr any[] +---@return SortByScoreCallback +function m.sortCallbackOfIndex(arr) + ---@type table<any, integer> + local indexMap = m.revertMap(arr) + return function (v) + return - indexMap[v] + end +end + ---裁剪字符串 ---@param str string ---@param mode? '"left"'|'"right"' @@ -733,6 +775,7 @@ function switchMT:has(name) end ---@param name string +---@param ... any ---@return ... function switchMT:__call(name, ...) local callback = self.map[name] or self._default @@ -752,6 +795,8 @@ function m.switch() end ---@param f async fun() +---@param name string +---@return any, boolean function m.getUpvalue(f, name) for i = 1, 999 do local uname, value = getupvalue(f, i) @@ -819,6 +864,7 @@ end ---@param t table ---@param sorter boolean|function +---@return any[] function m.getTableKeys(t, sorter) local keys = {} for k in pairs(t) do @@ -841,6 +887,15 @@ function m.arrayHas(array, value) return false end +function m.arrayIndexOf(array, value) + for i = 1, #array do + if array[i] == value then + return i + end + end + return nil +end + function m.arrayInsert(array, value) if not m.arrayHas(array, value) then array[#array+1] = value @@ -873,4 +928,24 @@ function m.cacheReturn(func) end end +---@param a table +---@param b table +---@return table +function m.tableMerge(a, b) + for k, v in pairs(b) do + a[k] = v + end + return a +end + +---@param a any[] +---@param b any[] +---@return any[] +function m.arrayMerge(a, b) + for i = 1, #b do + a[#a+1] = b[i] + end + return a +end + return m diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 6b4636fc..8a1fa96a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -7,10 +7,14 @@ local files = require 'files' local vm = require 'vm.vm' ---@class parser.object ----@field _compiledNodes boolean ----@field _node vm.node ----@field cindex integer ----@field func parser.object +---@field _compiledNodes boolean +---@field _node vm.node +---@field cindex integer +---@field func parser.object +---@field hideView boolean +---@field package _returns? parser.object[] +---@field package _callReturns? parser.object[] +---@field package _asCache? parser.object[] -- 该函数有副作用,会给source绑定node! ---@param source parser.object @@ -28,6 +32,12 @@ function vm.bindDocs(source) end if doc.type == 'doc.class' then vm.setNode(source, vm.compileNode(doc)) + for j = i + 1, #docs do + local overload = docs[j] + if overload.type == 'doc.overload' then + overload.overload.hideView = true + end + end return true end if doc.type == 'doc.param' then @@ -55,6 +65,9 @@ function vm.bindDocs(source) vm.setNode(source, vm.compileNode(ast)) return true end + if doc.type == 'doc.overload' then + vm.setNode(source, vm.compileNode(doc)) + end end return false end @@ -473,6 +486,7 @@ function vm.getReturnOfFunction(func, index) func._returns = {} end if not func._returns[index] then + ---@diagnostic disable-next-line: missing-fields func._returns[index] = { type = 'function.return', parent = func, @@ -570,6 +584,7 @@ local function getReturn(func, index, args) end if not func._callReturns[index] then local call = func.parent + ---@diagnostic disable-next-line: missing-fields func._callReturns[index] = { type = 'call.return', parent = call, @@ -756,6 +771,11 @@ function vm.selectNode(list, index) if not exp then return vm.createNode(vm.declareGlobal('type', 'nil')), nil end + + if vm.bindDocs(list) then + return vm.compileNode(list), exp + end + ---@type vm.node? local result if exp.type == 'call' then @@ -862,52 +882,69 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) end end - for n in callNode:eachObject() do - if n.type == 'function' then - ---@cast n parser.object - local sign = vm.getSign(n) + ---@param n parser.object + local function dealDocFunc(n) + local myEvent + if n.args[eventIndex] then + local argNode = vm.compileNode(n.args[eventIndex]) + myEvent = argNode:get(1) + end + if not myEvent + or not eventMap + or myIndex <= eventIndex + or myEvent.type ~= 'doc.type.string' + or eventMap[myEvent[1]] then local farg = getFuncArg(n, myIndex) if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then - if fn.type == 'doc.type.function' then - ---@cast fn parser.object - if sign then - local generic = vm.createGeneric(fn, sign) - local args = {} - for i = fixIndex + 1, myIndex - 1 do - args[#args+1] = call.args[i] - end - local resolvedNode = generic:resolve(guide.getUri(call), args) - vm.setNode(arg, resolvedNode) - goto CONTINUE - end - end vm.setNode(arg, fn) - ::CONTINUE:: end end end end - if n.type == 'doc.type.function' then - ---@cast n parser.object - local myEvent - if n.args[eventIndex] then - local argNode = vm.compileNode(n.args[eventIndex]) - myEvent = argNode:get(1) - end - if not myEvent - or not eventMap - or myIndex <= eventIndex - or myEvent.type ~= 'doc.type.string' - or eventMap[myEvent[1]] then - local farg = getFuncArg(n, myIndex) - if farg then - for fn in vm.compileNode(farg):eachObject() do - if isValidCallArgNode(arg, fn) then - vm.setNode(arg, fn) + end + + ---@param n parser.object + local function dealFunction(n) + local sign = vm.getSign(n) + local farg = getFuncArg(n, myIndex) + if farg then + for fn in vm.compileNode(farg):eachObject() do + if isValidCallArgNode(arg, fn) then + if fn.type == 'doc.type.function' then + ---@cast fn parser.object + if sign then + local generic = vm.createGeneric(fn, sign) + local args = {} + for i = fixIndex + 1, myIndex - 1 do + args[#args+1] = call.args[i] + end + local resolvedNode = generic:resolve(guide.getUri(call), args) + vm.setNode(arg, resolvedNode) + goto CONTINUE end end + vm.setNode(arg, fn) + ::CONTINUE:: + end + end + end + end + + for n in callNode:eachObject() do + if n.type == 'function' then + ---@cast n parser.object + dealFunction(n) + elseif n.type == 'doc.type.function' then + ---@cast n parser.object + dealDocFunc(n) + elseif n.type == 'global' and n.cate == 'type' then + ---@cast n vm.global + local overloads = vm.getOverloadsByTypeName(n.name, guide.getUri(arg)) + if overloads then + for _, func in ipairs(overloads) do + dealDocFunc(func) end end end @@ -1020,6 +1057,7 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(source.value)) end end + -- function x.y(self, ...) --> function x:y(...) if source[1] == 'self' and not hasMarkDoc @@ -1031,6 +1069,7 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(setfield.node)) end end + if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then local func = source.parent.parent -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number @@ -1055,6 +1094,7 @@ local function compileLocal(source) vm.setNode(source, vm.declareGlobal('type', 'any')) end end + -- for x in ... do if source.parent.type == 'in' then compileForVars(source.parent, source) @@ -1094,6 +1134,12 @@ local function compileLocal(source) end end + if source.value + and source.value.type == 'nil' + and not myNode:hasKnownType() then + vm.setNode(source, vm.compileNode(source.value)) + end + myNode.hasDefined = hasMarkDoc or hasMarkParam or hasMarkValue end @@ -1140,6 +1186,9 @@ local compilerSwitch = util.switch() end) : case 'table' : call(function (source) + if vm.bindAs(source) then + return + end vm.setNode(source, source) if source.parent.type == 'callargs' then @@ -1147,6 +1196,16 @@ local compilerSwitch = util.switch() vm.compileCallArg(source, call) end + if source.parent.type == 'return' then + local myIndex = util.arrayIndexOf(source.parent, source) + ---@cast myIndex -? + local parentNode = vm.selectNode(source.parent, myIndex) + if not parentNode:isEmpty() then + vm.setNode(source, parentNode) + return + end + end + if source.parent.type == 'setglobal' or source.parent.type == 'local' or source.parent.type == 'setlocal' @@ -1648,7 +1707,7 @@ local compilerSwitch = util.switch() if state.type == 'doc.return' or state.type == 'doc.param' then local func = state.bindSource - if func.type == 'function' then + if func and func.type == 'function' then local node = guide.getFunctionSelfNode(func) if node then vm.setNode(source, vm.compileNode(node)) diff --git a/script/vm/doc.lua b/script/vm/doc.lua index a6ea248f..6ac39910 100644 --- a/script/vm/doc.lua +++ b/script/vm/doc.lua @@ -5,7 +5,11 @@ local vm = require 'vm.vm' local config = require 'config' ---@class parser.object ----@field package _castTargetHead parser.object | vm.global | false +---@field package _castTargetHead? parser.object | vm.global | false +---@field package _validVersions? table<string, boolean> +---@field package _deprecated? parser.object | false +---@field package _async? boolean +---@field package _nodiscard? boolean ---获取class与alias ---@param suri uri @@ -467,3 +471,18 @@ function vm.getCastTargetHead(doc) end return nil end + +---@param doc parser.object +---@param key string +---@return boolean +function vm.docHasAttr(doc, key) + if not doc.docAttr then + return false + end + for _, name in ipairs(doc.docAttr.names) do + if name[1] == key then + return true + end + end + return false +end diff --git a/script/vm/function.lua b/script/vm/function.lua index bdd7e229..c6df6349 100644 --- a/script/vm/function.lua +++ b/script/vm/function.lua @@ -372,8 +372,14 @@ function vm.isVarargFunctionWithOverloads(func) if not func.args then return false end - if not func.args[1] or func.args[1].type ~= '...' then - return false + if func.args[1] and func.args[1].type == 'self' then + if not func.args[2] or func.args[2].type ~= '...' then + return false + end + else + if not func.args[1] or func.args[1].type ~= '...' then + return false + end end if not func.bindDocs then return false diff --git a/script/vm/global.lua b/script/vm/global.lua index c1b5f320..e830f6d8 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -1,6 +1,7 @@ local util = require 'utility' local scope = require 'workspace.scope' local guide = require 'parser.guide' +local config = require 'config' ---@class vm local vm = require 'vm.vm' @@ -371,16 +372,36 @@ local compilerGlobalSwitch = util.switch() return end source._enums = {} - for _, field in ipairs(tbl) do - if field.type == 'tablefield' then - source._enums[#source._enums+1] = field - local subType = vm.declareGlobal('type', name .. '.' .. field.field[1], uri) - subType:addSet(uri, field) - elseif field.type == 'tableindex' then - source._enums[#source._enums+1] = field - if field.index.type == 'string' then - local subType = vm.declareGlobal('type', name .. '.' .. field.index[1], uri) + if vm.docHasAttr(source, 'key') then + for _, field in ipairs(tbl) do + if field.type == 'tablefield' then + source._enums[#source._enums+1] = { + type = 'doc.type.string', + start = field.field.start, + finish = field.field.finish, + [1] = field.field[1], + } + elseif field.type == 'tableindex' then + source._enums[#source._enums+1] = { + type = 'doc.type.string', + start = field.index.start, + finish = field.index.finish, + [1] = field.index[1], + } + end + end + else + for _, field in ipairs(tbl) do + if field.type == 'tablefield' then + source._enums[#source._enums+1] = field + local subType = vm.declareGlobal('type', name .. '.' .. field.field[1], uri) subType:addSet(uri, field) + elseif field.type == 'tableindex' then + source._enums[#source._enums+1] = field + if field.index.type == 'string' then + local subType = vm.declareGlobal('type', name .. '.' .. field.index[1], uri) + subType:addSet(uri, field) + end end end end @@ -518,6 +539,33 @@ function vm.hasGlobalSets(suri, cate, name) return true end +---@param src parser.object +local function checkIsUndefinedGlobal(src) + local key = src[1] + + local uri = guide.getUri(src) + local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) + local rspecial = config.get(uri, 'Lua.runtime.special') + + local node = src.node + return src.type == 'getglobal' and key and not ( + dglobals[key] or + rspecial[key] or + node.tag ~= '_ENV' or + vm.hasGlobalSets(uri, 'variable', key) + ) +end + +---@param src parser.object +---@return boolean +function vm.isUndefinedGlobal(src) + local node = vm.compileNode(src) + if node.undefinedGlobal == nil then + node.undefinedGlobal = checkIsUndefinedGlobal(src) + end + return node.undefinedGlobal +end + ---@param source parser.object function compileObject(source) if source._globalNode ~= nil then @@ -593,6 +641,7 @@ function vm.getGlobalBase(source) end local name = global:asKeyName() if not root._globalBaseMap[name] then + ---@diagnostic disable-next-line: missing-fields root._globalBaseMap[name] = { type = 'globalbase', parent = root, diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 94fdfd88..f2673ed3 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -386,9 +386,11 @@ function mt:_computeViews(uri) self.views = {} for n in self.node:eachObject() do - local view = viewNodeSwitch(n.type, n, self, uri) - if view then - self.views[view] = true + if not n.hideView then + local view = viewNodeSwitch(n.type, n, self, uri) + if view then + self.views[view] = true + end end end @@ -565,11 +567,12 @@ function vm.viewKey(source, uri) return vm.viewKey(source.types[1], uri) else local key = vm.getInfer(source):view(uri) - return '[' .. key .. ']' + return '[' .. key .. ']', key end end if source.type == 'tableindex' - or source.type == 'setindex' then + or source.type == 'setindex' + or source.type == 'getindex' then local index = source.index local name = vm.getInfer(index):viewLiterals() if not name then @@ -587,7 +590,11 @@ function vm.viewKey(source, uri) return vm.viewKey(source.name, uri) end if source.type == 'doc.type.name' then - return '[' .. source[1] .. ']' + return '[' .. source[1] .. ']', source[1] + end + if source.type == 'doc.type.string' then + local name = util.viewString(source[1], source[2]) + return ('[%s]'):format(name), name end local key = vm.getKeyName(source) if key == nil then diff --git a/script/vm/node.lua b/script/vm/node.lua index 0ffd8c70..bc1dfcb1 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -15,6 +15,7 @@ vm.nodeCache = setmetatable({}, util.MODE_K) ---@field [integer] vm.node.object ---@field [vm.node.object] true ---@field fields? table<vm.node|string, vm.node> +---@field undefinedGlobal boolean? local mt = {} mt.__index = mt mt.id = 0 diff --git a/script/vm/operator.lua b/script/vm/operator.lua index bc8703c6..7ce2b30d 100644 --- a/script/vm/operator.lua +++ b/script/vm/operator.lua @@ -126,6 +126,7 @@ vm.unarySwich = util.switch() if result == nil then vm.setNode(source, vm.declareGlobal('type', 'boolean')) else + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = 'boolean', start = source.start, @@ -155,6 +156,7 @@ vm.unarySwich = util.switch() vm.setNode(source, node or vm.declareGlobal('type', 'number')) end else + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = 'number', start = source.start, @@ -171,6 +173,7 @@ vm.unarySwich = util.switch() local node = vm.runOperator('bnot', source[1]) vm.setNode(source, node or vm.declareGlobal('type', 'integer')) else + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = 'integer', start = source.start, @@ -223,6 +226,7 @@ vm.binarySwitch = util.switch() if source.op.type == '~=' then result = not result end + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = 'boolean', start = source.start, @@ -247,6 +251,7 @@ vm.binarySwitch = util.switch() or op == '&' and a & b or op == '|' and a | b or op == '~' and a ~ b + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = 'integer', start = source.start, @@ -285,6 +290,7 @@ vm.binarySwitch = util.switch() or op == '%' and a % b or op == '//' and a // b or op == '^' and a ^ b + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = (op == '//' or math.type(result) == 'integer') and 'integer' or 'number', start = source.start, @@ -364,6 +370,7 @@ vm.binarySwitch = util.switch() end end end + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = 'string', start = source.start, @@ -407,6 +414,7 @@ vm.binarySwitch = util.switch() or op == '<' and a < b or op == '>=' and a >= b or op == '<=' and a <= b + ---@diagnostic disable-next-line: missing-fields vm.setNode(source, { type = 'boolean', start = source.start, diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 1f434475..38cb2242 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -142,13 +142,15 @@ function mt:resolve(uri, args) end if object.type == 'doc.type.function' then for i, arg in ipairs(object.args) do - for n in node:eachObject() do - if n.type == 'function' - or n.type == 'doc.type.function' then - ---@cast n parser.object - local farg = n.args and n.args[i] - if farg then - resolve(arg.extends, vm.compileNode(farg)) + if arg.extends then + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local farg = n.args and n.args[i] + if farg then + resolve(arg.extends, vm.compileNode(farg)) + end end end end @@ -254,7 +256,7 @@ function mt:resolve(uri, args) local argNode = vm.compileNode(arg) local knownTypes, genericNames = getSignInfo(sign) if not isAllResolved(genericNames) then - local newArgNode = buildArgNode(argNode,sign, knownTypes) + local newArgNode = buildArgNode(argNode, sign, knownTypes) resolve(sign, newArgNode) end end diff --git a/script/vm/type.lua b/script/vm/type.lua index 8382eb86..545d2de5 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -143,6 +143,9 @@ end ---@param errs? typecheck.err[] ---@return boolean? local function checkChildEnum(childName, parent , uri, mark, errs) + if mark[childName] then + return + end local childClass = vm.getGlobal('type', childName) if not childClass then return nil @@ -157,11 +160,14 @@ local function checkChildEnum(childName, parent , uri, mark, errs) if not enums then return nil end + mark[childName] = true for _, enum in ipairs(enums) do if not vm.isSubType(uri, vm.compileNode(enum), parent, mark ,errs) then + mark[childName] = nil return false end end + mark[childName] = nil return true end @@ -752,7 +758,7 @@ function vm.viewTypeErrorMessage(uri, errs) lines[#lines+1] = '- ' .. line end end - util.revertTable(lines) + util.revertArray(lines) if #lines > 15 then lines[13] = ('...(+%d)'):format(#lines - 15) table.move(lines, #lines - 2, #lines, 14) @@ -761,3 +767,25 @@ function vm.viewTypeErrorMessage(uri, errs) return table.concat(lines, '\n') end end + +---@param name string +---@param uri uri +---@return parser.object[]? +function vm.getOverloadsByTypeName(name, uri) + local global = vm.getGlobal('type', name) + if not global then + return nil + end + local results + for _, set in ipairs(global:getSets(uri)) do + for _, doc in ipairs(set.bindGroup) do + if doc.type == 'doc.overload' then + if not results then + results = {} + end + results[#results+1] = doc.overload + end + end + end + return results +end diff --git a/script/vm/visible.lua b/script/vm/visible.lua index e550280f..d13ecf1f 100644 --- a/script/vm/visible.lua +++ b/script/vm/visible.lua @@ -7,9 +7,10 @@ local glob = require 'glob' ---@class parser.object ---@field package _visibleType? parser.visibleType ----@param source parser.object ----@return parser.visibleType -function vm.getVisibleType(source) +local function getVisibleType(source) + if guide.isLiteral(source) then + return 'public' + end if source._visibleType then return source._visibleType end @@ -55,6 +56,27 @@ function vm.getVisibleType(source) return 'public' end +---@class vm.node +---@field package _visibleType parser.visibleType + +---@param source parser.object +---@return parser.visibleType +function vm.getVisibleType(source) + local node = vm.compileNode(source) + if node._visibleType then + return node._visibleType + end + for _, def in ipairs(vm.getDefs(source)) do + local visible = getVisibleType(def) + if visible ~= 'public' then + node._visibleType = visible + return visible + end + end + node._visibleType = 'public' + return 'public' +end + ---@param source parser.object ---@return vm.global? function vm.getParentClass(source) diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua index 3e85e0fc..97518e84 100644 --- a/script/workspace/workspace.lua +++ b/script/workspace/workspace.lua @@ -50,8 +50,10 @@ function m.create(uri) m.folders[#m.folders+1] = scp if uri == furi.encode '/' or uri == furi.encode(os.getenv 'HOME' or '') then - client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri))) - scp:set('bad root', true) + if not FORCE_ACCEPT_WORKSPACE then + client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri))) + scp:set('bad root', true) + end end end @@ -469,10 +471,6 @@ function m.flushFiles(scp) for uri in pairs(cachedUris) do files.delRef(uri) end - collectgarbage() - collectgarbage() - -- TODO: wait maillist - collectgarbage 'restart' end ---@param scp scope @@ -493,6 +491,8 @@ end ---@async ---@param scp scope function m.awaitReload(scp) + await.unique('workspace reload:' .. scp:getName()) + await.sleep(0.1) scp:set('ready', false) scp:set('nativeMatcher', nil) scp:set('libraryMatcher', nil) |