diff options
Diffstat (limited to 'script')
73 files changed, 3315 insertions, 1701 deletions
diff --git a/script/client.lua b/script/client.lua index daa9bc52..d86fb4f2 100644 --- a/script/client.lua +++ b/script/client.lua @@ -248,6 +248,7 @@ local function tryModifyRC(uri, finalChanges, create) end local workspace = require 'workspace' local path = workspace.getAbsolutePath(uri, '.luarc.json') + or workspace.getAbsolutePath(uri, '.luarc.jsonc') if not path then return false end @@ -318,7 +319,7 @@ local function tryModifyClientGlobal(finalChanges) end ---@param changes config.change[] ----@param onlyMemory boolean +---@param onlyMemory? boolean function m.setConfig(changes, onlyMemory) local finalChanges = {} for _, change in ipairs(changes) do diff --git a/script/config/loader.lua b/script/config/loader.lua index c53f9399..30711dde 100644 --- a/script/config/loader.lua +++ b/script/config/loader.lua @@ -1,10 +1,10 @@ -local json = require 'json' local proto = require 'proto' local lang = require 'language' local util = require 'utility' local workspace = require 'workspace' local scope = require 'workspace.scope' local inspect = require 'inspect' +local jsonc = require 'jsonc' local function errorMessage(msg) proto.notify('window/showMessage', { @@ -29,7 +29,7 @@ function m.loadRCConfig(uri, filename) scp:set('lastRCConfig', nil) return nil end - local suc, res = pcall(json.decode, buf) + local suc, res = pcall(jsonc.decode, buf) if not suc then errorMessage(lang.script('CONFIG_LOAD_ERROR', res)) return scp:get('lastRCConfig') @@ -55,7 +55,7 @@ function m.loadLocalConfig(uri, filename) end local firstChar = buf:match '%S' if firstChar == '{' then - local suc, res = pcall(json.decode, buf) + local suc, res = pcall(jsonc.decode, buf) if not suc then errorMessage(lang.script('CONFIG_LOAD_ERROR', res)) return scp:get('lastLocalConfig') diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index beff594c..d4c20c60 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -16,10 +16,8 @@ local rpath = require 'workspace.require-path' local lang = require 'language' local lookBackward = require 'core.look-backward' local guide = require 'parser.guide' -local infer = require 'vm.infer' local await = require 'await' local postfix = require 'core.completion.postfix' -local globalMgr = require 'vm.global-manager' local diagnosticModes = { 'disable-next-line', @@ -186,11 +184,8 @@ local function buildFunctionSnip(source, value, oop) end local function buildDetail(source) - if source.type == 'dummy' then - return - end - local types = infer.getInfer(source):view() - local literals = infer.getInfer(source):viewLiterals() + local types = vm.getInfer(source):view() + local literals = vm.getInfer(source):viewLiterals() if literals then return types .. ' = ' .. literals else @@ -228,9 +223,6 @@ end ---@async local function buildDesc(source) - if source.type == 'dummy' then - return - end local desc = markdown() local hover = getHover.get(source) desc:add('md', hover) @@ -310,8 +302,23 @@ local function checkLocal(state, word, position, results) if name:sub(1, 1) == '@' then goto CONTINUE end - if infer.getInfer(source):hasFunction() then - for _, def in ipairs(vm.getDefs(source)) do + if vm.getInfer(source):hasFunction() then + local defs = vm.getDefs(source) + -- make sure `function` is before `doc.type.function` + local orders = {} + for i, def in ipairs(defs) do + if def.type == 'function' then + orders[def] = i - 20000 + elseif def.type == 'doc.type.function' then + orders[def] = i - 10000 + else + orders[def] = i + end + end + table.sort(defs, function (a, b) + return orders[a] < orders[b] + end) + for _, def in ipairs(defs) do if def.type == 'function' or def.type == 'doc.type.function' then local funcLabel = name .. getParams(def, false) @@ -358,7 +365,7 @@ local function checkModule(state, word, position, results) local fileName = path:match '[^/\\]*$' local stemName = fileName:gsub('%..+', '') if not locals[stemName] - and not globalMgr.hasGlobalSets(state.uri, 'variable', stemName) + and not vm.hasGlobalSets(state.uri, 'variable', stemName) and not config.get(state.uri, 'Lua.diagnostics.globals')[stemName] and stemName:match '^[%a_][%w_]*$' and matchKey(word, stemName) then @@ -505,7 +512,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent }) return end - if oop and not infer.getInfer(src):hasFunction() then + if oop and not vm.getInfer(src):hasFunction() then return end local literal = guide.getLiteral(value) @@ -608,14 +615,14 @@ end ---@async local function checkGlobal(state, word, startPos, position, parent, oop, results) local locals = guide.getVisibleLocals(state.ast, position) - local globals = globalMgr.getGlobalSets(state.uri, 'variable') + local globals = vm.getGlobalSets(state.uri, 'variable') checkFieldOfRefs(globals, state, word, startPos, position, parent, oop, results, locals, 'global') end ---@async local function checkField(state, word, start, position, parent, oop, results) if parent.tag == '_ENV' or parent.special == '_G' then - local globals = globalMgr.getGlobalSets(state.uri, 'variable') + local globals = vm.getGlobalSets(state.uri, 'variable') checkFieldOfRefs(globals, state, word, start, position, parent, oop, results) else local refs = vm.getFields(parent) @@ -1124,7 +1131,7 @@ local function checkTypingEnum(state, position, defs, str, results) or def.type == 'doc.type.integer' or def.type == 'doc.type.boolean' then enums[#enums+1] = { - label = infer.viewObject(def), + label = vm.viewObject(def), description = def.comment and def.comment.text, kind = define.CompletionItemKind.EnumMember, } @@ -1413,7 +1420,7 @@ local function tryCallArg(state, position, results) or src.type == 'doc.type.integer' or src.type == 'doc.type.boolean' then enums[#enums+1] = { - label = infer.viewObject(src), + label = vm.viewObject(src), description = src.comment, kind = define.CompletionItemKind.EnumMember, } @@ -1432,7 +1439,7 @@ local function tryCallArg(state, position, results) : string() end enums[#enums+1] = { - label = infer.getInfer(src):view(), + label = vm.getInfer(src):view(), description = description, kind = define.CompletionItemKind.Function, insertText = insertText, @@ -1520,6 +1527,7 @@ local function tryluaDocCate(word, results) 'module', 'async', 'nodiscard', + 'cast', } do if matchKey(word, docType) then results[#results+1] = { @@ -1668,8 +1676,27 @@ local function tryluaDocBySource(state, position, source, results) } end end + return true elseif source.type == 'doc.module' then collectRequireNames('require', state.uri, source.module or '', source, source.smark, position, results) + return true + elseif source.type == 'doc.cast.name' then + local locals = guide.getVisibleLocals(state.ast, position) + for name, loc in util.sortPairs(locals) do + if matchKey(source[1], name) then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Variable, + id = stack(function () ---@async + return { + detail = buildDetail(loc), + description = buildDesc(loc), + } + end), + } + end + end + return true end return false end @@ -1764,6 +1791,22 @@ local function tryluaDocByErr(state, position, err, docState, results) end elseif err.type == 'LUADOC_MISS_MODULE_NAME' then collectRequireNames('require', state.uri, '', docState, nil, position, results) + elseif err.type == 'LUADOC_MISS_LOCAL_NAME' then + local locals = guide.getVisibleLocals(state.ast, position) + for name, loc in util.sortPairs(locals) do + if name ~= '_ENV' then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Variable, + id = stack(function () ---@async + return { + detail = buildDetail(loc), + description = buildDesc(loc), + } + end), + } + end + end end end @@ -1775,14 +1818,14 @@ local function buildluaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = infer.getInfer(arg):view() + args[#args+1] = vm.getInfer(arg):view() end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = infer.getInfer(rtns[n]):view() + returns[n] = vm.getInfer(rtns[n]):view() end end end @@ -1882,6 +1925,9 @@ local function tryComment(state, position, results) local doc = getluaDoc(state, position) if not word then local comment = getComment(state, position) + if not comment then + return + end if comment.type == 'comment.short' or comment.type == 'comment.cshort' then if comment.text == '' then diff --git a/script/core/definition.lua b/script/core/definition.lua index b89aa751..e4868532 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -53,6 +53,7 @@ local accept = { ['doc.alias.name'] = true, ['doc.see.name'] = true, ['doc.see.field'] = true, + ['doc.cast.name'] = true, } local function checkRequire(source, offset) @@ -133,6 +134,9 @@ return function (uri, offset) local defs = vm.getDefs(source) for _, src in ipairs(defs) do + if src.type == 'global' then + goto CONTINUE + end local root = guide.getRoot(src) if not root then goto CONTINUE diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index b9d3c485..c97014fa 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -1,6 +1,7 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' +local vm = require 'vm' return function (uri, callback) local state = files.getState(uri) @@ -23,18 +24,16 @@ return function (uri, callback) } return end - if source.value.type == 'nil' - or source.value.type == 'number' - or source.value.type == 'integer' - or source.value.type == 'boolean' - or source.value.type == 'table' - or source.value.type == 'function' then + local infer = vm.getInfer(source.value) + if not infer:hasClass() + and not infer:hasType 'nil' + and not infer:hasType 'table' + and infer:view('any', uri) ~= 'any' then callback { start = source.value.start, finish = source.value.finish, message = lang.script.DIAG_COSE_NON_OBJECT, } - return end end) end diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua index 8d355aac..d4116b9b 100644 --- a/script/core/diagnostics/duplicate-doc-field.lua +++ b/script/core/diagnostics/duplicate-doc-field.lua @@ -1,6 +1,5 @@ local files = require 'files' local lang = require 'language' -local infer = require 'vm.infer' local function getFieldEventName(doc) if not doc.extends then diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index d95963e4..334fd81a 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -16,7 +16,7 @@ return function (uri, callback) local env = guide.getENV(root) local nilDefs = {} - if not env.ref then + if not env or not env.ref then return end for _, ref in ipairs(env.ref) do diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua index 369a6ba2..b4ae3715 100644 --- a/script/core/diagnostics/init.lua +++ b/script/core/diagnostics/init.lua @@ -105,7 +105,7 @@ end ---@param uri uri ---@param isScopeDiag boolean ---@param response async fun(result: any) ----@param checked async fun(name: string) +---@param checked? async fun(name: string) return function (uri, isScopeDiag, response, checked) local ast = files.getState(uri) if not ast then diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua index d7032c13..d03e8c70 100644 --- a/script/core/diagnostics/lowercase-global.lua +++ b/script/core/diagnostics/lowercase-global.lua @@ -1,8 +1,8 @@ -local files = require 'files' -local guide = require 'parser.guide' -local lang = require 'language' -local config = require 'config' -local vm = require 'vm' +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local config = require 'config' +local vm = require 'vm' local function isDocClass(source) if not source.bindDocs then @@ -30,7 +30,7 @@ return function (uri, callback) guide.eachSourceType(ast.ast, 'setglobal', function (source) local name = guide.getKeyName(source) - if definedGlobal[name] then + if not name or definedGlobal[name] then return end local first = name:match '%w' @@ -44,8 +44,17 @@ return function (uri, callback) if isDocClass(source) then return end - if vm.isGlobalLibraryName(name) then - return + if definedGlobal[name] == nil then + definedGlobal[name] = false + local global = vm.getGlobal('variable', name) + if global then + for _, set in ipairs(global:getSets(uri)) do + if vm.isMetaFile(guide.getUri(set)) then + definedGlobal[name] = true + return + end + end + end end callback { start = source.start, diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua new file mode 100644 index 00000000..698680ca --- /dev/null +++ b/script/core/diagnostics/missing-parameter.lua @@ -0,0 +1,73 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' + +local function countCallArgs(source) + local result = 0 + if not source.args then + return 0 + end + result = result + #source.args + return result +end + +---@return integer +local function countFuncArgs(source) + if not source.args or #source.args == 0 then + return 0 + end + local count = 0 + for i = #source.args, 1, -1 do + local arg = source.args[i] + if arg.type ~= '...' + and not (arg.name and arg.name[1] =='...') + and not vm.compileNode(arg):isNullable() then + return i + end + end + return count +end + +local function getFuncArgs(func) + local funcArgs + local defs = vm.getDefs(func) + for _, def in ipairs(defs) do + if def.type == 'function' + or def.type == 'doc.type.function' then + local args = countFuncArgs(def) + if not funcArgs or args < funcArgs then + funcArgs = args + end + end + end + return funcArgs +end + +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + guide.eachSourceType(state.ast, 'call', function (source) + local callArgs = countCallArgs(source) + + local func = source.node + local funcArgs = getFuncArgs(func) + + if not funcArgs then + return + end + + local delta = callArgs - funcArgs + if delta >= 0 then + return + end + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_MISS_ARGS', funcArgs, callArgs), + } + end) +end diff --git a/script/core/diagnostics/need-check-nil.lua b/script/core/diagnostics/need-check-nil.lua new file mode 100644 index 00000000..98fdfd08 --- /dev/null +++ b/script/core/diagnostics/need-check-nil.lua @@ -0,0 +1,39 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' + +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + guide.eachSourceType(state.ast, 'getlocal', function (src) + local checkNil + local nxt = src.next + if nxt then + if nxt.type == 'getfield' + or nxt.type == 'getmethod' + or nxt.type == 'getindex' + or nxt.type == 'call' then + checkNil = true + end + end + local call = src.parent + if call and call.type == 'call' and call.node == src then + checkNil = true + end + if not checkNil then + return + end + local node = vm.compileNode(src) + if node:hasFalsy() then + callback { + start = src.start, + finish = src.finish, + message = lang.script('DIAG_NEED_CHECK_NIL'), + } + end + end) +end diff --git a/script/core/diagnostics/no-unknown.lua b/script/core/diagnostics/no-unknown.lua index 2199b6a8..48aab5da 100644 --- a/script/core/diagnostics/no-unknown.lua +++ b/script/core/diagnostics/no-unknown.lua @@ -1,7 +1,7 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' -local infer = require 'vm.infer' +local vm = require 'vm' return function (uri, callback) local ast = files.getState(uri) @@ -20,7 +20,7 @@ return function (uri, callback) and source.type ~= 'tableindex' then return end - if infer.getInfer(source):view() == 'unknown' then + if vm.getInfer(source):view() == 'unknown' then callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/not-yieldable.lua b/script/core/diagnostics/not-yieldable.lua index 0588bbde..a1c84276 100644 --- a/script/core/diagnostics/not-yieldable.lua +++ b/script/core/diagnostics/not-yieldable.lua @@ -3,7 +3,6 @@ local await = require 'await' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' -local infer = require 'vm.infer' local function isYieldAble(defs, i) local hasFuncDef @@ -12,7 +11,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if infer.getInfer(arg):hasType 'any' + if vm.getInfer(arg):hasType 'any' or vm.isAsync(arg, true) or arg.type == '...' then return true @@ -23,7 +22,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if infer.getInfer(arg.extends):hasType 'any' + if vm.getInfer(arg.extends):hasType 'any' or vm.isAsync(arg.extends, true) then return true end diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index 4adf169e..41781df8 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -2,7 +2,6 @@ local files = require 'files' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' -local define = require 'proto.define' local function countCallArgs(source) local result = 0 @@ -14,64 +13,40 @@ local function countCallArgs(source) end local function countFuncArgs(source) - local result = 0 if not source.args or #source.args == 0 then - return result - end - if source.args[#source.args].type == '...' then - return math.maxinteger - end - result = result + #source.args - return result -end - -local function countOverLoadArgs(source, doc) - local result = 0 - local func = doc.overload - if not func.args or #func.args == 0 then - return result + return 0 end - if func.args[#func.args].type == '...' then + local lastArg = source.args[#source.args] + if lastArg.type == '...' + or (lastArg.name and lastArg.name[1] == '...') then return math.maxinteger + else + return #source.args end - result = result + #func.args - return result end local function getFuncArgs(func) local funcArgs local defs = vm.getDefs(func) for _, def in ipairs(defs) do - if def.value then - def = def.value - end - if def.type == 'function' then + if def.type == 'function' + or def.type == 'doc.type.function' then local args = countFuncArgs(def) if not funcArgs or args > funcArgs then funcArgs = args end - if def.bindDocs then - for _, doc in ipairs(def.bindDocs) do - if doc.type == 'doc.overload' then - args = countOverLoadArgs(def, doc) - if not funcArgs or args > funcArgs then - funcArgs = args - end - end - end - end end end return funcArgs end return function (uri, callback) - local ast = files.getState(uri) - if not ast then + local state = files.getState(uri) + if not state then return end - guide.eachSourceType(ast.ast, 'call', function (source) + guide.eachSourceType(state.ast, 'call', function (source) local callArgs = countCallArgs(source) if callArgs == 0 then return @@ -97,7 +72,6 @@ return function (uri, callback) callback { start = arg.start, finish = arg.finish, - tags = { define.DiagnosticTag.Unnecessary }, message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs) } end diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index 025c217a..41fcda48 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -3,7 +3,6 @@ local vm = require 'vm' local lang = require 'language' local guide = require 'parser.guide' local await = require 'await' -local infer = require 'vm.infer' local skipCheckClass = { ['unknown'] = true, @@ -35,7 +34,7 @@ return function (uri, callback) local node = src.node if node then local ok - for view in infer.getInfer(node):eachView() do + for view in vm.getInfer(node):eachView() do if not skipCheckClass[view] then ok = true break diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index 139fa74f..bd0aae69 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -4,7 +4,6 @@ local lang = require 'language' local config = require 'config' local guide = require 'parser.guide' local await = require 'await' -local globalMgr = require 'vm.global-manager' local requireLike = { ['include'] = true, @@ -41,7 +40,7 @@ return function (uri, callback) return end if cache[key] == nil then - cache[key] = globalMgr.hasGlobalSets(uri, 'variable', key) + cache[key] = vm.hasGlobalSets(uri, 'variable', key) end if cache[key] then return diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua index 79cb16e2..813ac804 100644 --- a/script/core/diagnostics/unused-function.lua +++ b/script/core/diagnostics/unused-function.lua @@ -18,75 +18,107 @@ local function isToBeClosed(source) return false end ----@async -return function (uri, callback) - local ast = files.getState(uri) - if not ast then - return +---@param source parser.object +local function isValidFunction(source) + if not source then + return false + end + if source.type == 'main' then + return false + end + local parent = source.parent + if not parent then + return false + end + if parent.type ~= 'local' + and parent.type ~= 'setlocal' then + return false + end + if isToBeClosed(parent) then + return false end + return true +end - local cache = {} +---@async +local function collect(ast, white, roots, links) ---@async - local function checkFunction(source) - if not source then + guide.eachSourceType(ast, 'function', function (src) + await.delay() + if not isValidFunction(src) then return end - if cache[source] ~= nil then - return cache[source] - end - cache[source] = false - local parent = source.parent - if not parent then - return false - end - if parent.type ~= 'local' - and parent.type ~= 'setlocal' then - return false - end - if isToBeClosed(parent) then - return false + local loc = src.parent + if loc.type == 'setlocal' then + loc = loc.node end - await.delay() - if parent.type == 'setlocal' then - parent = parent.node - end - local refs = parent.ref - local hasGet - if refs then - for _, src in ipairs(refs) do - if guide.isGet(src) then - local func = guide.getParentFunction(src) - if not checkFunction(func) then - hasGet = true - break - end + for _, ref in ipairs(loc.ref or {}) do + if ref.type == 'getlocal' then + local func = guide.getParentFunction(ref) + if not isValidFunction(func) or roots[func] then + roots[src] = true + return end + if not links[func] then + links[func] = {} + end + links[func][#links[func]+1] = src end end - if not hasGet then - if client.isVSCode() then - callback { - start = source.start, - finish = source.finish, - tags = { define.DiagnosticTag.Unnecessary }, - message = lang.script.DIAG_UNUSED_FUNCTION, - } - else - callback { - start = source.keyword[1], - finish = source.keyword[2], - tags = { define.DiagnosticTag.Unnecessary }, - message = lang.script.DIAG_UNUSED_FUNCTION, - } - end - cache[source] = true - return true - end - return false + white[src] = true + end) + + return white, roots, links +end + +local function turnBlack(source, black, white, links) + if black[source] then + return end + black[source] = true + white[source] = nil + for _, link in ipairs(links[source] or {}) do + turnBlack(link, black, white, links) + end +end - -- 只检查局部函数 - guide.eachSourceType(ast.ast, 'function', function (source) ---@async - checkFunction(source) - end) +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + if vm.isMetaFile(uri) then + return + end + + local black = {} + local white = {} + local roots = {} + local links = {} + + collect(state.ast, white, roots, links) + + for source in pairs(roots) do + turnBlack(source, black, white, links) + end + + for source in pairs(white) do + if client.isVSCode() then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_FUNCTION, + } + else + callback { + start = source.keyword[1], + finish = source.keyword[2], + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_FUNCTION, + } + end + end end diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua index 2e07e1ee..ce033cf3 100644 --- a/script/core/diagnostics/unused-vararg.lua +++ b/script/core/diagnostics/unused-vararg.lua @@ -2,6 +2,7 @@ local files = require 'files' local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' +local vm = require 'vm' return function (uri, callback) local ast = files.getState(uri) @@ -9,6 +10,10 @@ return function (uri, callback) return end + if vm.isMetaFile(uri) then + return + end + guide.eachSourceType(ast.ast, 'function', function (source) local args = source.args if not args then diff --git a/script/core/formatting.lua b/script/core/formatting.lua index 49da6861..b52854a4 100644 --- a/script/core/formatting.lua +++ b/script/core/formatting.lua @@ -3,7 +3,7 @@ local files = require("files") local log = require("log") return function(uri, options) - local text = files.getText(uri) + local text = files.getOriginText(uri) local ast = files.getState(uri) local status, formattedText = codeFormat.format(uri, text, options) diff --git a/script/core/hint.lua b/script/core/hint.lua index f6774d2a..f97cdcec 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -1,5 +1,4 @@ local files = require 'files' -local infer = require 'vm.infer' local vm = require 'vm' local config = require 'config' local guide = require 'parser.guide' @@ -39,7 +38,7 @@ local function typeHint(uri, results, start, finish) end end await.delay() - local view = infer.getInfer(source):view() + local view = vm.getInfer(source):view() if view == 'any' or view == 'unknown' or view == 'nil' then diff --git a/script/core/hover/args.lua b/script/core/hover/args.lua index a53136b0..c485d9b9 100644 --- a/script/core/hover/args.lua +++ b/script/core/hover/args.lua @@ -1,17 +1,5 @@ local guide = require 'parser.guide' -local infer = require 'vm.infer' - -local function optionalArg(arg) - if not arg.bindDocs then - return false - end - local name = arg[1] - for _, doc in ipairs(arg.bindDocs) do - if doc.type == 'doc.param' and doc.param[1] == name then - return doc.optional - end - end -end +local vm = require 'vm' local function asFunction(source) local args = {} @@ -21,7 +9,7 @@ local function asFunction(source) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(infer.getInfer(parent.node):view 'any') + args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view 'any') end if source.args then for i = 1, #source.args do @@ -31,18 +19,25 @@ local function asFunction(source) end local name = arg.name or guide.getKeyName(arg) if name then + local argNode = vm.compileNode(arg) + local optional + if argNode:isOptional() then + optional = true + argNode = argNode:copy() + argNode:removeOptional() + end args[#args+1] = ('%s%s: %s'):format( name, - optionalArg(arg) and '?' or '', - infer.getInfer(arg):view 'any' + optional and '?' or '', + vm.getInfer(argNode):view('any', guide.getUri(source)) ) elseif arg.type == '...' then args[#args+1] = ('%s: %s'):format( '...', - infer.getInfer(arg):view 'any' + vm.getInfer(arg):view 'any' ) else - args[#args+1] = ('%s'):format(infer.getInfer(arg):view 'any') + args[#args+1] = ('%s'):format(vm.getInfer(arg):view 'any') end ::CONTINUE:: end @@ -61,7 +56,7 @@ local function asDocFunction(source) args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - arg.extends and infer.getInfer(arg.extends):view 'any' or 'any' + arg.extends and vm.getInfer(arg.extends):view 'any' or 'any' ) end return args diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index 03f6128a..e9267c0f 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -6,7 +6,6 @@ local lang = require 'language' local util = require 'utility' local guide = require 'parser.guide' local rpath = require 'workspace.require-path' -local infer = require 'vm.infer' local function collectRequire(mode, literal, uri) local result, searchers @@ -153,7 +152,7 @@ local function buildEnumChunk(docType, name) local types = {} local lines = {} for _, tp in ipairs(vm.getDefs(docType)) do - types[#types+1] = infer.getInfer(tp):view() + types[#types+1] = vm.getInfer(tp):view() if tp.type == 'doc.type.string' or tp.type == 'doc.type.integer' or tp.type == 'doc.type.boolean' then @@ -175,7 +174,7 @@ local function buildEnumChunk(docType, name) (enum.default and '->') or (enum.additional and '+>') or ' |', - infer.viewObject(enum) + vm.viewObject(enum) ) if enum.comment then local first = true diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index bc2f40eb..7231944a 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -5,7 +5,6 @@ local getDesc = require 'core.hover.description' local util = require 'utility' local findSource = require 'core.find-source' local markdown = require 'provider.markdown' -local infer = require 'vm.infer' local guide = require 'parser.guide' ---@async @@ -40,9 +39,24 @@ local function getHover(source) end local oop - if infer.getInfer(source):view() == 'function' then + if vm.getInfer(source):view() == 'function' then + local defs = vm.getDefs(source) + -- make sure `function` is before `doc.type.function` + local orders = {} + for i, def in ipairs(defs) do + if def.type == 'function' then + orders[def] = i - 20000 + elseif def.type == 'doc.type.function' then + orders[def] = i - 10000 + else + orders[def] = i + end + end + table.sort(defs, function (a, b) + return orders[a] < orders[b] + end) local hasFunc - for _, def in ipairs(vm.getDefs(source)) do + for _, def in ipairs(defs) do if guide.isOOP(def) then oop = true end @@ -58,6 +72,9 @@ local function getHover(source) else addHover(source, true, oop) for _, def in ipairs(vm.getDefs(source)) do + if def.type == 'global' then + goto CONTINUE + end if guide.isOOP(def) then oop = true end @@ -67,6 +84,7 @@ local function getHover(source) isFunction = true end addHover(def, isFunction, oop) + ::CONTINUE:: end end diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 8224e9d3..2bbfe806 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -2,7 +2,6 @@ local buildName = require 'core.hover.name' local buildArgs = require 'core.hover.args' local buildReturn = require 'core.hover.return' local buildTable = require 'core.hover.table' -local infer = require 'vm.infer' local vm = require 'vm' local util = require 'utility' local lang = require 'language' @@ -34,7 +33,7 @@ local function asDocTypeName(source) return '(class) ' .. doc.class[1] end if doc.type == 'doc.alias' then - return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', infer.getInfer(doc.extends):view()) + return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view()) end end end @@ -42,7 +41,7 @@ end ---@async local function asValue(source, title) local name = buildName(source, false) or '' - local ifr = infer.getInfer(source) + local ifr = vm.getInfer(source) local type = ifr:view() local literal = ifr:viewLiterals() local cont = buildTable(source) @@ -140,7 +139,7 @@ local function asDocFieldName(source) break end end - local view = infer.getInfer(source.extends):view() + local view = vm.getInfer(source.extends):view() if not class then return ('(field) ?.%s: %s'):format(name, view) end @@ -180,7 +179,7 @@ local function asNumber(source) if not text then return nil end - local raw = text:sub(source.start, source.finish) + local raw = text:sub(source.start + 1, source.finish) if not raw or not raw:find '[^%-%d%.]' then return nil end diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua index 905c5ec7..f8473638 100644 --- a/script/core/hover/name.lua +++ b/script/core/hover/name.lua @@ -1,5 +1,5 @@ -local infer = require 'vm.infer' local guide = require 'parser.guide' +local vm = require 'vm' local buildName @@ -19,7 +19,7 @@ end local function asField(source, oop) local class if source.node.type ~= 'getglobal' then - class = infer.getInfer(source.node):viewClass() + class = vm.getInfer(source.node):viewClass() end local node = class or buildName(source.node, false) diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index 77710148..3d8a94a5 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,5 +1,3 @@ -local infer = require 'vm.infer' -local guide = require 'parser.guide' local vm = require 'vm.vm' ---@param source parser.object @@ -65,10 +63,9 @@ local function asFunction(source) local rtn = vm.getReturnOfFunction(source, i) local doc = docs[i] local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ') - local text = ('%s%s%s'):format( + local text = ('%s%s'):format( name or '', - infer.getInfer(rtn):view(), - doc and doc.optional and '?' or '' + vm.getInfer(rtn):view() ) if i == 1 then returns[i] = (' -> %s'):format(text) @@ -86,10 +83,7 @@ local function asDocFunction(source) end local returns = {} for i, rtn in ipairs(source.returns) do - local rtnText = ('%s%s'):format( - infer.getInfer(rtn):view(), - rtn.optional and '?' or '' - ) + local rtnText = vm.getInfer(rtn):view() if i == 1 then returns[#returns+1] = (' -> %s'):format(rtnText) else diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index 31036edd..16874101 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -1,7 +1,6 @@ local vm = require 'vm' local util = require 'utility' local config = require 'config' -local infer = require 'vm.infer' local await = require 'await' local guide = require 'parser.guide' @@ -16,22 +15,34 @@ local function formatKey(key) return ('[%s]'):format(key) end -local function buildAsHash(keys, typeMap, literalMap, optMap, reachMax) +---@param uri uri +---@param keys string[] +---@param nodeMap table<string, vm.node> +---@param reachMax integer +local function buildAsHash(uri, keys, nodeMap, reachMax) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local typeView = typeMap[key] - local literalView = literalMap[key] + local node = nodeMap[key] + local isOptional = node:isOptional() + if isOptional then + node = node:copy() + node:removeOptional() + end + local ifr = vm.getInfer(node) + local typeView = ifr:view('unknown', uri) + local literalView = ifr:viewLiterals() if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( formatKey(key), - optMap[key] and '?' or '', + isOptional and '?' or '', typeView, - literalView) + literalView + ) else lines[#lines+1] = (' %s%s: %s,'):format( formatKey(key), - optMap[key] and '?' or '', + isOptional and '?' or '', typeView ) end @@ -43,26 +54,40 @@ local function buildAsHash(keys, typeMap, literalMap, optMap, reachMax) return table.concat(lines, '\n') end -local function buildAsConst(keys, typeMap, literalMap, optMap, reachMax) +---@param uri uri +---@param keys string[] +---@param nodeMap table<string, vm.node> +---@param reachMax integer +local function buildAsConst(uri, keys, nodeMap, reachMax) + local literalMap = {} + for _, key in ipairs(keys) do + literalMap[key] = vm.getInfer(nodeMap[key]):viewLiterals() + end table.sort(keys, function (a, b) return tonumber(literalMap[a]) < tonumber(literalMap[b]) end) local lines = {} lines[#lines+1] = '{' for _, key in ipairs(keys) do - local typeView = typeMap[key] + local node = nodeMap[key] + local isOptional = node:isOptional() + if isOptional then + node = node:copy() + node:removeOptional() + end + local typeView = vm.getInfer(node):view('unknown', uri) local literalView = literalMap[key] if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( formatKey(key), - optMap[key] and '?' or '', + isOptional and '?' or '', typeView, literalView ) else lines[#lines+1] = (' %s%s: %s,'):format( formatKey(key), - optMap[key] and '?' or '', + isOptional and '?' or '', typeView ) end @@ -102,6 +127,19 @@ local function getKeyMap(fields) if ta == 'boolean' then return a == true end + if ta == 'string' then + if a:sub(1, 1) == '_' then + if b:sub(1, 1) == '_' then + return a < b + else + return false + end + elseif b:sub(1, 1) == '_' then + return true + else + return a < b + end + end return a < b else return tsa < tsb @@ -110,48 +148,25 @@ local function getKeyMap(fields) return keys, map end -local function getOptMap(fields, keyMap) - local optMap = {} - for _, field in ipairs(fields) do - if field.type == 'doc.field' then - if field.optional then - local key = vm.getKeyName(field) - if keyMap[key] then - optMap[key] = true - end - end - end - if field.type == 'doc.type.field' then - if field.optional then - local key = vm.getKeyName(field) - if keyMap[key] then - optMap[key] = true - end - end - end - end - return optMap -end - ---@async -local function getInferMap(fields, keyMap) - ---@type table<string, vm.infer> - local inferMap = {} +local function getNodeMap(fields, keyMap) + ---@type table<string, vm.node> + local nodeMap = {} for _, field in ipairs(fields) do local key = vm.getKeyName(field) if not keyMap[key] then goto CONTINUE end await.delay() - local ifr = infer.getInfer(field) - if inferMap[key] then - inferMap[key] = inferMap[key]:merge(ifr) + local node = vm.compileNode(field) + if nodeMap[key] then + nodeMap[key]:merge(node) else - inferMap[key] = ifr + nodeMap[key] = node:copy() end ::CONTINUE:: end - return inferMap + return nodeMap end ---@async @@ -163,7 +178,7 @@ return function (source) return nil end - for view in infer.getInfer(source):eachView() do + for view in vm.getInfer(source):eachView() do if view == 'string' or vm.isSubType(uri, view, 'string') then return nil @@ -184,19 +199,14 @@ return function (source) end end - local optMap = getOptMap(fields, map) - local inferMap = getInferMap(fields, map) + local nodeMap = getNodeMap(fields, map) - local typeMap = {} - local literalMap = {} local isConsts = true for i = 1, #keys do await.delay() local key = keys[i] - - typeMap[key] = inferMap[key]:view('unknown', uri) - literalMap[key] = inferMap[key]:viewLiterals() - if not tonumber(literalMap[key]) then + local literal = vm.getInfer(nodeMap[key]):viewLiterals() + if not tonumber(literal) then isConsts = false end end @@ -204,9 +214,9 @@ return function (source) local result if isConsts then - result = buildAsConst(keys, typeMap, literalMap, optMap, reachMax) + result = buildAsConst(uri, keys, nodeMap, reachMax) else - result = buildAsHash(keys, typeMap, literalMap, optMap, reachMax) + result = buildAsHash(uri, keys, nodeMap, reachMax) end --if timeUp then diff --git a/script/core/look-backward.lua b/script/core/look-backward.lua index eea089bc..eeee6017 100644 --- a/script/core/look-backward.lua +++ b/script/core/look-backward.lua @@ -2,7 +2,8 @@ local m = {} --- 是否是空白符 ----@param inline boolean # 必须在同一行中(排除换行符) +---@param char string +---@param inline? boolean # 必须在同一行中(排除换行符) function m.isSpace(char, inline) if inline then if char == ' ' @@ -21,7 +22,9 @@ function m.isSpace(char, inline) end --- 跳过空白符 ----@param inline boolean # 必须在同一行中(排除换行符) +---@param text string +---@param offset integer +---@param inline? boolean # 必须在同一行中(排除换行符) function m.skipSpace(text, offset, inline) for i = offset, 1, -1 do local char = text:sub(i, i) diff --git a/script/core/matchkey.lua b/script/core/matchkey.lua index 3c6a54a8..4db9d764 100644 --- a/script/core/matchkey.lua +++ b/script/core/matchkey.lua @@ -59,7 +59,7 @@ end ---@param input string ---@param other string ----@param fast boolean +---@param fast? boolean ---@return boolean isMatch ---@return number deviation return function (input, other, fast) diff --git a/script/core/rangeformatting.lua b/script/core/rangeformatting.lua index ccf2d21f..f64e9cda 100644 --- a/script/core/rangeformatting.lua +++ b/script/core/rangeformatting.lua @@ -4,7 +4,7 @@ local log = require("log") local converter = require("proto.converter") return function(uri, range, options) - local text = files.getText(uri) + local text = files.getOriginText(uri) local status, formattedText, startLine, endLine = codeFormat.range_format( uri, text, range.start.line, range["end"].line, options) diff --git a/script/core/rename.lua b/script/core/rename.lua index ec21e87c..7599fad6 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -3,7 +3,6 @@ local vm = require 'vm' local util = require 'utility' local findSource = require 'core.find-source' local guide = require 'parser.guide' -local globalMgr = require 'vm.global-manager' local Forcing @@ -191,7 +190,7 @@ end ---@async local function ofGlobal(source, newname, callback) local key = guide.getKeyName(source) - local global = globalMgr.getGlobal('variable', key) + local global = vm.getGlobal('variable', key) if not global then return end @@ -214,7 +213,7 @@ end ---@async local function ofDocTypeName(source, newname, callback) local oldname = source[1] - local global = globalMgr.getGlobal('type', oldname) + local global = vm.getGlobal('type', oldname) if not global then return end diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 568bb222..33449013 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -5,7 +5,6 @@ local vm = require 'vm' local util = require 'utility' local guide = require 'parser.guide' local converter = require 'proto.converter' -local infer = require 'vm.infer' local config = require 'config' local linkedTable = require 'linked-table' @@ -16,8 +15,24 @@ local Care = util.switch() if not options.variable then return end - local isLib = vm.isGlobalLibraryName(source[1]) - local isFunc = infer.getInfer(source):hasFunction() + + local name = source[1] + local isLib = options.libGlobals[name] + if isLib == nil then + isLib = false + local global = vm.getGlobal('variable', name) + if global then + local uri = guide.getUri(source) + for _, set in ipairs(global:getSets(uri)) do + if vm.isMetaFile(guide.getUri(set)) then + isLib = true + break + end + end + end + options.libGlobals[name] = isLib + end + local isFunc = vm.getInfer(source):hasFunction() local type = isFunc and define.TokenTypes['function'] or define.TokenTypes.variable local modifier = isLib and define.TokenModifiers.defaultLibrary or define.TokenModifiers.static @@ -66,7 +81,7 @@ local Care = util.switch() return end end - if infer.getInfer(source):hasFunction() then + if vm.getInfer(source):hasFunction() then results[#results+1] = { start = source.start, finish = source.finish, @@ -165,27 +180,23 @@ local Care = util.switch() -- 5. Class declaration -- only search this local if loc.bindDocs then - for i = #loc.bindDocs, 1, -1 do - local doc = loc.bindDocs[i] - if doc.type == 'doc.type' then - break - end - if doc.type == "doc.class" and doc.bindSources then - for _, src in ipairs(doc.bindSources) do - if src == loc then - results[#results+1] = { - start = source.start, - finish = source.finish, - type = define.TokenTypes.class, - } - return - end + local isParam = source.parent.type == 'funcargs' + or source.parent.type == 'in' + if not isParam then + for _, doc in ipairs(loc.bindDocs) do + if doc.type == 'doc.class' then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.class, + } + return end end end end -- 6. References to other functions - if infer.getInfer(loc):hasFunction() then + if vm.getInfer(loc):hasFunction() then results[#results+1] = { start = source.start, finish = source.finish, @@ -656,6 +667,14 @@ local Care = util.switch() type = define.TokenTypes.keyword, } end) + : case 'doc.cast.name' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.variable, + } + end) local function buildTokens(uri, results) local tokens = {} @@ -773,24 +792,25 @@ end ---@async return function (uri, start, finish) + local results = {} if not config.get(uri, 'Lua.semantic.enable') then - return nil + return results end local state = files.getState(uri) if not state then - return nil + return results end local options = { uri = uri, state = state, text = files.getText(uri), + libGlobals = {}, variable = config.get(uri, 'Lua.semantic.variable'), annotation = config.get(uri, 'Lua.semantic.annotation'), keyword = config.get(uri, 'Lua.semantic.keyword'), } - local results = {} guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async Care(source.type, source, options, results) await.delay() @@ -798,27 +818,26 @@ return function (uri, start, finish) for _, comm in ipairs(state.comms) do if start <= comm.start and comm.finish <= finish then - if comm.type == 'comment.short' then - local head = comm.text:match '^%-%s*[@|]' - if head then - results[#results+1] = { - start = comm.start, - finish = comm.start + #head + 1, - type = define.TokenTypes.comment, - } - results[#results+1] = { - start = comm.start + #head + 1, - finish = comm.start + #head + 2 + #comm.text:match('%S*', #head + 1), - type = define.TokenTypes.keyword, - modifieres = define.TokenModifiers.documentation, - } + local headPos = (comm.type == 'comment.short' and comm.text:match '^%-%s*[@|]()') + or (comm.type == 'comment.long' and comm.text:match '^@()') + if headPos then + local atPos + if comm.type == 'comment.short' then + atPos = headPos + 2 else - results[#results+1] = { - start = comm.start, - finish = comm.finish, - type = define.TokenTypes.comment, - } + atPos = headPos + #comm.mark end + results[#results+1] = { + start = comm.start, + finish = comm.start + atPos - 2, + type = define.TokenTypes.comment, + } + results[#results+1] = { + start = comm.start + atPos - 2, + finish = comm.start + atPos - 1 + #comm.text:match('%S*', headPos), + type = define.TokenTypes.keyword, + modifieres = define.TokenModifiers.documentation, + } else results[#results+1] = { start = comm.start, @@ -830,7 +849,7 @@ return function (uri, start, finish) end if #results == 0 then - return {} + return results end results = solveMultilineAndOverlapping(state, results) diff --git a/script/core/signature.lua b/script/core/signature.lua index ab7268dd..025e70b7 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -41,6 +41,9 @@ end ---@async local function makeOneSignature(source, oop, index) local label = hoverLabel(source, oop) + if not label then + return nil + end -- 去掉返回值 label = label:gsub('%s*->.+', '') local params = {} diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua index 92f81997..d8434c8c 100644 --- a/script/core/type-definition.lua +++ b/script/core/type-definition.lua @@ -3,7 +3,6 @@ local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' local guide = require 'parser.guide' -local infer = require 'vm.infer' local rpath = require 'workspace.require-path' local function sortResults(results) diff --git a/script/doctor.lua b/script/doctor.lua index 91a7e4b8..87cdcfcb 100644 --- a/script/doctor.lua +++ b/script/doctor.lua @@ -175,6 +175,9 @@ m.snapshot = private(function () exclude[o] = true end end + ---@generic T + ---@param o T + ---@return T local function private(o) if not o then return nil diff --git a/script/encoder/init.lua b/script/encoder/init.lua index 0011265a..3c8a58e0 100644 --- a/script/encoder/init.lua +++ b/script/encoder/init.lua @@ -10,9 +10,9 @@ local utf16be = utf16('be', utf8.codepoint '�') local m = {} ---@param encoding encoder.encoding ----@param s string ----@param i integer ----@param j integer +---@param s string +---@param i? integer +---@param j? integer function m.len(encoding, s, i, j) i = i or 1 j = j or #s @@ -33,9 +33,9 @@ function m.len(encoding, s, i, j) end ---@param encoding encoder.encoding ----@param s string ----@param n integer ----@param i integer +---@param s string +---@param n integer +---@param i? integer function m.offset(encoding, s, n, i) i = i or 1 if encoding == 'utf16' diff --git a/script/files.lua b/script/files.lua index d16474fd..22c9ae31 100644 --- a/script/files.lua +++ b/script/files.lua @@ -165,8 +165,8 @@ end --- 设置文件文本 ---@param uri uri ---@param text string ----@param isTrust boolean ----@param callback function +---@param isTrust? boolean +---@param callback? function function m.setText(uri, text, isTrust, callback) if not text then return diff --git a/script/fs-utility.lua b/script/fs-utility.lua index c845c769..08aae98a 100644 --- a/script/fs-utility.lua +++ b/script/fs-utility.lua @@ -281,12 +281,8 @@ local function fsIsDirectory(path, option) if path.type == 'dummy' then return path:isDirectory() end - local suc, res = pcall(fs.is_directory, path) - if not suc then - option.err[#option.err+1] = res - return false - end - return res + local status = fs.symlink_status(path):type() + return status == 'directory' end local function fsPairs(path, option) @@ -616,9 +612,10 @@ end function m.scanDirectory(dir, callback) for fullpath in fs.pairs(dir) do - if fs.is_directory(fullpath) then + local status = fs.symlink_status(fullpath):type() + if status == 'directory' then m.scanDirectory(fullpath, callback) - else + elseif status == 'regular' then callback(fullpath) end end diff --git a/script/glob/gitignore.lua b/script/glob/gitignore.lua index 09be1415..4dad2747 100644 --- a/script/glob/gitignore.lua +++ b/script/glob/gitignore.lua @@ -163,7 +163,7 @@ function mt:getRelativePath(path) return path end ----@param callback async fun() +---@param callback async fun(path: string) ---@async function mt:scan(path, callback) local files = {} diff --git a/script/jsonc.lua b/script/jsonc.lua new file mode 100644 index 00000000..0361d99b --- /dev/null +++ b/script/jsonc.lua @@ -0,0 +1,603 @@ +local type = type +local next = next +local error = error +local tonumber = tonumber +local tostring = tostring +local table_concat = table.concat +local table_sort = table.sort +local string_char = string.char +local string_byte = string.byte +local string_find = string.find +local string_match = string.match +local string_gsub = string.gsub +local string_sub = string.sub +local string_rep = string.rep +local string_format = string.format +local setmetatable = setmetatable +local getmetatable = getmetatable +local huge = math.huge +local tiny = -huge + +local utf8_char +local math_type + +if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then + local math_floor = math.floor + function utf8_char(c) + if c <= 0x7f then + return string_char(c) + elseif c <= 0x7ff then + return string_char(math_floor(c / 64) + 192, c % 64 + 128) + elseif c <= 0xffff then + return string_char( + math_floor(c / 4096) + 224, + math_floor(c % 4096 / 64) + 128, + c % 64 + 128 + ) + elseif c <= 0x10ffff then + return string_char( + math_floor(c / 262144) + 240, + math_floor(c % 262144 / 4096) + 128, + math_floor(c % 4096 / 64) + 128, + c % 64 + 128 + ) + end + error(string.format("invalid UTF-8 code '%x'", c)) + end + function math_type(v) + if v >= -2147483648 and v <= 2147483647 and math_floor(v) == v then + return "integer" + end + return "float" + end +else + utf8_char = utf8.char + math_type = math.type +end + +local json = {} + +json.supportSparseArray = true + +local objectMt = {} + +function json.createEmptyObject() + return setmetatable({}, objectMt) +end + +function json.isObject(t) + if t[1] ~= nil then + return false + end + return next(t) ~= nil or getmetatable(t) == objectMt +end + +if debug and debug.upvalueid then + -- Generate a lightuserdata + json.null = debug.upvalueid(json.createEmptyObject, 1) +else + json.null = function() end +end + +-- json.encode -- + +local statusVisited +local statusBuilder +local statusDep +local statusOpt + +local defaultOpt = { + newline = "", + indent = "", +} +defaultOpt.__index = defaultOpt + +local encode_map = {} + +local encode_escape_map = { + [ "\"" ] = "\\\"", + [ "\\" ] = "\\\\", + [ "/" ] = "\\/", + [ "\b" ] = "\\b", + [ "\f" ] = "\\f", + [ "\n" ] = "\\n", + [ "\r" ] = "\\r", + [ "\t" ] = "\\t", +} + +local decode_escape_set = {} +local decode_escape_map = {} +for k, v in next, encode_escape_map do + decode_escape_map[v] = k + decode_escape_set[string_byte(v, 2)] = true +end + +for i = 0, 31 do + local c = string_char(i) + if not encode_escape_map[c] then + encode_escape_map[c] = string_format("\\u%04x", i) + end +end + +encode_map["nil"] = function () + return "null" +end + +local function encode_string(v) + return string_gsub(v, '[%z\1-\31\\"]', encode_escape_map) +end + +local function convertreal(v) + local g = string_format('%.16g', v) + if tonumber(g) == v then + return g + end + return string_format('%.17g', v) +end + +if string_match(tostring(1/2), "%p") == "," then + local _convertreal = convertreal + function convertreal(v) + return string_gsub(_convertreal(v), ',', '.') + end +end + +function encode_map.number(v) + if v ~= v or v <= tiny or v >= huge then + error("unexpected number value '" .. tostring(v) .. "'") + end + if math_type(v) == "integer" then + return string_format('%d', v) + end + return convertreal(v) +end + +function encode_map.boolean(v) + if v then + return "true" + else + return "false" + end +end + +local function encode_unexpected(v) + if v == json.null then + return "null" + else + error("unexpected type '"..type(v).."'") + end +end +encode_map[ "function" ] = encode_unexpected +encode_map[ "userdata" ] = encode_unexpected +encode_map[ "thread" ] = encode_unexpected + +local function encode_newline() + statusBuilder[#statusBuilder+1] = statusOpt.newline..string_rep(statusOpt.indent, statusDep) +end + +local function encode(v) + local res = encode_map[type(v)](v) + statusBuilder[#statusBuilder+1] = res +end + +function encode_map.string(v) + statusBuilder[#statusBuilder+1] = '"' + statusBuilder[#statusBuilder+1] = encode_string(v) + return '"' +end + +function encode_map.table(t) + local first_val = next(t) + if first_val == nil then + if getmetatable(t) == objectMt then + return "{}" + else + return "[]" + end + end + if statusVisited[t] then + error("circular reference") + end + statusVisited[t] = true + if type(first_val) == 'string' then + local key = {} + for k in next, t do + if type(k) ~= "string" then + error("invalid table: mixed or invalid key types") + end + key[#key+1] = k + end + table_sort(key) + statusBuilder[#statusBuilder+1] = "{" + statusDep = statusDep + 1 + encode_newline() + local k = key[1] + statusBuilder[#statusBuilder+1] = '"' + statusBuilder[#statusBuilder+1] = encode_string(k) + statusBuilder[#statusBuilder+1] = '": ' + encode(t[k]) + for i = 2, #key do + local k = key[i] + statusBuilder[#statusBuilder+1] = "," + encode_newline() + statusBuilder[#statusBuilder+1] = '"' + statusBuilder[#statusBuilder+1] = encode_string(k) + statusBuilder[#statusBuilder+1] = '": ' + encode(t[k]) + end + statusDep = statusDep - 1 + encode_newline() + statusVisited[t] = nil + return "}" + elseif json.supportSparseArray then + local max = 0 + for k in next, t do + if math_type(k) ~= "integer" or k <= 0 then + error("invalid table: mixed or invalid key types") + end + if max < k then + max = k + end + end + statusBuilder[#statusBuilder+1] = "[" + statusDep = statusDep + 1 + encode_newline() + encode(t[1]) + for i = 2, max do + statusBuilder[#statusBuilder+1] = "," + encode_newline() + encode(t[i]) + end + statusDep = statusDep - 1 + encode_newline() + statusVisited[t] = nil + return "]" + else + if t[1] == nil then + error("invalid table: mixed or invalid key types") + end + statusBuilder[#statusBuilder+1] = "[" + statusDep = statusDep + 1 + encode_newline() + encode(t[1]) + local count = 2 + while t[count] ~= nil do + statusBuilder[#statusBuilder+1] = "," + encode_newline() + encode(t[count]) + count = count + 1 + end + if next(t, count-1) ~= nil then + error("invalid table: mixed or invalid key types") + end + statusDep = statusDep - 1 + encode_newline() + statusVisited[t] = nil + return "]" + end +end + +function json.encode(v, option) + statusVisited = {} + statusBuilder = {} + statusDep = 0 + statusOpt = option and setmetatable(option, defaultOpt) or defaultOpt + encode(v) + return table_concat(statusBuilder) +end + +-- json.decode -- + +local statusBuf +local statusPos +local statusTop +local statusAry = {} +local statusRef = {} + +local function find_line() + local line = 1 + local pos = 1 + while true do + local f, _, nl1, nl2 = string_find(statusBuf, '([\n\r])([\n\r]?)', pos) + if not f then + return line, statusPos - pos + 1 + end + local newpos = f + ((nl1 == nl2 or nl2 == '') and 1 or 2) + if newpos > statusPos then + return line, statusPos - pos + 1 + end + pos = newpos + line = line + 1 + end +end + +local function decode_error(msg) + error(string_format("ERROR: %s at line %d col %d", msg, find_line()), 2) +end + +local function get_word() + return string_match(statusBuf, "^[^ \t\r\n%]},]*", statusPos) +end + +local function skip_comment(b) + if b ~= 47 --[[ '/' ]] then + return + end + local c = string_byte(statusBuf, statusPos+1) + if c == 42 --[[ '*' ]] then + -- block comment + local pos = string_find(statusBuf, "*/", statusPos) + if pos then + statusPos = pos + 2 + else + statusPos = #statusBuf + 1 + end + return true + elseif c == 47 --[[ '/' ]] then + -- line comment + local pos = string_find(statusBuf, "[\r\n]", statusPos) + if pos then + statusPos = pos + else + statusPos = #statusBuf + 1 + end + return true + end +end + +local function next_byte() + local pos = string_find(statusBuf, "[^ \t\r\n]", statusPos) + if pos then + statusPos = pos + local b = string_byte(statusBuf, pos) + if not skip_comment(b) then + return b + end + return next_byte() + end + return -1 +end + +local function decode_unicode_surrogate(s1, s2) + return utf8_char(0x10000 + (tonumber(s1, 16) - 0xd800) * 0x400 + (tonumber(s2, 16) - 0xdc00)) +end + +local function decode_unicode_escape(s) + return utf8_char(tonumber(s, 16)) +end + +local function decode_string() + local has_unicode_escape = false + local has_escape = false + local i = statusPos + 1 + while true do + i = string_find(statusBuf, '[%z\1-\31\\"]', i) + if not i then + decode_error "expected closing quote for string" + end + local x = string_byte(statusBuf, i) + if x < 32 then + statusPos = i + decode_error "control character in string" + end + if x == 34 --[[ '"' ]] then + local s = string_sub(statusBuf, statusPos + 1, i - 1) + if has_unicode_escape then + s = string_gsub(string_gsub(s + , "\\u([dD][89aAbB]%x%x)\\u([dD][c-fC-F]%x%x)", decode_unicode_surrogate) + , "\\u(%x%x%x%x)", decode_unicode_escape) + end + if has_escape then + s = string_gsub(s, "\\.", decode_escape_map) + end + statusPos = i + 1 + return s + end + --assert(x == 92 --[[ "\\" ]]) + local nx = string_byte(statusBuf, i+1) + if nx == 117 --[[ "u" ]] then + if not string_match(statusBuf, "^%x%x%x%x", i+2) then + statusPos = i + decode_error "invalid unicode escape in string" + end + has_unicode_escape = true + i = i + 6 + else + if not decode_escape_set[nx] then + statusPos = i + decode_error("invalid escape char '" .. (nx and string_char(nx) or "<eol>") .. "' in string") + end + has_escape = true + i = i + 2 + end + end +end + +local function decode_number() + local num, c = string_match(statusBuf, '^([0-9]+%.?[0-9]*)([eE]?)', statusPos) + if not num or string_byte(num, -1) == 0x2E --[[ "." ]] then + decode_error("invalid number '" .. get_word() .. "'") + end + if c ~= '' then + num = string_match(statusBuf, '^([^eE]*[eE][-+]?[0-9]+)[ \t\r\n%]},/]', statusPos) + if not num then + decode_error("invalid number '" .. get_word() .. "'") + end + end + statusPos = statusPos + #num + return tonumber(num) +end + +local function decode_number_zero() + local num, c = string_match(statusBuf, '^(.%.?[0-9]*)([eE]?)', statusPos) + if not num or string_byte(num, -1) == 0x2E --[[ "." ]] or string_match(statusBuf, '^.[0-9]+', statusPos) then + decode_error("invalid number '" .. get_word() .. "'") + end + if c ~= '' then + num = string_match(statusBuf, '^([^eE]*[eE][-+]?[0-9]+)[ \t\r\n%]},/]', statusPos) + if not num then + decode_error("invalid number '" .. get_word() .. "'") + end + end + statusPos = statusPos + #num + return tonumber(num) +end + +local function decode_number_negative() + statusPos = statusPos + 1 + local c = string_byte(statusBuf, statusPos) + if c then + if c == 0x30 then + return -decode_number_zero() + elseif c > 0x30 and c < 0x3A then + return -decode_number() + end + end + decode_error("invalid number '" .. get_word() .. "'") +end + +local function decode_true() + if string_sub(statusBuf, statusPos, statusPos+3) ~= "true" then + decode_error("invalid literal '" .. get_word() .. "'") + end + statusPos = statusPos + 4 + return true +end + +local function decode_false() + if string_sub(statusBuf, statusPos, statusPos+4) ~= "false" then + decode_error("invalid literal '" .. get_word() .. "'") + end + statusPos = statusPos + 5 + return false +end + +local function decode_null() + if string_sub(statusBuf, statusPos, statusPos+3) ~= "null" then + decode_error("invalid literal '" .. get_word() .. "'") + end + statusPos = statusPos + 4 + return json.null +end + +local function decode_array() + statusPos = statusPos + 1 + local res = {} + local chr = next_byte() + if chr == 93 --[[ ']' ]] then + statusPos = statusPos + 1 + return res + end + statusTop = statusTop + 1 + statusAry[statusTop] = true + statusRef[statusTop] = res + return res +end + +local function decode_object() + statusPos = statusPos + 1 + local res = {} + local chr = next_byte() + if chr == 125 --[[ ']' ]] then + statusPos = statusPos + 1 + return json.createEmptyObject() + end + statusTop = statusTop + 1 + statusAry[statusTop] = false + statusRef[statusTop] = res + return res +end + +local decode_uncompleted_map = { + [ string_byte '"' ] = decode_string, + [ string_byte "0" ] = decode_number_zero, + [ string_byte "1" ] = decode_number, + [ string_byte "2" ] = decode_number, + [ string_byte "3" ] = decode_number, + [ string_byte "4" ] = decode_number, + [ string_byte "5" ] = decode_number, + [ string_byte "6" ] = decode_number, + [ string_byte "7" ] = decode_number, + [ string_byte "8" ] = decode_number, + [ string_byte "9" ] = decode_number, + [ string_byte "-" ] = decode_number_negative, + [ string_byte "t" ] = decode_true, + [ string_byte "f" ] = decode_false, + [ string_byte "n" ] = decode_null, + [ string_byte "[" ] = decode_array, + [ string_byte "{" ] = decode_object, +} +local function unexpected_character() + decode_error("unexpected character '" .. string_sub(statusBuf, statusPos, statusPos) .. "'") +end +local function unexpected_eol() + decode_error("unexpected character '<eol>'") +end + +local decode_map = {} +for i = 0, 255 do + decode_map[i] = decode_uncompleted_map[i] or unexpected_character +end +decode_map[-1] = unexpected_eol + +local function decode() + return decode_map[next_byte()]() +end + +local function decode_item() + local top = statusTop + local ref = statusRef[top] + if statusAry[top] then + ref[#ref+1] = decode() + else + local key = decode_string() + if next_byte() ~= 58 --[[ ':' ]] then + decode_error "expected ':'" + end + statusPos = statusPos + 1 + ref[key] = decode() + end + if top == statusTop then + repeat + local chr = next_byte(); statusPos = statusPos + 1 + if chr == 44 --[[ "," ]] then + local c = next_byte() + if statusAry[statusTop] then + if c ~= 93 --[[ "]" ]] then return end + else + if c ~= 125 --[[ "}" ]] then return end + end + statusPos = statusPos + 1 + else + if statusAry[statusTop] then + if chr ~= 93 --[[ "]" ]] then decode_error "expected ']' or ','" end + else + if chr ~= 125 --[[ "}" ]] then decode_error "expected '}' or ','" end + end + end + statusTop = statusTop - 1 + until statusTop == 0 + end +end + +function json.decode(str) + if type(str) ~= "string" then + error("expected argument of type string, got " .. type(str)) + end + statusBuf = str + statusPos = 1 + statusTop = 0 + if next_byte() == -1 then + return json.null + end + local res = decode() + while statusTop > 0 do + decode_item() + end + if string_find(statusBuf, "[^ \t\r\n]", statusPos) then + decode_error "trailing garbage" + end + return res +end + +return json diff --git a/script/language.lua b/script/language.lua index 771dc948..22546fb8 100644 --- a/script/language.lua +++ b/script/language.lua @@ -6,7 +6,9 @@ local function supportLanguage() local list = {} for path in fs.pairs(ROOT / 'locale') do if fs.is_directory(path) then - list[#list+1] = path:filename():string():lower() + local id = path:filename():string():lower() + list[#list+1] = id + list[id] = true end end return list diff --git a/script/log.lua b/script/log.lua index 597bdc4e..6cb865c3 100644 --- a/script/log.lua +++ b/script/log.lua @@ -85,7 +85,10 @@ function m.warn(...) end function m.error(...) - return pushLog('error', ...) + -- Don't use tail calls, + -- Otherwise, the count of `debug.getinfo` will be wrong + local msg = pushLog('error', ...) + return msg end function m.raw(thd, level, msg, source, currentline, clock) diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 0ece65fc..06169b09 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -16,6 +16,7 @@ local type = type ---@field uri uri ---@field start integer ---@field finish integer +---@field range integer ---@field effect integer ---@field attrs string[] ---@field specials parser.object[] @@ -56,6 +57,14 @@ local type = type ---@field init parser.object ---@field step parser.object ---@field redundant { max: integer, passed: integer } +---@field filter parser.object +---@field loc parser.object +---@field keyword integer[] +---@field casts parser.object[] +---@field mode? '+' | '-' +---@field hasGoTo? true +---@field hasReturn? true +---@field hasBreak? true ---@field _root parser.object ---@class guide @@ -71,6 +80,7 @@ local blockTypes = { ['repeat'] = true, ['do'] = true, ['function'] = true, + ['if'] = true, ['ifblock'] = true, ['elseblock'] = true, ['elseifblock'] = true, @@ -141,6 +151,9 @@ local childMap = { ['doc.see'] = {'name', 'field'}, ['doc.version'] = {'#versions'}, ['doc.diagnostic'] = {'#names'}, + ['doc.as'] = {'as'}, + ['doc.cast'] = {'loc', '#casts'}, + ['doc.cast.block'] = {'extends'}, } ---@type table<string, fun(obj: parser.object, list: parser.object[])> @@ -393,6 +406,7 @@ function m.getRoot(obj) end local parent = obj.parent if not parent then + log.error('Can not find out root:', obj.type) return nil end obj = parent @@ -413,6 +427,7 @@ function m.getUri(obj) return '' end +---@return parser.object? function m.getENV(source, start) if not start then start = 1 @@ -446,19 +461,17 @@ function m.getFunctionVarArgs(func) end --- 获取指定区块中可见的局部变量 ----@param block table ----@param name string {comment = '变量名'} ----@param pos integer {comment = '可见位置'} -function m.getLocal(block, name, pos) - block = m.getBlock(block) - for _ = 1, 10000 do - if not block then - return nil - end - local locals = block.locals - local res +---@param source parser.object +---@param name string # 变量名 +---@param pos integer # 可见位置 +---@return parser.object? +function m.getLocal(source, name, pos) + local root = m.getRoot(source) + local res + m.eachSourceContain(root, pos, function (src) + local locals = src.locals if not locals then - goto CONTINUE + return end for i = 1, #locals do local loc = locals[i] @@ -471,13 +484,8 @@ function m.getLocal(block, name, pos) end end end - if res then - return res, res - end - ::CONTINUE:: - block = m.getParentBlock(block) - end - error('guide.getLocal overstack') + end) + return res end --- 获取指定区块中所有的可见局部变量名称 @@ -602,6 +610,9 @@ local function addChilds(list, obj) end --- 遍历所有包含position的source +---@param ast parser.object +---@param position integer +---@param callback fun(src: parser.object) function m.eachSourceContain(ast, position, callback) local list = { ast } local mark = {} @@ -922,6 +933,7 @@ function m.getKeyNameOfLiteral(obj) end end +---@return string? function m.getKeyName(obj) if not obj then return nil @@ -1027,8 +1039,6 @@ function m.getKeyType(obj) return type(obj.field[1]) elseif tp == 'doc.type.field' then return type(obj.name[1]) - elseif tp == 'dummy' then - return 'string' end if tp == 'doc.field.name' then return type(obj[1]) diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 5a2e1d09..d8e31950 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -2,10 +2,11 @@ local m = require 'lpeglabel' local re = require 'parser.relabel' local guide = require 'parser.guide' local parser = require 'parser.newparser' +local util = require 'utility' local TokenTypes, TokenStarts, TokenFinishs, TokenContents, TokenMarks local Ci, Offset, pushWarning, NextComment, Lines -local parseType +local parseType, parseTypeUnit ---@type any local Parser = re.compile([[ Main <- (Token / Sp)* @@ -52,6 +53,7 @@ Symbol <- ({} { / '...' / '[' / ']' + / '-' !'-' } {}) -> Symbol ]], { @@ -124,6 +126,8 @@ Symbol <- ({} { ---@class parser.object ---@field literal boolean ---@field signs parser.object[] +---@field originalComment parser.object +---@field as? parser.object local function trim(str) return str:match '^%s*(%S+)%s*$' @@ -336,104 +340,6 @@ local function parseSigns(parent) return signs end -local function parseClass(parent) - local result = { - type = 'doc.class', - parent = parent, - fields = {}, - } - result.class = parseName('doc.class.name', result) - if not result.class then - pushWarning { - type = 'LUADOC_MISS_CLASS_NAME', - start = getFinish(), - finish = getFinish(), - } - return nil - end - result.start = getStart() - result.finish = getFinish() - result.signs = parseSigns(result) - if not checkToken('symbol', ':', 1) then - return result - end - nextToken() - - result.extends = {} - - while true do - local extend = parseName('doc.extends.name', result) - or parseTable(result) - if not extend then - pushWarning { - type = 'LUADOC_MISS_CLASS_EXTENDS_NAME', - start = getFinish(), - finish = getFinish(), - } - return result - end - result.extends[#result.extends+1] = extend - result.finish = getFinish() - if not checkToken('symbol', ',', 1) then - break - end - nextToken() - end - return result -end - -local function parseTypeUnitArray(parent, node) - if not checkToken('symbol', '[]', 1) then - return nil - end - nextToken() - local result = { - type = 'doc.type.array', - start = node.start, - finish = getFinish(), - node = node, - parent = parent, - } - node.parent = result - return result -end - -local function parseTypeUnitSign(parent, node) - if not checkToken('symbol', '<', 1) then - return nil - end - nextToken() - local result = { - type = 'doc.type.sign', - start = node.start, - finish = getFinish(), - node = node, - parent = parent, - signs = {}, - } - node.parent = result - while true do - local sign = parseType(result) - if not sign then - pushWarning { - type = 'LUA_DOC_MISS_SIGN', - start = getFinish(), - finish = getFinish(), - } - break - end - result.signs[#result.signs+1] = sign - if checkToken('symbol', ',', 1) then - nextToken() - else - break - end - end - nextSymbolOrError '>' - result.finish = getFinish() - return result -end - local function parseDots(tp, parent) if not checkToken('symbol', '...', 1) then return @@ -527,8 +433,6 @@ local function parseTypeUnitFunction(parent) return typeUnit end -local parseTypeUnit - local function parseFunction(parent) local _, content = peekToken() if content == 'async' then @@ -551,6 +455,58 @@ local function parseFunction(parent) end end +local function parseTypeUnitArray(parent, node) + if not checkToken('symbol', '[]', 1) then + return nil + end + nextToken() + local result = { + type = 'doc.type.array', + start = node.start, + finish = getFinish(), + node = node, + parent = parent, + } + node.parent = result + return result +end + +local function parseTypeUnitSign(parent, node) + if not checkToken('symbol', '<', 1) then + return nil + end + nextToken() + local result = { + type = 'doc.type.sign', + start = node.start, + finish = getFinish(), + node = node, + parent = parent, + signs = {}, + } + node.parent = result + while true do + local sign = parseType(result) + if not sign then + pushWarning { + type = 'LUA_DOC_MISS_SIGN', + start = getFinish(), + finish = getFinish(), + } + break + end + result.signs[#result.signs+1] = sign + if checkToken('symbol', ',', 1) then + nextToken() + else + break + end + end + nextSymbolOrError '>' + result.finish = getFinish() + return result +end + local function parseString(parent) local tp, content = peekToken() if not tp or tp ~= 'string' then @@ -709,6 +665,10 @@ function parseType(parent) if not result.start then result.start = getFinish() end + if checkToken('symbol', '?', 1) then + nextToken() + result.optional = true + end result.finish = getFinish() result.firstFinish = result.finish @@ -785,405 +745,534 @@ function parseType(parent) return result end -local function parseAlias() - local result = { - type = 'doc.alias', - } - result.alias = parseName('doc.alias.name', result) - if not result.alias then - pushWarning { - type = 'LUADOC_MISS_ALIAS_NAME', - start = getFinish(), - finish = getFinish(), - } - return nil - end - result.start = getStart() - result.signs = parseSigns(result) - result.extends = parseType(result) - if not result.extends then - pushWarning { - type = 'LUADOC_MISS_ALIAS_EXTENDS', - start = getFinish(), - finish = getFinish(), - } - return nil - end - result.finish = getFinish() - return result -end - -local function parseParam() - local result = { - type = 'doc.param', - } - result.param = parseName('doc.param.name', result) - or parseDots('doc.param.name', result) - if not result.param then - pushWarning { - type = 'LUADOC_MISS_PARAM_NAME', - start = getFinish(), - finish = getFinish(), +local docSwitch = util.switch() + : case 'class' + : call(function () + local result = { + type = 'doc.class', + fields = {}, } - return nil - end - if checkToken('symbol', '?', 1) then - nextToken() - result.optional = true - end - result.start = result.param.start - result.finish = getFinish() - result.extends = parseType(result) - if not result.extends then - pushWarning { - type = 'LUADOC_MISS_PARAM_EXTENDS', - start = getFinish(), - finish = getFinish(), - } - return result - end - result.finish = getFinish() - result.firstFinish = result.extends.firstFinish - return result -end - -local function parseReturn() - local result = { - type = 'doc.return', - returns = {}, - } - while true do - local docType = parseType(result) - if not docType then - break - end - if not result.start then - result.start = docType.start - end - if checkToken('symbol', '?', 1) then - nextToken() - docType.optional = true + result.class = parseName('doc.class.name', result) + if not result.class then + pushWarning { + type = 'LUADOC_MISS_CLASS_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil end - docType.name = parseName('doc.return.name', docType) - result.returns[#result.returns+1] = docType - if not checkToken('symbol', ',', 1) then - break + result.start = getStart() + result.finish = getFinish() + result.signs = parseSigns(result) + if not checkToken('symbol', ':', 1) then + return result end nextToken() - end - if #result.returns == 0 then - return nil - end - result.finish = getFinish() - return result -end -local function parseField() - local result = { - type = 'doc.field', - } - try(function () - local tp, value = nextToken() - if tp == 'name' then - if value == 'public' - or value == 'protected' - or value == 'private' then - result.visible = value - result.start = getStart() - return true + result.extends = {} + + while true do + local extend = parseName('doc.extends.name', result) + or parseTable(result) + if not extend then + pushWarning { + type = 'LUADOC_MISS_CLASS_EXTENDS_NAME', + start = getFinish(), + finish = getFinish(), + } + return result end + result.extends[#result.extends+1] = extend + result.finish = getFinish() + if not checkToken('symbol', ',', 1) then + break + end + nextToken() end - return false + return result end) - result.field = parseName('doc.field.name', result) - or parseIndexField('doc.field.name', result) - if not result.field then - pushWarning { - type = 'LUADOC_MISS_FIELD_NAME', - start = getFinish(), - finish = getFinish(), - } - return nil - end - if not result.start then - result.start = result.field.start - end - if checkToken('symbol', '?', 1) then - nextToken() - result.optional = true - end - result.extends = parseType(result) - if not result.extends then - pushWarning { - type = 'LUADOC_MISS_FIELD_EXTENDS', - start = getFinish(), - finish = getFinish(), - } - return nil - end - result.finish = getFinish() - return result -end - -local function parseGeneric() - local result = { - type = 'doc.generic', - generics = {}, - } - while true do - local object = { - type = 'doc.generic.object', - parent = result, + : case 'type' + : call(function () + return parseType() + end) + : case 'alias' + : call(function () + local result = { + type = 'doc.alias', } - object.generic = parseName('doc.generic.name', object) - if not object.generic then + result.alias = parseName('doc.alias.name', result) + if not result.alias then pushWarning { - type = 'LUADOC_MISS_GENERIC_NAME', + type = 'LUADOC_MISS_ALIAS_NAME', start = getFinish(), finish = getFinish(), } return nil end - object.start = object.generic.start - if not result.start then - result.start = object.start + result.start = getStart() + result.signs = parseSigns(result) + result.extends = parseType(result) + if not result.extends then + pushWarning { + type = 'LUADOC_MISS_ALIAS_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return nil end - if checkToken('symbol', ':', 1) then + result.finish = getFinish() + return result + end) + : case 'param' + : call(function () + local result = { + type = 'doc.param', + } + result.param = parseName('doc.param.name', result) + or parseDots('doc.param.name', result) + if not result.param then + pushWarning { + type = 'LUADOC_MISS_PARAM_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + if checkToken('symbol', '?', 1) then nextToken() - object.extends = parseType(object) + result.optional = true end - object.finish = getFinish() - result.generics[#result.generics+1] = object - if not checkToken('symbol', ',', 1) then - break + result.start = result.param.start + result.finish = getFinish() + result.extends = parseType(result) + if not result.extends then + pushWarning { + type = 'LUADOC_MISS_PARAM_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return result end - nextToken() - end - result.finish = getFinish() - return result -end - -local function parseVararg() - local result = { - type = 'doc.vararg', - } - result.vararg = parseType(result) - if not result.vararg then - pushWarning { - type = 'LUADOC_MISS_VARARG_TYPE', - start = getFinish(), - finish = getFinish(), + result.finish = getFinish() + result.firstFinish = result.extends.firstFinish + return result + end) + : case 'return' + : call(function () + local result = { + type = 'doc.return', + returns = {}, } - return - end - result.start = result.vararg.start - result.finish = result.vararg.finish - return result -end - -local function parseOverload() - local tp, name = peekToken() - if tp ~= 'name' - or (name ~= 'fun' and name ~= 'async') then - pushWarning { - type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD', - start = getFinish(), - finish = getFinish(), + while true do + local docType = parseType(result) + if not docType then + break + end + if not result.start then + result.start = docType.start + end + if checkToken('symbol', '?', 1) then + nextToken() + docType.optional = true + end + docType.name = parseName('doc.return.name', docType) + result.returns[#result.returns+1] = docType + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + if #result.returns == 0 then + return nil + end + result.finish = getFinish() + return result + end) + : case 'field' + : call(function () + local result = { + type = 'doc.field', } - return nil - end - local result = { - type = 'doc.overload', - } - result.overload = parseFunction(result) - if not result.overload then - return nil - end - result.overload.parent = result - result.start = result.overload.start - result.finish = result.overload.finish - return result -end - -local function parseDeprecated() - return { - type = 'doc.deprecated', - start = getFinish(), - finish = getFinish(), - } -end - -local function parseMeta() - return { - type = 'doc.meta', - start = getFinish(), - finish = getFinish(), - } -end - -local function parseVersion() - local result = { - type = 'doc.version', - versions = {}, - } - while true do - local tp, text = nextToken() - if not tp then + try(function () + local tp, value = nextToken() + if tp == 'name' then + if value == 'public' + or value == 'protected' + or value == 'private' then + result.visible = value + result.start = getStart() + return true + end + end + return false + end) + result.field = parseName('doc.field.name', result) + or parseIndexField('doc.field.name', result) + if not result.field then pushWarning { - type = 'LUADOC_MISS_VERSION', + type = 'LUADOC_MISS_FIELD_NAME', start = getFinish(), finish = getFinish(), } - break + return nil end if not result.start then - result.start = getStart() + result.start = result.field.start end - local version = { - type = 'doc.version.unit', - parent = result, - start = getStart(), + if checkToken('symbol', '?', 1) then + nextToken() + result.optional = true + end + result.extends = parseType(result) + if not result.extends then + pushWarning { + type = 'LUADOC_MISS_FIELD_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.finish = getFinish() + return result + end) + : case 'generic' + : call(function () + local result = { + type = 'doc.generic', + generics = {}, } - if tp == 'symbol' then - if text == '>' then - version.ge = true - elseif text == '<' then - version.le = true + while true do + local object = { + type = 'doc.generic.object', + parent = result, + } + object.generic = parseName('doc.generic.name', object) + if not object.generic then + pushWarning { + type = 'LUADOC_MISS_GENERIC_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + object.start = object.generic.start + if not result.start then + result.start = object.start end - tp, text = nextToken() + if checkToken('symbol', ':', 1) then + nextToken() + object.extends = parseType(object) + end + object.finish = getFinish() + result.generics[#result.generics+1] = object + if not checkToken('symbol', ',', 1) then + break + end + nextToken() end - if tp ~= 'name' then + result.finish = getFinish() + return result + end) + : case 'vararg' + : call(function () + local result = { + type = 'doc.vararg', + } + result.vararg = parseType(result) + if not result.vararg then pushWarning { - type = 'LUADOC_MISS_VERSION', - start = getStart(), + type = 'LUADOC_MISS_VARARG_TYPE', + start = getFinish(), finish = getFinish(), } - break + return end - version.version = tonumber(text) or text - version.finish = getFinish() - result.versions[#result.versions+1] = version - if not checkToken('symbol', ',', 1) then - break + result.start = result.vararg.start + result.finish = result.vararg.finish + return result + end) + : case 'overload' + : call(function () + local tp, name = peekToken() + if tp ~= 'name' + or (name ~= 'fun' and name ~= 'async') then + pushWarning { + type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD', + start = getFinish(), + finish = getFinish(), + } + return nil end - nextToken() - end - if #result.versions == 0 then - return nil - end - result.finish = getFinish() - return result -end - -local function parseSee() - local result = { - type = 'doc.see', - } - result.name = parseName('doc.see.name', result) - if not result.name then - return nil - end - result.start = result.name.start - result.finish = result.name.finish - if checkToken('symbol', '#', 1) then - nextToken() - result.field = parseName('doc.see.field', result) - result.finish = getFinish() - end - return result -end - -local function parseDiagnostic() - local result = { - type = 'doc.diagnostic', - } - local nextTP, mode = nextToken() - if nextTP ~= 'name' then - pushWarning { - type = 'LUADOC_MISS_DIAG_MODE', + local result = { + type = 'doc.overload', + } + result.overload = parseFunction(result) + if not result.overload then + return nil + end + result.overload.parent = result + result.start = result.overload.start + result.finish = result.overload.finish + return result + end) + : case 'deprecated' + : call(function () + return { + type = 'doc.deprecated', start = getFinish(), finish = getFinish(), } - return nil - end - result.mode = mode - result.start = getStart() - result.finish = getFinish() - if mode ~= 'disable-next-line' - and mode ~= 'disable-line' - and mode ~= 'disable' - and mode ~= 'enable' then - pushWarning { - type = 'LUADOC_ERROR_DIAG_MODE', - start = result.start, - finish = result.finish, + end) + : case 'meta' + : call(function () + return { + type = 'doc.meta', + start = getFinish(), + finish = getFinish(), + } + end) + : case 'version' + : call(function () + local result = { + type = 'doc.version', + versions = {}, } - end - - if checkToken('symbol', ':', 1) then - nextToken() - result.names = {} while true do - local name = parseName('doc.diagnostic.name', result) - if not name then + local tp, text = nextToken() + if not tp then pushWarning { - type = 'LUADOC_MISS_DIAG_NAME', + type = 'LUADOC_MISS_VERSION', start = getFinish(), finish = getFinish(), } - return result + break + end + if not result.start then + result.start = getStart() end - result.names[#result.names+1] = name + local version = { + type = 'doc.version.unit', + parent = result, + start = getStart(), + } + if tp == 'symbol' then + if text == '>' then + version.ge = true + elseif text == '<' then + version.le = true + end + tp, text = nextToken() + end + if tp ~= 'name' then + pushWarning { + type = 'LUADOC_MISS_VERSION', + start = getStart(), + finish = getFinish(), + } + break + end + version.version = tonumber(text) or text + version.finish = getFinish() + result.versions[#result.versions+1] = version if not checkToken('symbol', ',', 1) then break end nextToken() end - end + if #result.versions == 0 then + return nil + end + result.finish = getFinish() + return result + end) + : case 'see' + : call(function () + local result = { + type = 'doc.see', + } + result.name = parseName('doc.see.name', result) + if not result.name then + return nil + end + result.start = result.name.start + result.finish = result.name.finish + if checkToken('symbol', '#', 1) then + nextToken() + result.field = parseName('doc.see.field', result) + result.finish = getFinish() + end + return result + end) + : case 'diagnostic' + : call(function () + local result = { + type = 'doc.diagnostic', + } + local nextTP, mode = nextToken() + if nextTP ~= 'name' then + pushWarning { + type = 'LUADOC_MISS_DIAG_MODE', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.mode = mode + result.start = getStart() + result.finish = getFinish() + if mode ~= 'disable-next-line' + and mode ~= 'disable-line' + and mode ~= 'disable' + and mode ~= 'enable' then + pushWarning { + type = 'LUADOC_ERROR_DIAG_MODE', + start = result.start, + finish = result.finish, + } + end - result.finish = getFinish() + if checkToken('symbol', ':', 1) then + nextToken() + result.names = {} + while true do + local name = parseName('doc.diagnostic.name', result) + if not name then + pushWarning { + type = 'LUADOC_MISS_DIAG_NAME', + start = getFinish(), + finish = getFinish(), + } + return result + end + result.names[#result.names+1] = name + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + end - return result -end + result.finish = getFinish() -local function parseModule() - local result = { - type = 'doc.module', - start = getFinish(), - finish = getFinish(), - } - local tp, content = peekToken() - if tp == 'string' then - result.module = content - nextToken() - result.start = getStart() + return result + end) + : case 'module' + : call(function () + local result = { + type = 'doc.module', + start = getFinish(), + finish = getFinish(), + } + local tp, content = peekToken() + if tp == 'string' then + result.module = content + nextToken() + result.start = getStart() + result.finish = getFinish() + result.smark = getMark() + else + pushWarning { + type = 'LUADOC_MISS_MODULE_NAME', + start = getFinish(), + finish = getFinish(), + } + end + return result + end) + : case 'async' + : call(function () + return { + type = 'doc.async', + start = getFinish(), + finish = getFinish(), + } + end) + : case 'nodiscard' + : call(function () + return { + type = 'doc.nodiscard', + start = getFinish(), + finish = getFinish(), + } + end) + : case 'as' + : call(function () + local result = { + type = 'doc.as', + start = getFinish(), + finish = getFinish(), + } + result.as = parseType(result) result.finish = getFinish() - result.smark = getMark() - else - pushWarning { - type = 'LUADOC_MISS_MODULE_NAME', + return result + end) + : case 'cast' + : call(function () + local result = { + type = 'doc.cast', start = getFinish(), finish = getFinish(), + casts = {}, } - end - return result -end -local function parseAsync() - return { - type = 'doc.async', - start = getFinish(), - finish = getFinish(), - } -end + local loc = parseName('doc.cast.name', result) + if not loc then + pushWarning { + type = 'LUADOC_MISS_LOCAL_NAME', + start = getFinish(), + finish = getFinish(), + } + return result + end -local function parseNoDiscard() - return { - type = 'doc.nodiscard', - start = getFinish(), - finish = getFinish(), - } -end + result.loc = loc + result.finish = loc.finish + + while true do + local block = { + type = 'doc.cast.block', + parent = result, + start = getFinish(), + finish = getFinish(), + } + if checkToken('symbol', '+', 1) then + block.mode = '+' + nextToken() + block.start = getStart() + block.finish = getFinish() + elseif checkToken('symbol', '-', 1) then + block.mode = '-' + nextToken() + block.start = getStart() + block.finish = getFinish() + end + + if checkToken('symbol', '?', 1) then + block.optional = true + nextToken() + block.start = block.start or getStart() + block.finish = block.finish + else + block.extends = parseType(block) + if block.extends then + block.start = block.start or block.extends.start + block.finish = block.extends.finish + end + end + + if block.optional or block.extends then + result.casts[#result.casts+1] = block + end + + if checkToken('symbol', ',', 1) then + nextToken() + else + break + end + end + + return result + end) local function convertTokens() local tp, text = nextToken() @@ -1198,41 +1287,7 @@ local function convertTokens() } return nil end - if text == 'class' then - return parseClass() - elseif text == 'type' then - return parseType() - elseif text == 'alias' then - return parseAlias() - elseif text == 'param' then - return parseParam() - elseif text == 'return' then - return parseReturn() - elseif text == 'field' then - return parseField() - elseif text == 'generic' then - return parseGeneric() - elseif text == 'vararg' then - return parseVararg() - elseif text == 'overload' then - return parseOverload() - elseif text == 'deprecated' then - return parseDeprecated() - elseif text == 'meta' then - return parseMeta() - elseif text == 'version' then - return parseVersion() - elseif text == 'see' then - return parseSee() - elseif text == 'diagnostic' then - return parseDiagnostic() - elseif text == 'module' then - return parseModule() - elseif text == 'async' then - return parseAsync() - elseif text == 'nodiscard' then - return parseNoDiscard() - end + return docSwitch(text) end local function trimTailComment(text) @@ -1257,7 +1312,8 @@ end local function buildLuaDoc(comment) local text = comment.text - local _, startPos = text:find('^%-%s*@') + local startPos = (comment.type == 'comment.short' and text:match '^%-%s*@()') + or (comment.type == 'comment.long' and text:match '^@()') if not startPos then return { type = 'doc.comment', @@ -1268,9 +1324,9 @@ local function buildLuaDoc(comment) } end - local doc = text:sub(startPos + 1) + local doc = text:sub(startPos) - parseTokens(doc, comment.start + startPos + 1) + parseTokens(doc, comment.start + startPos) local result = convertTokens() if result then result.range = comment.finish @@ -1313,16 +1369,21 @@ local function isNextLine(binded, doc) return false end local lastDoc = binded[#binded] - if lastDoc.type == 'doc.type' then + if lastDoc.type == 'doc.type' + or lastDoc.type == 'doc.module' then return false end if lastDoc.type == 'doc.class' or lastDoc.type == 'doc.field' then if doc.type ~= 'doc.field' - and doc.type ~= 'doc.comment' then + and doc.type ~= 'doc.comment' + and doc.type ~= 'doc.overload' then return false end end + if doc.type == 'doc.cast' then + return false + end local lastRow = guide.rowColOf(lastDoc.finish) local newRow = guide.rowColOf(doc.start) return newRow - lastRow == 1 @@ -1400,11 +1461,13 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish) if src.start >= start then if src.type == 'local' or src.type == 'self' + or src.type == 'setlocal' or src.type == 'setglobal' or src.type == 'tablefield' or src.type == 'tableindex' or src.type == 'setfield' or src.type == 'setindex' + or src.type == 'setmethod' or src.type == 'function' then src.bindDocs = binded bindSources[#bindSources+1] = src diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua index e226417f..630c12c2 100644 --- a/script/parser/newparser.lua +++ b/script/parser/newparser.lua @@ -117,6 +117,7 @@ local Specials = { ['xpcall'] = true, ['pairs'] = true, ['ipairs'] = true, + ['assert'] = true, } local UnarySymbol = { @@ -537,6 +538,7 @@ local function skipComment(isAction) if longComment then longComment.type = 'comment.long' longComment.text = longComment[1] + longComment.mark = longComment[2] longComment[1] = nil longComment[2] = nil State.comms[#State.comms+1] = longComment @@ -689,9 +691,6 @@ local function parseLocalAttrs() end local function createLocal(obj, attrs) - if not obj then - return nil - end obj.type = 'local' obj.effect = obj.finish @@ -2891,7 +2890,11 @@ local function parseLocal() pushActionIntoCurrentChunk(loc) skipSpace() parseMultiVars(loc, parseName, true) - loc.effect = lastRightPosition() + if loc.value then + loc.effect = loc.value.finish + else + loc.effect = loc.finish + end return loc end @@ -2946,13 +2949,22 @@ local function parseReturn() end pushActionIntoCurrentChunk(rtn) for i = #Chunk, 1, -1 do - local func = Chunk[i] - if func.type == 'function' - or func.type == 'main' then - if not func.returns then - func.returns = {} + local block = Chunk[i] + if block.type == 'function' + or block.type == 'main' then + if not block.returns then + block.returns = {} end - func.returns[#func.returns+1] = rtn + block.returns[#block.returns+1] = rtn + break + end + end + for i = #Chunk, 1, -1 do + local block = Chunk[i] + if block.type == 'ifblock' + or block.type == 'elseifblock' + or block.type == 'else' then + block.hasReturn = true break end end @@ -3052,6 +3064,15 @@ local function parseGoTo() break end end + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + if chunk.type == 'ifblock' + or chunk.type == 'elseifblock' + or chunk.type == 'elseblock' then + chunk.hasGoTo = true + break + end + end pushActionIntoCurrentChunk(action) return action @@ -3586,6 +3607,15 @@ local function parseBreak() break end end + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + if chunk.type == 'ifblock' + or chunk.type == 'elseifblock' + or chunk.type == 'elseblock' then + chunk.hasBreak = true + break + end + end if not ok and Mode == 'Lua' then pushError { type = 'BREAK_OUTSIDE', diff --git a/script/proto/define.lua b/script/proto/define.lua index 389cdf88..fb60c56c 100644 --- a/script/proto/define.lua +++ b/script/proto/define.lua @@ -9,10 +9,10 @@ m.DiagnosticSeverity = { } ---@alias DiagnosticDefaultSeverity ----| '"Hint"' ----| '"Information"' ----| '"Warning"' ----| '"Error"' +---| 'Hint' +---| 'Information' +---| 'Warning' +---| 'Error' --- 诊断类型与默认等级 ---@type table<string, DiagnosticDefaultSeverity> @@ -29,6 +29,7 @@ m.DiagnosticDefaultSeverity = { ['newline-call'] = 'Information', ['newfield-call'] = 'Warning', ['redundant-parameter'] = 'Warning', + ['missing-parameter'] = 'Warning', ['redundant-return'] = 'Warning', ['ambiguity-1'] = 'Warning', ['lowercase-global'] = 'Information', @@ -47,6 +48,7 @@ m.DiagnosticDefaultSeverity = { ['await-in-sync'] = 'Warning', ['not-yieldable'] = 'Warning', ['discard-returns'] = 'Warning', + ['need-check-nil'] = 'Warning', ['type-check'] = 'Warning', ['duplicate-doc-alias'] = 'Warning', @@ -63,9 +65,9 @@ m.DiagnosticDefaultSeverity = { } ---@alias DiagnosticDefaultNeededFileStatus ----| '"Any"' ----| '"Opened"' ----| '"None"' +---| 'Any' +---| 'Opened' +---| 'None' -- 文件状态 m.FileStatus = { @@ -88,6 +90,7 @@ m.DiagnosticDefaultNeededFileStatus = { ['newline-call'] = 'Any', ['newfield-call'] = 'Any', ['redundant-parameter'] = 'Opened', + ['missing-parameter'] = 'Opened', ['redundant-return'] = 'Opened', ['ambiguity-1'] = 'Any', ['lowercase-global'] = 'Any', @@ -106,6 +109,7 @@ m.DiagnosticDefaultNeededFileStatus = { ['await-in-sync'] = 'None', ['not-yieldable'] = 'None', ['discard-returns'] = 'Opened', + ['need-check-nil'] = 'Opened', ['type-check'] = 'None', ['duplicate-doc-alias'] = 'Any', diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua index b359c21c..15b08d49 100644 --- a/script/provider/diagnostic.lua +++ b/script/provider/diagnostic.lua @@ -128,12 +128,17 @@ local function mergeDiags(a, b, c) merge(b) merge(c) + if #t == 0 then + return nil + end + return t end +-- enable `push`, disable `clear` function m.clear(uri) await.close('diag:' .. uri) - if not m.cache[uri] then + if m.cache[uri] == nil then return end m.cache[uri] = nil @@ -144,6 +149,7 @@ function m.clear(uri) log.info('clearDiagnostics', uri) end +-- enable `push` and `send` function m.clearCache(uri) m.cache[uri] = false end @@ -251,14 +257,7 @@ function m.doDiagnostic(uri, isScopeDiag) version = version, diagnostics = full, }) - if #full > 0 then - log.debug('publishDiagnostics', uri, #full) - end - end - - -- always re-sent diagnostics of current file - if not isScopeDiag then - m.cache[uri] = nil + log.debug('publishDiagnostics', uri, #full) end pushResult() @@ -435,6 +434,7 @@ files.watch(function (ev, uri) ---@async m.refresh(uri) elseif ev == 'open' then if ws.isReady(uri) then + m.clearCache(uri) xpcall(m.doDiagnostic, log.error, uri) end elseif ev == 'close' then diff --git a/script/provider/provider.lua b/script/provider/provider.lua index b8b101ed..08b6ca93 100644 --- a/script/provider/provider.lua +++ b/script/provider/provider.lua @@ -42,8 +42,9 @@ local function updateConfig(uri) end local rc = cfgLoader.loadRCConfig(folder.uri, '.luarc.json') + or cfgLoader.loadRCConfig(folder.uri, '.luarc.jsonc') if rc then - log.info('Load config from luarc.json', folder.uri) + log.info('Load config from .luarc.json/.luarc.jsonc', folder.uri) log.debug(inspect(rc)) end @@ -91,6 +92,14 @@ filewatch.event(function (ev, path) ---@async end end end + if util.stringEndWith(path, '.luarc.jsonc') then + for _, scp in ipairs(workspace.folders) do + local rcPath = workspace.getAbsolutePath(scp.uri, '.luarc.jsonc') + if path == rcPath then + updateConfig(scp.uri) + end + end + end end) m.register 'initialize' { @@ -226,7 +235,6 @@ m.register 'workspace/didRenameFiles' { } m.register 'textDocument/didOpen' { - ---@async function (params) local doc = params.textDocument local scheme = furi.split(doc.uri) @@ -235,7 +243,6 @@ m.register 'textDocument/didOpen' { end local uri = files.getRealUri(doc.uri) log.debug('didOpen', uri) - workspace.awaitReady(uri) local text = doc.text files.setText(uri, text, true, function (file) file.version = doc.version @@ -257,13 +264,14 @@ m.register 'textDocument/didClose' { } m.register 'textDocument/didChange' { - ---@async function (params) local doc = params.textDocument + local scheme = furi.split(doc.uri) + if scheme ~= 'file' then + return + end local changes = params.contentChanges local uri = files.getRealUri(doc.uri) - workspace.awaitReady(uri) - --log.debug('changes', util.dump(changes)) local text = files.getOriginText(uri) or '' local rows = files.getCachedRows(uri) text, rows = tm(text, rows, changes) @@ -521,7 +529,8 @@ m.register 'textDocument/completion' { local count, max = workspace.getLoadingProcess(uri) return { { - label = lang.script('HOVER_WS_LOADING', count, max),textEdit = { + label = lang.script('HOVER_WS_LOADING', count, max), + textEdit = { range = { start = params.position, ['end'] = params.position, diff --git a/script/pub/pub.lua b/script/pub/pub.lua index e73aea51..47591ee6 100644 --- a/script/pub/pub.lua +++ b/script/pub/pub.lua @@ -124,7 +124,7 @@ end --- 通过 jumpQueue 可以插队 ---@param name string ---@param params any ----@param callback function +---@param callback? function function m.task(name, params, callback) local info = { id = counter(), diff --git a/script/service/telemetry.lua b/script/service/telemetry.lua index 50af39b1..2e52def2 100644 --- a/script/service/telemetry.lua +++ b/script/service/telemetry.lua @@ -99,7 +99,7 @@ timer.wait(5, function () end local suc, link = pcall(net.connect, 'tcp', 'moe-moe.love', 11577) if not suc then - suc, link = pcall(net.connect, 'tcp', '154.23.191.94', 11577) + suc, link = pcall(net.connect, 'tcp', '154.23.191.39', 11577) end if not suc or not link then return diff --git a/script/utility.lua b/script/utility.lua index 5a52e417..47b0c8d8 100644 --- a/script/utility.lua +++ b/script/utility.lua @@ -83,7 +83,7 @@ local m = {} --- 打印表的结构 ---@param tbl table ----@param option table {optional = 'self'} +---@param option? table ---@return string function m.dump(tbl, option) if not option then @@ -315,8 +315,8 @@ function m.saveFile(path, content) end --- 计数器 ----@param init integer {optional = 'after'} ----@param step integer {optional = 'after'} +---@param init? integer +---@param step? integer ---@return fun():integer function m.counter(init, step) if not step then @@ -346,8 +346,8 @@ function m.sortPairs(t, sorter) end --- 深拷贝(不处理元表) ----@param source table ----@param target table {optional = 'self'} +---@param source table +---@param target? table function m.deepCopy(source, target) local mark = {} local function copy(a, b) @@ -566,7 +566,7 @@ end ---遍历文本的每一行 ---@param text string ----@param keepNL boolean # 保留换行符 +---@param keepNL? boolean # 保留换行符 ---@return fun(text:string):string, integer function m.eachLine(text, keepNL) local offset = 1 diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 8126f393..75620d19 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1,10 +1,6 @@ local guide = require 'parser.guide' local util = require 'utility' -local localID = require 'vm.local-id' -local globalMgr = require 'vm.global-manager' -local signMgr = require 'vm.sign' local config = require 'config' -local genericMgr = require 'vm.generic' local rpath = require 'workspace.require-path' local files = require 'files' ---@class vm @@ -13,7 +9,6 @@ local vm = require 'vm.vm' ---@class parser.object ---@field _compiledNodes boolean ---@field _node vm.node ----@field _localBase table ---@field _globalBase table local searchFieldSwitch = util.switch() @@ -54,7 +49,7 @@ local searchFieldSwitch = util.switch() : case 'string' : call(function (suri, source, key, ref, pushResult) -- change to `string: stringlib` ? - local stringlib = globalMgr.getGlobal('type', 'stringlib') + local stringlib = vm.getGlobal('type', 'stringlib') if stringlib then vm.getClassFields(suri, stringlib, key, ref, pushResult) end @@ -64,9 +59,9 @@ local searchFieldSwitch = util.switch() : call(function (suri, node, key, ref, pushResult) local fields if key then - fields = localID.getSources(node, key) + fields = vm.getLocalSources(node, key) else - fields = localID.getFields(node) + fields = vm.getLocalFields(node) end if fields then for _, src in ipairs(fields) do @@ -119,7 +114,7 @@ local searchFieldSwitch = util.switch() if type(key) ~= 'string' then return end - local global = globalMgr.getGlobal('variable', node.name, key) + local global = vm.getGlobal('variable', node.name, key) if global then for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -131,7 +126,7 @@ local searchFieldSwitch = util.switch() end end else - local globals = globalMgr.getFields('variable', node.name) + local globals = vm.getGlobalFields('variable', node.name) for _, global in ipairs(globals) do for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -158,7 +153,7 @@ local searchFieldSwitch = util.switch() if type(key) ~= 'string' then return end - local global = globalMgr.getGlobal('variable', node.name, key) + local global = vm.getGlobal('variable', node.name, key) if global then for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -168,7 +163,7 @@ local searchFieldSwitch = util.switch() end end else - local globals = globalMgr.getFields('variable', node.name) + local globals = vm.getGlobalFields('variable', node.name) for _, global in ipairs(globals) do for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -185,7 +180,7 @@ local searchFieldSwitch = util.switch() end) -function vm.getClassFields(suri, node, key, ref, pushResult) +function vm.getClassFields(suri, object, key, ref, pushResult) local mark = {} local function searchClass(class, searchedFields) @@ -201,11 +196,51 @@ function vm.getClassFields(suri, node, key, ref, pushResult) local hasFounded = {} for _, field in ipairs(set.fields) do local fieldKey = guide.getKeyName(field) - if key == nil - or fieldKey == key then - if not searchedFields[fieldKey] then - pushResult(field) - hasFounded[fieldKey] = true + if fieldKey then + -- ---@field x boolean -> class.x + if key == nil + or fieldKey == key then + if not searchedFields[fieldKey] then + pushResult(field) + hasFounded[fieldKey] = true + end + end + end + if not hasFounded[fieldKey] then + local keyType = type(key) + if keyType == 'table' then + -- ---@field [integer] boolean -> class[integer] + local fieldNode = vm.compileNode(field.field) + if vm.isSubType(suri, key.name, fieldNode) then + local nkey = '|' .. key.name + if not searchedFields[nkey] then + pushResult(field) + hasFounded[nkey] = true + end + end + else + local typeName + if keyType == 'number' then + if math.tointeger(key) then + typeName = 'integer' + else + typeName = 'number' + end + elseif keyType == 'boolean' + or keyType == 'string' then + typeName = keyType + end + if typeName then + -- ---@field [integer] boolean -> class[1] + local fieldNode = vm.compileNode(field.field) + if vm.isSubType(suri, typeName, fieldNode) then + local nkey = '|' .. typeName + if not searchedFields[nkey] then + pushResult(field) + hasFounded[nkey] = true + end + end + end end end end @@ -214,19 +249,23 @@ function vm.getClassFields(suri, node, key, ref, pushResult) for _, src in ipairs(set.bindSources) do searchFieldSwitch(src.type, suri, src, key, ref, function (field) local fieldKey = guide.getKeyName(field) - if not searchedFields[fieldKey] - and guide.isSet(field) then - hasFounded[fieldKey] = true - pushResult(field) + if fieldKey then + if not searchedFields[fieldKey] + and guide.isSet(field) then + hasFounded[fieldKey] = true + pushResult(field) + end end end) if src.value and src.value.type == 'table' then searchFieldSwitch('table', suri, src.value, key, ref, function (field) local fieldKey = guide.getKeyName(field) - if not searchedFields[fieldKey] - and guide.isSet(field) then - hasFounded[fieldKey] = true - pushResult(field) + if fieldKey then + if not searchedFields[fieldKey] + and guide.isSet(field) then + hasFounded[fieldKey] = true + pushResult(field) + end end end) end @@ -239,7 +278,7 @@ function vm.getClassFields(suri, node, key, ref, pushResult) end for _, extend in ipairs(set.extends) do if extend.type == 'doc.extends.name' then - local extendType = globalMgr.getGlobal('type', extend[1]) + local extendType = vm.getGlobal('type', extend[1]) if extendType then searchClass(extendType, searchedFields) end @@ -253,12 +292,12 @@ function vm.getClassFields(suri, node, key, ref, pushResult) local function searchGlobal(class) if class.cate == 'type' and class.name == '_G' then if key == nil then - local sets = globalMgr.getGlobalSets(suri, 'variable') + local sets = vm.getGlobalSets(suri, 'variable') for _, set in ipairs(sets) do pushResult(set) end else - local global = globalMgr.getGlobal('variable', key) + local global = vm.getGlobal('variable', key) if global then for _, set in ipairs(global:getSets(suri)) do pushResult(set) @@ -268,8 +307,8 @@ function vm.getClassFields(suri, node, key, ref, pushResult) end end - searchClass(node) - searchGlobal(node) + searchClass(object) + searchGlobal(object) end ---@class parser.object @@ -283,10 +322,13 @@ local function getObjectSign(source) end source._sign = false if source.type == 'function' then + if not source.bindDocs then + return false + end for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.generic' then if not source._sign then - source._sign = signMgr() + source._sign = vm.createSign() break end end @@ -314,14 +356,18 @@ local function getObjectSign(source) if not hasGeneric then return false end - source._sign = signMgr() + source._sign = vm.createSign() if source.type == 'doc.type.function' then for _, arg in ipairs(source.args) do - local argNode = vm.compileNode(arg.extends) - if arg.optional then - argNode:addOptional() + if arg.extends then + local argNode = vm.compileNode(arg.extends) + if arg.optional then + argNode:addOptional() + end + source._sign:addSign(argNode) + else + source._sign:addSign(vm.createNode()) end - source._sign:addSign(argNode) end end end @@ -354,7 +400,7 @@ function vm.getReturnOfFunction(func, index) if not sign then return rtn end - return genericMgr(rtn, sign) + return vm.createGeneric(rtn, sign) end end @@ -455,6 +501,9 @@ local function getReturn(func, index, args) result:merge(rnode) end end + if result and returnNode:isOptional() then + result:addOptional() + end end end end @@ -462,6 +511,25 @@ local function getReturn(func, index, args) return result end +---@param source parser.object +---@return boolean +local function bindAs(source) + local root = guide.getRoot(source) + local docs = root.docs + if not docs then + return + end + for _, doc in ipairs(docs) do + if doc.type == 'doc.as' and doc.originalComment.start == source.finish + 2 then + if doc.as then + vm.setNode(source, vm.compileNode(doc.as), true) + end + return true + end + end + return false +end + local function bindDocs(source) local isParam = source.parent.type == 'funcargs' or source.parent.type == 'in' @@ -485,7 +553,11 @@ local function bindDocs(source) end if doc.type == 'doc.param' then if isParam and source[1] == doc.param[1] then - vm.setNode(source, vm.compileNode(doc)) + local node = vm.compileNode(doc) + if doc.optional then + node:addOptional() + end + vm.setNode(source, node) return true end end @@ -503,12 +575,17 @@ local function bindDocs(source) vm.setNode(source, vm.compileNode(ast)) return true end + if doc.type == 'doc.overload' then + if not isParam then + vm.setNode(source, vm.compileNode(doc)) + end + end end return false end local function compileByLocalID(source) - local sources = localID.getSources(source) + local sources = vm.getLocalSources(source) if not sources then return end @@ -571,7 +648,7 @@ local function selectNode(source, list, index) if exp.type == 'call' then result = getReturn(exp.node, index, exp.args) if not result then - vm.setNode(source, globalMgr.getGlobal('type', 'unknown')) + vm.setNode(source, vm.declareGlobal('type', 'unknown')) return vm.getNode(source) end else @@ -597,7 +674,7 @@ local function selectNode(source, list, index) end end if not hasKnownType then - rtnNode:merge(globalMgr.getGlobal('type', 'unknown')) + rtnNode:merge(vm.declareGlobal('type', 'unknown')) end vm.setNode(source, rtnNode) return rtnNode @@ -664,10 +741,21 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) for n in callNode:eachObject() do if n.type == 'function' then + local sign = getObjectSign(n) local farg = getFuncArg(n, myIndex) if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then + if fn.type == 'doc.type.function' then + if sign then + local generic = vm.createGeneric(fn, sign) + local args = {} + for i = fixIndex + 1, myIndex - 1 do + args[#args+1] = call.args[i] + end + fn = generic:resolve(guide.getUri(call), args) + end + end vm.setNode(arg, fn) end end @@ -716,29 +804,19 @@ function vm.compileCallArg(arg, call, index) if call.node.special == 'pcall' or call.node.special == 'xpcall' then local fixIndex = call.node.special == 'pcall' and 1 or 2 - callNode = vm.compileNode(call.args[1]) - compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex) + if call.args and call.args[1] then + callNode = vm.compileNode(call.args[1]) + compileCallArgNode(arg, call, callNode, fixIndex, index - fixIndex) + end end return vm.getNode(arg) end ---@param source parser.object ---@return vm.node -local function compileLocalBase(source) - if not source._localBase then - source._localBase = { - type = 'localbase', - parent = source, - } - end - local baseNode = vm.getNode(source._localBase) - if baseNode then - return baseNode - end - baseNode = vm.createNode() - vm.setNode(source._localBase, baseNode, true) - +local function compileLocal(source) vm.setNode(source, source) + local hasMarkDoc if source.bindDocs then hasMarkDoc = bindDocs(source) @@ -788,14 +866,19 @@ local function compileLocalBase(source) if n.type == 'doc.type.function' then for index, arg in ipairs(n.args) do if func.args[index] == source then - vm.setNode(source, vm.compileNode(arg)) + local argNode = vm.compileNode(arg) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end + end hasDocArg = true end end end end if not hasDocArg then - vm.setNode(source, globalMgr.getGlobal('type', 'any')) + vm.setNode(source, vm.declareGlobal('type', 'any')) end end -- for x in ... do @@ -805,15 +888,10 @@ local function compileLocalBase(source) -- for x = ... do if source.parent.type == 'loop' then - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.compileNode(source.parent) end - baseNode:merge(vm.getNode(source)) - vm.removeNode(source) - - baseNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) - - return baseNode + vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) end local compilerSwitch = util.switch() @@ -867,41 +945,79 @@ local compilerSwitch = util.switch() end) : case 'paren' : call(function (source) + if bindAs(source) then + return + end if source.exp then vm.setNode(source, vm.compileNode(source.exp)) end end) : case 'local' : case 'self' + ---@param source parser.object : call(function (source) - local baseNode = compileLocalBase(source) - vm.setNode(source, baseNode, true) - if not baseNode:getData 'hasDefined' and source.ref then + compileLocal(source) + local refs = source.ref + if not refs then + return + end + + local hasMark = vm.getNode(source):getData 'hasDefined' + + local runner = vm.createRunner(source) + runner:launch(function (src, node) + if src.type == 'setlocal' then + if src.bindDocs then + for _, doc in ipairs(src.bindDocs) do + if doc.type == 'doc.type' then + vm.setNode(src, vm.compileNode(doc), true) + return vm.getNode(src) + end + end + end + if src.value and guide.isLiteral(src.value) then + if src.value.type == 'table' then + vm.setNode(src, vm.createNode(src.value), true) + else + vm.setNode(src, vm.compileNode(src.value), true) + end + elseif src.value + and src.value.type == 'binary' + and src.value.op and src.value.op.type == 'or' + and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then + -- x = x or 1 + vm.setNode(src, vm.compileNode(src.value)) + else + vm.setNode(src, node, true) + end + return vm.getNode(src) + elseif src.type == 'getlocal' then + if bindAs(src) then + return + end + vm.setNode(src, node, true) + end + end) + + if not hasMark then + local parentFunc = guide.getParentFunction(source) for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' then - vm.setNode(source, vm.compileNode(ref)) + if ref.type == 'setlocal' + and guide.getParentFunction(ref) == parentFunc then + vm.setNode(source, vm.getNode(ref)) end end end end) : case 'setlocal' : call(function (source) - local baseNode = compileLocalBase(source.node) - if not baseNode:getData 'hasDefined' and source.value then - if source.value.type == 'table' then - vm.setNode(source, source.value) - else - vm.setNode(source, vm.compileNode(source.value)) - end - end - baseNode:merge(vm.getNode(source)) - vm.setNode(source, baseNode, true) vm.compileNode(source.node) end) : case 'getlocal' : call(function (source) - local baseNode = compileLocalBase(source.node) - vm.setNode(source, baseNode, true) + if bindAs(source) then + return + end vm.compileNode(source.node) end) : case 'setfield' @@ -924,6 +1040,9 @@ local compilerSwitch = util.switch() : case 'getmethod' : case 'getindex' : call(function (source) + if bindAs(source) then + return + end compileByLocalID(source) local key = guide.getKeyName(source) if key == nil and source.index then @@ -959,6 +1078,9 @@ local compilerSwitch = util.switch() end) : case 'getglobal' : call(function (source) + if bindAs(source) then + return + end if source.node[1] ~= '_ENV' then return end @@ -1019,7 +1141,7 @@ local compilerSwitch = util.switch() end) end if hasGeneric then - vm.setNode(source, genericMgr(rtn, sign)) + vm.setNode(source, vm.createGeneric(rtn, sign)) else vm.setNode(source, vm.compileNode(rtn)) end @@ -1092,29 +1214,44 @@ local compilerSwitch = util.switch() -- for k, v in pairs(t) do --> for k, v in iterator, status, initValue do --> local k, v = iterator(status, initValue) - source._iterator = {} - source._iterArgs = {{}, {}} - -- iterator - selectNode(source._iterator, source.exps, 1) - -- status - selectNode(source._iterArgs[1], source.exps, 2) - -- initValue - selectNode(source._iterArgs[2], source.exps, 3) - end + source._iterator = { + type = 'dummyfunc', + parent = source, + } + source._iterArgs = {{},{}} + end + -- iterator + selectNode(source._iterator, source.exps, 1) + -- status + selectNode(source._iterArgs[1], source.exps, 2) + -- initValue + selectNode(source._iterArgs[2], source.exps, 3) if source.keys then for i, loc in ipairs(source.keys) do local node = getReturn(source._iterator, i, source._iterArgs) if node then + if i == 1 then + node:removeOptional() + end vm.setNode(loc, node) end end end end) + : case 'loop' + : call(function (source) + if source.loc then + vm.setNode(source.loc, vm.declareGlobal('type', 'integer')) + end + end) : case 'doc.type' : call(function (source) for _, typeUnit in ipairs(source.types) do vm.setNode(source, vm.compileNode(typeUnit)) end + if source.optional then + vm.getNode(source):addOptional() + end end) : case 'doc.type.integer' : case 'doc.type.string' @@ -1130,7 +1267,13 @@ local compilerSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) vm.setNode(source, source) - local global = globalMgr.getGlobal('type', source.node[1]) + if not source.node[1] then + return + end + local global = vm.getGlobal('type', source.node[1]) + if not global then + return + end for _, set in ipairs(global:getSets(uri)) do if set.type == 'doc.class' then if set.extends then @@ -1161,14 +1304,22 @@ local compilerSwitch = util.switch() if not source.extends then return end - vm.setNode(source, vm.compileNode(source.extends)) + local fieldNode = vm.compileNode(source.extends) + if source.optional then + fieldNode:addOptional() + end + vm.setNode(source, fieldNode) end) : case 'doc.type.field' : call(function (source) if not source.extends then return end - vm.setNode(source, vm.compileNode(source.extends)) + local fieldNode = vm.compileNode(source.extends) + if source.optional then + fieldNode:addOptional() + end + vm.setNode(source, fieldNode) end) : case 'doc.param' : call(function (source) @@ -1208,7 +1359,7 @@ local compilerSwitch = util.switch() end) : case 'doc.see.name' : call(function (source) - local type = globalMgr.getGlobal('type', source[1]) + local type = vm.getGlobal('type', source[1]) if type then vm.setNode(source, vm.compileNode(type)) end @@ -1218,7 +1369,10 @@ local compilerSwitch = util.switch() if source.extends then vm.setNode(source, vm.compileNode(source.extends)) else - vm.setNode(source, globalMgr.getGlobal('type', 'any')) + vm.setNode(source, vm.declareGlobal('type', 'any')) + end + if source.optional then + vm.getNode(source):addOptional() end end) : case 'generic' @@ -1227,10 +1381,16 @@ local compilerSwitch = util.switch() end) : case 'unary' : call(function (source) + if bindAs(source) then + return + end + if not source[1] then + return + end if source.op.type == 'not' then local result = vm.test(source[1]) if result == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'boolean')) + vm.setNode(source, vm.declareGlobal('type', 'boolean')) return else vm.setNode(source, { @@ -1244,13 +1404,13 @@ local compilerSwitch = util.switch() end end if source.op.type == '#' then - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end if source.op.type == '-' then local v = vm.getNumber(source[1]) if v == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return else vm.setNode(source, { @@ -1266,7 +1426,7 @@ local compilerSwitch = util.switch() if source.op.type == '~' then local v = vm.getInteger(source[1]) if v == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return else vm.setNode(source, { @@ -1282,34 +1442,42 @@ local compilerSwitch = util.switch() end) : case 'binary' : call(function (source) + if bindAs(source) then + return + end + if not source[1] or not source[2] then + return + end if source.op.type == 'and' then + local node1 = vm.compileNode(source[1]) + local node2 = vm.compileNode(source[2]) local r1 = vm.test(source[1]) if r1 == true then - vm.setNode(source, vm.compileNode(source[2])) - return - end - if r1 == false then - vm.setNode(source, vm.compileNode(source[1])) - return + vm.setNode(source, node2) + elseif r1 == false then + vm.setNode(source, node1) + else + vm.setNode(source, node2) end - return end if source.op.type == 'or' then + local node1 = vm.compileNode(source[1]) + local node2 = vm.compileNode(source[2]) local r1 = vm.test(source[1]) if r1 == true then - vm.setNode(source, vm.compileNode(source[1])) - return - end - if r1 == false then - vm.setNode(source, vm.compileNode(source[2])) - return + vm.setNode(source, node1) + elseif r1 == false then + vm.setNode(source, node2) + else + vm.getNode(source):merge(node1) + vm.getNode(source):setTruthy() + vm.getNode(source):merge(node2) end - return end if source.op.type == '==' then local result = vm.equal(source[1], source[2]) if result == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'boolean')) + vm.setNode(source, vm.declareGlobal('type', 'boolean')) return else vm.setNode(source, { @@ -1325,7 +1493,7 @@ local compilerSwitch = util.switch() if source.op.type == '~=' then local result = vm.equal(source[1], source[2]) if result == nil then - vm.setNode(source, globalMgr.getGlobal('type', 'boolean')) + vm.setNode(source, vm.declareGlobal('type', 'boolean')) return else vm.setNode(source, { @@ -1351,7 +1519,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1368,7 +1536,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1385,7 +1553,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1402,7 +1570,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1419,7 +1587,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'integer')) + vm.setNode(source, vm.declareGlobal('type', 'integer')) return end end @@ -1437,7 +1605,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1455,7 +1623,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1473,7 +1641,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1490,14 +1658,14 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end if source.op.type == '%' then local a = vm.getNumber(source[1]) local b = vm.getNumber(source[2]) - if a and b then + if a and b and b ~= 0 then local result = a % b vm.setNode(source, { type = math.type(result) == 'integer' and 'integer' or 'number', @@ -1508,7 +1676,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1525,7 +1693,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1543,7 +1711,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'number')) + vm.setNode(source, vm.declareGlobal('type', 'number')) return end end @@ -1580,7 +1748,7 @@ local compilerSwitch = util.switch() }) return else - vm.setNode(source, globalMgr.getGlobal('type', 'string')) + vm.setNode(source, vm.declareGlobal('type', 'string')) return end end @@ -1614,17 +1782,20 @@ local function compileByGlobal(source) vm.setNode(source, globalNode, true) return end + ---@type vm.node globalNode = vm.createNode(global) vm.setNode(root._globalBase[name], globalNode, true) + vm.setNode(source, globalNode, true) - local sets = global.links[uri].sets or {} - local gets = global.links[uri].gets or {} - for _, set in ipairs(sets) do - vm.setNode(set, globalNode, true) - end - for _, get in ipairs(gets) do - vm.setNode(get, globalNode, true) - end + -- TODO:don't mix + --local sets = global.links[uri].sets or {} + --local gets = global.links[uri].gets or {} + --for _, set in ipairs(sets) do + -- vm.setNode(set, globalNode, true) + --end + --for _, get in ipairs(gets) do + -- vm.setNode(get, globalNode, true) + --end if global.cate == 'variable' then local hasMarkDoc @@ -1672,7 +1843,11 @@ end ---@return vm.node function vm.compileNode(source) if not source then - error('Can not compile nil node') + if TEST then + error('Can not compile nil source') + else + log.error('Can not compile nil source') + end end if source.type == 'global' then diff --git a/script/vm/def.lua b/script/vm/def.lua index b66e8fda..83e92686 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -2,8 +2,6 @@ local vm = require 'vm.vm' local util = require 'utility' local guide = require 'parser.guide' -local localID = require 'vm.local-id' -local globalMgr = require 'vm.global-manager' local simpleSwitch @@ -79,6 +77,13 @@ simpleSwitch = util.switch() pushResult(source.node) end end) + : case 'doc.cast.name' + : call(function (source, pushResult) + local loc = guide.getLocal(source, source[1], source.start) + if loc then + pushResult(loc) + end + end) local searchFieldSwitch = util.switch() : case 'table' @@ -97,7 +102,7 @@ local searchFieldSwitch = util.switch() ---@param key string : call(function (suri, obj, key, pushResult) if obj.cate == 'variable' then - local newGlobal = globalMgr.getGlobal('variable', obj.name, key) + local newGlobal = vm.getGlobal('variable', obj.name, key) if newGlobal then for _, set in ipairs(newGlobal:getSets(suri)) do pushResult(set) @@ -110,7 +115,7 @@ local searchFieldSwitch = util.switch() end) : case 'local' : call(function (suri, obj, key, pushResult) - local sources = localID.getSources(obj, key) + local sources = vm.getLocalSources(obj, key) if sources then for _, src in ipairs(sources) do if guide.isSet(src) then @@ -189,7 +194,7 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchByLocalID(source, pushResult) - local idSources = localID.getSources(source) + local idSources = vm.getLocalSources(source) if not idSources then return end diff --git a/script/vm/doc.lua b/script/vm/doc.lua index 5a92a103..e2b383b6 100644 --- a/script/vm/doc.lua +++ b/script/vm/doc.lua @@ -3,7 +3,6 @@ local guide = require 'parser.guide' ---@class vm local vm = require 'vm.vm' local config = require 'config' -local globalMgr = require 'vm.global-manager' ---获取class与alias ---@param suri uri @@ -11,13 +10,13 @@ local globalMgr = require 'vm.global-manager' ---@return parser.object[] function vm.getDocSets(suri, name) if name then - local global = globalMgr.getGlobal('type', name) + local global = vm.getGlobal('type', name) if not global then return {} end return global:getSets(suri) else - return globalMgr.getGlobalSets(suri, 'type') + return vm.getGlobalSets(suri, 'type') end end @@ -27,6 +26,9 @@ function vm.isMetaFile(uri) return false end local cache = files.getCache(uri) + if not cache then + return false + end if cache.isMeta ~= nil then return cache.isMeta end @@ -332,6 +334,9 @@ function vm.isDiagDisabledAt(uri, position, name) return false end local cache = files.getCache(uri) + if not cache then + return false + end if not cache.diagnosticRanges then cache.diagnosticRanges = {} for _, doc in ipairs(status.ast.docs) do diff --git a/script/vm/field.lua b/script/vm/field.lua index ba7cd4c1..5de838be 100644 --- a/script/vm/field.lua +++ b/script/vm/field.lua @@ -15,6 +15,15 @@ local searchByNodeSwitch = util.switch() pushResult(source) end) +local function searchByLocalID(source, pushResult) + local fields = vm.getLocalFields(source) + if fields then + for _, field in ipairs(fields) do + pushResult(field) + end + end +end + local function searchByNode(source, pushResult) local uri = guide.getUri(source) vm.compileByParentNode(source, nil, true, function (field) @@ -35,6 +44,7 @@ function vm.getFields(source) end end + searchByLocalID(source, pushResult) searchByNode(source, pushResult) return results diff --git a/script/vm/generic.lua b/script/vm/generic.lua index b3981ff8..6462028e 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -1,3 +1,4 @@ +---@class vm local vm = require 'vm.vm' ---@class parser.object @@ -114,7 +115,7 @@ end ---@param uri uri ---@param args parser.object ----@return parser.object +---@return vm.node function mt:resolve(uri, args) local resolved = self.sign:resolve(uri, args) local protoNode = vm.compileNode(self.proto) @@ -129,7 +130,7 @@ end ---@param proto vm.object ---@param sign vm.sign ---@return vm.generic -return function (proto, sign) +function vm.createGeneric(proto, sign) local generic = setmetatable({ sign = sign, proto = proto, diff --git a/script/vm/global-manager.lua b/script/vm/global-manager.lua deleted file mode 100644 index f25bb5a0..00000000 --- a/script/vm/global-manager.lua +++ /dev/null @@ -1,364 +0,0 @@ -local util = require 'utility' -local guide = require 'parser.guide' -local globalBuilder = require 'vm.global' -local signMgr = require 'vm.sign' -local genericMgr = require 'vm.generic' ----@class vm -local vm = require 'vm.vm' - ----@class parser.object ----@field _globalNode vm.global - ----@class vm.global-manager -local m = {} ----@type table<string, vm.global> -m.globals = {} ----@type table<uri, table<string, boolean>> -m.globalSubs = util.multiTable(2) - -local compilerGlobalSwitch = util.switch() - : case 'local' - : call(function (source) - if source.special ~= '_G' then - return - end - if source.ref then - for _, ref in ipairs(source.ref) do - m.compileObject(ref) - end - end - end) - : case 'getlocal' - : call(function (source) - if source.special ~= '_G' then - return - end - if not source.next then - return - end - m.compileObject(source.next) - end) - : case 'setglobal' - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local global = m.declareGlobal('variable', name, uri) - global:addSet(uri, source) - source._globalNode = global - end) - : case 'getglobal' - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local global = m.declareGlobal('variable', name, uri) - global:addGet(uri, source) - source._globalNode = global - - local nxt = source.next - if nxt then - m.compileObject(nxt) - end - end) - : case 'setfield' - : case 'setmethod' - : case 'setindex' - ---@param source parser.object - : call(function (source) - local name - local keyName = guide.getKeyName(source) - if not keyName then - return - end - if source.node._globalNode then - local parentName = source.node._globalNode:getName() - if parentName == '_G' then - name = keyName - else - name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) - end - elseif source.node.special == '_G' then - name = keyName - end - if not name then - return - end - local uri = guide.getUri(source) - local global = m.declareGlobal('variable', name, uri) - global:addSet(uri, source) - source._globalNode = global - end) - : case 'getfield' - : case 'getmethod' - : case 'getindex' - ---@param source parser.object - : call(function (source) - local name - local keyName = guide.getKeyName(source) - if not keyName then - return - end - if source.node._globalNode then - local parentName = source.node._globalNode:getName() - if parentName == '_G' then - name = keyName - else - name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) - end - elseif source.node.special == '_G' then - name = keyName - end - local uri = guide.getUri(source) - local global = m.declareGlobal('variable', name, uri) - global:addGet(uri, source) - source._globalNode = global - - local nxt = source.next - if nxt then - m.compileObject(nxt) - end - end) - : case 'call' - : call(function (source) - if source.node.special == 'rawset' - or source.node.special == 'rawget' then - if not source.args then - return - end - local g = source.args[1] - local key = source.args[2] - if g and key and g.special == '_G' then - local name = guide.getKeyName(key) - if name then - local uri = guide.getUri(source) - local global = m.declareGlobal('variable', name, uri) - if source.node.special == 'rawset' then - global:addSet(uri, source) - source.value = source.args[3] - else - global:addGet(uri, source) - end - source._globalNode = global - - local nxt = source.next - if nxt then - m.compileObject(nxt) - end - end - end - end - end) - : case 'doc.class' - ---@param source parser.object - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local class = m.declareGlobal('type', name, uri) - class:addSet(uri, source) - source._globalNode = class - - if source.signs then - source._sign = signMgr() - for _, sign in ipairs(source.signs) do - source._sign:addSign(vm.compileNode(sign)) - end - if source.extends then - for _, ext in ipairs(source.extends) do - if ext.type == 'doc.type.table' then - ext._generic = genericMgr(ext, source._sign) - end - end - end - end - end) - : case 'doc.alias' - : call(function (source) - local uri = guide.getUri(source) - local name = guide.getKeyName(source) - local alias = m.declareGlobal('type', name, uri) - alias:addSet(uri, source) - source._globalNode = alias - - if source.signs then - source._sign = signMgr() - for _, sign in ipairs(source.signs) do - source._sign:addSign(vm.compileNode(sign)) - end - source.extends._generic = genericMgr(source.extends, source._sign) - end - end) - : case 'doc.type.name' - : call(function (source) - local uri = guide.getUri(source) - local name = source[1] - local type = m.declareGlobal('type', name, uri) - type:addGet(uri, source) - source._globalNode = type - end) - : case 'doc.extends.name' - : call(function (source) - local uri = guide.getUri(source) - local name = source[1] - local class = m.declareGlobal('type', name, uri) - class:addGet(uri, source) - source._globalNode = class - end) - - ----@alias vm.global.cate '"variable"' | '"type"' - ----@param cate vm.global.cate ----@param name string ----@param uri uri ----@return vm.global -function m.declareGlobal(cate, name, uri) - local key = cate .. '|' .. name - m.globalSubs[uri][key] = true - if not m.globals[key] then - m.globals[key] = globalBuilder(name, cate) - end - return m.globals[key] -end - ----@param cate vm.global.cate ----@param name string ----@param field? string ----@return vm.global? -function m.getGlobal(cate, name, field) - local key = cate .. '|' .. name - if field then - key = key .. vm.ID_SPLITE .. field - end - return m.globals[key] -end - ----@param cate vm.global.cate ----@param name string ----@return vm.global[] -function m.getFields(cate, name) - local globals = {} - local key = cate .. '|' .. name - - -- TODO: optimize - local clock = os.clock() - for gid, global in pairs(m.globals) do - if gid ~= key - and util.stringStartWith(gid, key) - and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE - and not gid:find(vm.ID_SPLITE, #key + 2) then - globals[#globals+1] = global - end - end - local cost = os.clock() - clock - if cost > 0.1 then - log.warn('global-manager getFields cost %.3f', cost) - end - - return globals -end - ----@param cate vm.global.cate ----@return vm.global[] -function m.getGlobals(cate) - local globals = {} - - -- TODO: optimize - local clock = os.clock() - for gid, global in pairs(m.globals) do - if util.stringStartWith(gid, cate) - and not gid:find(vm.ID_SPLITE) then - globals[#globals+1] = global - end - end - local cost = os.clock() - clock - if cost > 0.1 then - log.warn('global-manager getGlobals cost %.3f', cost) - end - - return globals -end - ----@param suri uri ----@param cate vm.global.cate ----@return parser.object[] -function m.getGlobalSets(suri, cate) - local globals = m.getGlobals(cate) - local result = {} - for _, global in ipairs(globals) do - local sets = global:getSets(suri) - for _, set in ipairs(sets) do - result[#result+1] = set - end - end - return result -end - ----@param suri uri ----@param cate vm.global.cate ----@param name string ----@return boolean -function m.hasGlobalSets(suri, cate, name) - local global = m.getGlobal(cate, name) - if not global then - return false - end - local sets = global:getSets(suri) - if #sets == 0 then - return false - end - return true -end - ----@param source parser.object -function m.compileObject(source) - if source._globalNode ~= nil then - return - end - source._globalNode = false - compilerGlobalSwitch(source.type, source) -end - ----@param source parser.object -function m.compileAst(source) - local env = guide.getENV(source) - m.compileObject(env) - guide.eachSpecialOf(source, 'rawset', function (src) - m.compileObject(src.parent) - end) - guide.eachSpecialOf(source, 'rawget', function (src) - m.compileObject(src.parent) - end) - guide.eachSourceTypes(source.docs, { - 'doc.class', - 'doc.alias', - 'doc.type.name', - 'doc.extends.name', - }, function (src) - m.compileObject(src) - end) -end - ----@return vm.global -function m.getNode(source) - if source.type == 'field' - or source.type == 'method' then - source = source.parent - end - return source._globalNode -end - ----@param uri uri -function m.dropUri(uri) - local globalSub = m.globalSubs[uri] - m.globalSubs[uri] = nil - for key in pairs(globalSub) do - local global = m.globals[key] - if global then - global:dropUri(uri) - if not global:isAlive() then - m.globals[key] = nil - end - end - end -end - -return m diff --git a/script/vm/global.lua b/script/vm/global.lua index 1c46c9a3..a54ab552 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -1,5 +1,9 @@ -local util = require 'utility' -local scope= require 'workspace.scope' +local util = require 'utility' +local scope = require 'workspace.scope' +local guide = require 'parser.guide' +local files = require 'files' +---@class vm +local vm = require 'vm.vm' ---@class vm.global.link ---@field gets parser.object[] @@ -15,8 +19,6 @@ mt.__index = mt mt.type = 'global' mt.name = '' -local ID_SPLITE = '\x1F' - ---@param uri uri ---@param source parser.object function mt:addSet(uri, source) @@ -106,7 +108,7 @@ end ---@return string function mt:getKeyName() - return self.name:match('[^' .. ID_SPLITE .. ']+$') + return self.name:match('[^' .. vm.ID_SPLITE .. ']+$') end ---@return boolean @@ -116,10 +118,427 @@ end ---@param cate vm.global.cate ---@return vm.global -return function (name, cate) +local function createGlobal(name, cate) return setmetatable({ name = name, cate = cate, links = util.multiTable(2), }, mt) end + +---@class parser.object +---@field _globalNode vm.global + +---@type table<string, vm.global> +local allGlobals = {} +---@type table<uri, table<string, boolean>> +local globalSubs = util.multiTable(2) + +local compileObject +local compilerGlobalSwitch = util.switch() + : case 'local' + : call(function (source) + if source.special ~= '_G' then + return + end + if source.ref then + for _, ref in ipairs(source.ref) do + compileObject(ref) + end + end + end) + : case 'getlocal' + : call(function (source) + if source.special ~= '_G' then + return + end + if not source.next then + return + end + compileObject(source.next) + end) + : case 'setglobal' + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local global = vm.declareGlobal('variable', name, uri) + global:addSet(uri, source) + source._globalNode = global + end) + : case 'getglobal' + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local global = vm.declareGlobal('variable', name, uri) + global:addGet(uri, source) + source._globalNode = global + + local nxt = source.next + if nxt then + compileObject(nxt) + end + end) + : case 'setfield' + : case 'setmethod' + : case 'setindex' + ---@param source parser.object + : call(function (source) + local name + local keyName = guide.getKeyName(source) + if not keyName then + return + end + if source.node._globalNode then + local parentName = source.node._globalNode:getName() + if parentName == '_G' then + name = keyName + else + name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) + end + elseif source.node.special == '_G' then + name = keyName + end + if not name then + return + end + local uri = guide.getUri(source) + local global = vm.declareGlobal('variable', name, uri) + global:addSet(uri, source) + source._globalNode = global + end) + : case 'getfield' + : case 'getmethod' + : case 'getindex' + ---@param source parser.object + : call(function (source) + local name + local keyName = guide.getKeyName(source) + if not keyName then + return + end + if source.node._globalNode then + local parentName = source.node._globalNode:getName() + if parentName == '_G' then + name = keyName + else + name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName) + end + elseif source.node.special == '_G' then + name = keyName + end + local uri = guide.getUri(source) + local global = vm.declareGlobal('variable', name, uri) + global:addGet(uri, source) + source._globalNode = global + + local nxt = source.next + if nxt then + compileObject(nxt) + end + end) + : case 'call' + : call(function (source) + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + if not source.args then + return + end + local g = source.args[1] + local key = source.args[2] + if g and key and g.special == '_G' then + local name = guide.getKeyName(key) + if name then + local uri = guide.getUri(source) + local global = vm.declareGlobal('variable', name, uri) + if source.node.special == 'rawset' then + global:addSet(uri, source) + source.value = source.args[3] + else + global:addGet(uri, source) + end + source._globalNode = global + + local nxt = source.next + if nxt then + compileObject(nxt) + end + end + end + end + end) + : case 'doc.class' + ---@param source parser.object + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local class = vm.declareGlobal('type', name, uri) + class:addSet(uri, source) + source._globalNode = class + + if source.signs then + source._sign = vm.createSign() + for _, sign in ipairs(source.signs) do + source._sign:addSign(vm.compileNode(sign)) + end + if source.extends then + for _, ext in ipairs(source.extends) do + if ext.type == 'doc.type.table' then + ext._generic = vm.createGeneric(ext, source._sign) + end + end + end + end + end) + : case 'doc.alias' + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + local alias = vm.declareGlobal('type', name, uri) + alias:addSet(uri, source) + source._globalNode = alias + + if source.signs then + source._sign = vm.createSign() + for _, sign in ipairs(source.signs) do + source._sign:addSign(vm.compileNode(sign)) + end + source.extends._generic = vm.createGeneric(source.extends, source._sign) + end + end) + : case 'doc.type.name' + : call(function (source) + local uri = guide.getUri(source) + local name = source[1] + local type = vm.declareGlobal('type', name, uri) + type:addGet(uri, source) + source._globalNode = type + end) + : case 'doc.extends.name' + : call(function (source) + local uri = guide.getUri(source) + local name = source[1] + local class = vm.declareGlobal('type', name, uri) + class:addGet(uri, source) + source._globalNode = class + end) + + +---@alias vm.global.cate '"variable"' | '"type"' + +---@param cate vm.global.cate +---@param name string +---@param uri? uri +---@return vm.global +function vm.declareGlobal(cate, name, uri) + local key = cate .. '|' .. name + if uri then + globalSubs[uri][key] = true + end + if not allGlobals[key] then + allGlobals[key] = createGlobal(name, cate) + end + return allGlobals[key] +end + +---@param cate vm.global.cate +---@param name string +---@param field? string +---@return vm.global? +function vm.getGlobal(cate, name, field) + local key = cate .. '|' .. name + if field then + key = key .. vm.ID_SPLITE .. field + end + return allGlobals[key] +end + +---@param cate vm.global.cate +---@param name string +---@return vm.global[] +function vm.getGlobalFields(cate, name) + local globals = {} + local key = cate .. '|' .. name + + local clock = os.clock() + for gid, global in pairs(allGlobals) do + if gid ~= key + and util.stringStartWith(gid, key) + and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE + and not gid:find(vm.ID_SPLITE, #key + 2) then + globals[#globals+1] = global + end + end + local cost = os.clock() - clock + if cost > 0.1 then + log.warn('global-manager getFields cost %.3f', cost) + end + + return globals +end + +---@param cate vm.global.cate +---@return vm.global[] +function vm.getGlobals(cate) + local globals = {} + + local clock = os.clock() + for gid, global in pairs(allGlobals) do + if util.stringStartWith(gid, cate) + and not gid:find(vm.ID_SPLITE) then + globals[#globals+1] = global + end + end + local cost = os.clock() - clock + if cost > 0.1 then + log.warn('global-manager getGlobals cost %.3f', cost) + end + + return globals +end + +---@param suri uri +---@param cate vm.global.cate +---@return parser.object[] +function vm.getGlobalSets(suri, cate) + local globals = vm.getGlobals(cate) + local result = {} + for _, global in ipairs(globals) do + local sets = global:getSets(suri) + for _, set in ipairs(sets) do + result[#result+1] = set + end + end + return result +end + +---@param suri uri +---@param cate vm.global.cate +---@param name string +---@return boolean +function vm.hasGlobalSets(suri, cate, name) + local global = vm.getGlobal(cate, name) + if not global then + return false + end + local sets = global:getSets(suri) + if #sets == 0 then + return false + end + return true +end + +---@param source parser.object +function compileObject(source) + if source._globalNode ~= nil then + return + end + source._globalNode = false + compilerGlobalSwitch(source.type, source) +end + +---@param source parser.object +local function compileSelf(source) + if source.parent.type ~= 'funcargs' then + return + end + ---@type parser.object + local node = source.parent.parent and source.parent.parent.parent and source.parent.parent.parent.node + if not node then + return + end + local fields = vm.getLocalFields(source) + if not fields then + return + end + local nodeLocalID = vm.getLocalID(node) + local globalNode = node._globalNode + if not nodeLocalID and not globalNode then + return + end + for _, field in ipairs(fields) do + if field.type == 'setfield' then + local key = guide.getKeyName(field) + if key then + if nodeLocalID then + local myID = nodeLocalID .. vm.ID_SPLITE .. key + vm.insertLocalID(myID, field) + end + if globalNode then + local myID = globalNode:getName() .. vm.ID_SPLITE .. key + local myGlobal = vm.declareGlobal('variable', myID, guide.getUri(node)) + myGlobal:addSet(guide.getUri(node), field) + end + end + end + end +end + +---@param source parser.object +local function compileAst(source) + local env = guide.getENV(source) + if not env then + return + end + compileObject(env) + guide.eachSpecialOf(source, 'rawset', function (src) + compileObject(src.parent) + end) + guide.eachSpecialOf(source, 'rawget', function (src) + compileObject(src.parent) + end) + guide.eachSourceTypes(source.docs, { + 'doc.class', + 'doc.alias', + 'doc.type.name', + 'doc.extends.name', + }, function (src) + compileObject(src) + end) + + --[[ + local mt + function mt:xxx() + self.a = 1 + end + + mt.a --> find this definition + ]] + guide.eachSourceType(source, 'self', function (src) + compileSelf(src) + end) +end + +---@param uri uri +local function dropUri(uri) + local globalSub = globalSubs[uri] + globalSubs[uri] = nil + for key in pairs(globalSub) do + local global = allGlobals[key] + if global then + global:dropUri(uri) + if not global:isAlive() then + allGlobals[key] = nil + end + end + end +end + +for uri in files.eachFile() do + local state = files.getState(uri) + if state then + compileAst(state.ast) + end +end + +files.watch(function (ev, uri) + if ev == 'update' then + dropUri(uri) + local state = files.getState(uri) + if state then + compileAst(state.ast) + end + end + if ev == 'remove' then + dropUri(uri) + end +end) diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 2a64ed52..fabc9828 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -1,11 +1,9 @@ local util = require 'utility' local config = require 'config' local guide = require 'parser.guide' +---@class vm local vm = require 'vm.vm' ----@class vm.infer-manager -local m = {} - ---@class vm.infer ---@field views table<string, boolean> ---@field cachedView? string @@ -21,7 +19,7 @@ mt._hasDocFunction = false mt._isParam = false mt._isLocal = false -m.NULL = setmetatable({}, mt) +vm.NULL = setmetatable({}, mt) local inferSorted = { ['boolean'] = - 100, @@ -52,7 +50,7 @@ local viewNodeSwitch = util.switch() : call(function (source, infer) if source.type == 'table' then if #source == 1 and source[1].type == 'varargs' then - local node = m.getInfer(source[1]):view() + local node = vm.getInfer(source[1]):view() return ('%s[]'):format(node) end end @@ -90,7 +88,7 @@ local viewNodeSwitch = util.switch() if source.signs then local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = m.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view() end return ('%s<%s>'):format(source[1], table.concat(buf, ', ')) else @@ -99,7 +97,7 @@ local viewNodeSwitch = util.switch() end) : case 'generic' : call(function (source, infer) - return m.getInfer(source.proto):view() + return vm.getInfer(source.proto):view() end) : case 'doc.generic.name' : call(function (source, infer) @@ -108,7 +106,7 @@ local viewNodeSwitch = util.switch() : case 'doc.type.array' : call(function (source, infer) infer._hasClass = true - local view = m.getInfer(source.node):view() + local view = vm.getInfer(source.node):view() if source.node.type == 'doc.type' then view = '(' .. view .. ')' end @@ -119,7 +117,7 @@ local viewNodeSwitch = util.switch() infer._hasClass = true local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = m.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view() end return ('%s<%s>'):format(source.node[1], table.concat(buf, ', ')) end) @@ -144,20 +142,23 @@ local viewNodeSwitch = util.switch() local argView = '' local regView = '' for i, arg in ipairs(source.args) do + local argNode = vm.compileNode(arg) + local isOptional = argNode:isOptional() + if isOptional then + argNode = argNode:copy() + argNode:removeOptional() + end args[i] = string.format('%s%s: %s' , arg.name[1] - , arg.optional and '?' or '' - , m.getInfer(arg):view() + , isOptional and '?' or '' + , vm.getInfer(argNode):view() ) end if #args > 0 then argView = table.concat(args, ', ') end for i, ret in ipairs(source.returns) do - rets[i] = string.format('%s%s' - , m.getInfer(ret):view() - , ret.optional and '?' or '' - ) + rets[i] = vm.getInfer(ret):view() end if #rets > 0 then regView = ':' .. table.concat(rets, ', ') @@ -165,16 +166,21 @@ local viewNodeSwitch = util.switch() return ('fun(%s)%s'):format(argView, regView) end) ----@param source parser.object +---@param source parser.object | vm.node ---@return vm.infer -function m.getInfer(source) - local node = vm.compileNode(source) +function vm.getInfer(source) + local node + if source.type == 'vm.node' then + node = source + else + node = vm.compileNode(source) + end if node.lastInfer then return node.lastInfer end local infer = setmetatable({ node = node, - uri = guide.getUri(source), + uri = source.type ~= 'vm.node' and guide.getUri(source), }, mt) node.lastInfer = infer @@ -199,24 +205,24 @@ function mt:_trim() if self._hasTable and not self._hasClass then self.views['table'] = true end - if self._hasClass then - self:_eraseAlias() - end end -function mt:_eraseAlias() - local expandAlias = config.get(self.uri, 'Lua.hover.expandAlias') +---@param uri uri +---@return table<string, true> +function mt:_eraseAlias(uri) + local drop = {} + local expandAlias = config.get(uri, 'Lua.hover.expandAlias') for n in self.node:eachObject() do if n.type == 'global' and n.cate == 'type' then - for _, set in ipairs(n:getSets(self.uri)) do + for _, set in ipairs(n:getSets(uri)) do if set.type == 'doc.alias' then if expandAlias then - self.views[n.name] = nil + drop[n.name] = true else for _, ext in ipairs(set.extends.types) do local view = viewNodeSwitch(ext.type, ext, {}) if view and view ~= n.name then - self.views[view] = nil + drop[view] = true end end end @@ -224,6 +230,7 @@ function mt:_eraseAlias() end end end + return drop end ---@param tp string @@ -273,17 +280,16 @@ function mt:view(default, uri) return 'any' end - if not next(self.views) then - return default or 'unknown' - end - - if self.cachedView then - return self.cachedView + local drop + if self._hasClass then + drop = self:_eraseAlias(uri or self.uri) end local array = {} for view in pairs(self.views) do - array[#array+1] = view + if not drop or not drop[view] then + array[#array+1] = view + end end table.sort(array, function (a, b) @@ -298,22 +304,29 @@ function mt:view(default, uri) local max = #array local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') - if max > limit then - local view = string.format('%s...(+%d)' - , table.concat(array, '|', 1, limit) - , max - limit - ) - - self.cachedView = view - - return view + local view + if #array == 0 then + view = default or 'unknown' else - local view = table.concat(array, '|') - - self.cachedView = view + if max > limit then + view = string.format('%s...(+%d)' + , table.concat(array, '|', 1, limit) + , max - limit + ) + else + view = table.concat(array, '|') + end + end - return view + if self.node:isOptional() then + if max > 1 then + view = '(' .. view .. ')?' + else + view = view .. '?' + end end + + return view end function mt:eachView() @@ -324,10 +337,10 @@ end ---@param other vm.infer ---@return vm.infer function mt:merge(other) - if self == m.NULL then + if self == vm.NULL then return other end - if other == m.NULL then + if other == vm.NULL then return self end @@ -390,8 +403,6 @@ end ---@param source parser.object ---@return string? -function m.viewObject(source) +function vm.viewObject(source) return viewNodeSwitch(source.type, source, {}) end - -return m diff --git a/script/vm/init.lua b/script/vm/init.lua index 0058c698..f5003c11 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -1,4 +1,7 @@ local vm = require 'vm.vm' + +---@alias vm.object parser.object | vm.global | vm.generic + require 'vm.compiler' require 'vm.value' require 'vm.node' @@ -8,5 +11,10 @@ require 'vm.field' require 'vm.doc' require 'vm.type' require 'vm.library' -require 'vm.manager' +require 'vm.runner' +require 'vm.infer' +require 'vm.generic' +require 'vm.sign' +require 'vm.local-id' +require 'vm.global' return vm diff --git a/script/vm/library.lua b/script/vm/library.lua index 49f7adb0..e7bf4f42 100644 --- a/script/vm/library.lua +++ b/script/vm/library.lua @@ -13,24 +13,3 @@ function vm.getLibraryName(source) end return nil end - -local globalLibraryNames = { - 'arg', 'assert', 'error', 'collectgarbage', 'dofile', '_G', 'getfenv', - 'getmetatable', 'ipairs', 'load', 'loadfile', 'loadstring', - 'module', 'next', 'pairs', 'pcall', 'print', 'rawequal', - 'rawget', 'rawlen', 'rawset', 'select', 'setfenv', - 'setmetatable', 'tonumber', 'tostring', 'type', '_VERSION', - 'warn', 'xpcall', 'require', 'unpack', 'bit32', 'coroutine', - 'debug', 'io', 'math', 'os', 'package', 'string', 'table', - 'utf8', 'newproxy', -} -local globalLibraryNamesMap -function vm.isGlobalLibraryName(name) - if not globalLibraryNamesMap then - globalLibraryNamesMap = {} - for _, v in ipairs(globalLibraryNames) do - globalLibraryNamesMap[v] = true - end - end - return globalLibraryNamesMap[name] or false -end diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua index 728de301..80c68769 100644 --- a/script/vm/local-id.lua +++ b/script/vm/local-id.lua @@ -1,13 +1,13 @@ local util = require 'utility' local guide = require 'parser.guide' +---@class vm local vm = require 'vm.vm' ---@class parser.object ---@field _localID string ---@field _localIDs table<string, parser.object[]> ----@class vm.local-id -local m = {} +local compileLocalID, getLocal local compileSwitch = util.switch() : case 'local' @@ -18,13 +18,13 @@ local compileSwitch = util.switch() return end for _, ref in ipairs(source.ref) do - m.compileLocalID(ref) + compileLocalID(ref) end end) : case 'getlocal' : call(function (source) source._localID = ('%d'):format(source.node.start) - m.compileLocalID(source.next) + compileLocalID(source.next) end) : case 'getfield' : case 'setfield' @@ -40,7 +40,7 @@ local compileSwitch = util.switch() source._localID = parentID .. vm.ID_SPLITE .. key source.field._localID = source._localID if source.type == 'getfield' then - m.compileLocalID(source.next) + compileLocalID(source.next) end end) : case 'getmethod' @@ -57,7 +57,7 @@ local compileSwitch = util.switch() source._localID = parentID .. vm.ID_SPLITE .. key source.method._localID = source._localID if source.type == 'getmethod' then - m.compileLocalID(source.next) + compileLocalID(source.next) end end) : case 'getindex' @@ -74,7 +74,7 @@ local compileSwitch = util.switch() source._localID = parentID .. vm.ID_SPLITE .. key source.index._localID = source._localID if source.type == 'setindex' then - m.compileLocalID(source.next) + compileLocalID(source.next) end end) @@ -82,7 +82,7 @@ local leftSwitch = util.switch() : case 'field' : case 'method' : call(function (source) - return m.getLocal(source.parent) + return getLocal(source.parent) end) : case 'getfield' : case 'setfield' @@ -91,24 +91,36 @@ local leftSwitch = util.switch() : case 'getindex' : case 'setindex' : call(function (source) - return m.getLocal(source.node) + return getLocal(source.node) end) : case 'getlocal' : call(function (source) return source.node end) : case 'local' + : case 'self' : call(function (source) return source end) ---@param source parser.object ---@return parser.object? -function m.getLocal(source) +function getLocal(source) return leftSwitch(source.type, source) end -function m.compileLocalID(source) +---@param id string +---@param source parser.object +function vm.insertLocalID(id, source) + local root = guide.getRoot(source) + if not root._localIDs then + root._localIDs = util.multiTable(2) + end + local sources = root._localIDs[id] + sources[#sources+1] = source +end + +function compileLocalID(source) if not source then return end @@ -117,37 +129,33 @@ function m.compileLocalID(source) return end compileSwitch(source.type, source) - if not source._localID then + local id = source._localID + if not id then return end - local root = guide.getRoot(source) - if not root._localIDs then - root._localIDs = util.multiTable(2) - end - local sources = root._localIDs[source._localID] - sources[#sources+1] = source + vm.insertLocalID(id, source) end ---@param source parser.object ----@return string|boolean -function m.getID(source) +---@return string? +function vm.getLocalID(source) if source._localID ~= nil then return source._localID end source._localID = false - local loc = m.getLocal(source) + local loc = getLocal(source) if not loc then return source._localID end - m.compileLocalID(loc) + compileLocalID(loc) return source._localID end ---@param source parser.object ---@param key? string ---@return parser.object[]? -function m.getSources(source, key) - local id = m.getID(source) +function vm.getLocalSources(source, key) + local id = vm.getLocalID(source) if not id then return nil end @@ -166,8 +174,8 @@ end ---@param source parser.object ---@return parser.object[] -function m.getFields(source) - local id = m.getID(source) +function vm.getLocalFields(source) + local id = vm.getLocalID(source) if not id then return nil end @@ -195,5 +203,3 @@ function m.getFields(source) end return fields end - -return m diff --git a/script/vm/local-manager.lua b/script/vm/local-manager.lua deleted file mode 100644 index 51bafb24..00000000 --- a/script/vm/local-manager.lua +++ /dev/null @@ -1,40 +0,0 @@ -local util = require 'utility' -local guide = require 'parser.guide' - ----@class vm.local-node -local m = {} ----@type table<uri, parser.object[]> -m.locals = util.multiTable(2) ----@type table<parser.object, table<parser.object, boolean>> -m.localSubs = util.multiTable(2, function () - return setmetatable({}, util.MODE_K) -end) ----@type table<parser.object, boolean> -m.allLocals = {} - ----@param source parser.object -function m.declareLocal(source) - if m.allLocals[source] then - return - end - m.allLocals[source] = true - local uri = guide.getUri(source) - local locals = m.locals[uri] - locals[#locals+1] = source -end - ----@param uri uri -function m.dropUri(uri) - local locals = m.locals[uri] - m.locals[uri] = nil - for _, loc in ipairs(locals) do - m.allLocals[loc] = nil - local localSubs = m.localSubs[loc] - m.localSubs[loc] = nil - for source in pairs(localSubs) do - source._node = nil - end - end -end - -return m diff --git a/script/vm/manager.lua b/script/vm/manager.lua deleted file mode 100644 index 58255fca..00000000 --- a/script/vm/manager.lua +++ /dev/null @@ -1,26 +0,0 @@ - -local files = require 'files' -local globalManager = require 'vm.global-manager' -local localManager = require 'vm.local-manager' - ----@alias vm.object parser.object | vm.global | vm.generic - ----@class vm.state -local m = {} - -files.watch(function (ev, uri) - if ev == 'update' then - globalManager.dropUri(uri) - localManager.dropUri(uri) - local state = files.getState(uri) - if state then - globalManager.compileAst(state.ast) - end - end - if ev == 'remove' then - globalManager.dropUri(uri) - localManager.dropUri(uri) - end -end) - -return m diff --git a/script/vm/node.lua b/script/vm/node.lua index 6906da7e..e76542aa 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -1,5 +1,4 @@ local files = require 'files' -local localMgr = require 'vm.local-manager' ---@class vm local vm = require 'vm.vm' local ws = require 'workspace.workspace' @@ -8,15 +7,14 @@ local ws = require 'workspace.workspace' vm.nodeCache = {} ---@class vm.node +---@field [integer] vm.object local mt = {} mt.__index = mt +mt.id = 0 mt.type = 'vm.node' mt.optional = nil mt.lastInfer = nil mt.data = nil ----@type vm.node[] -mt._childs = nil -mt._locked = false ---@param node vm.node | vm.object function mt:merge(node) @@ -30,20 +28,10 @@ function mt:merge(node) if node:isOptional() then self.optional = true end - if node._locked then - if not self._childs then - self._childs = {} - end - if not self._childs[node] then - self._childs[#self._childs+1] = node - self._childs[node] = true - end - else - for _, obj in ipairs(node) do - if not self[obj] then - self[obj] = true - self[#self+1] = obj - end + for _, obj in ipairs(node) do + if not self[obj] then + self[obj] = true + self[#self+1] = obj end end else @@ -54,84 +42,25 @@ function mt:merge(node) end end -function mt:_each(mark, callback) - if mark[self] then - return - end - mark[self] = true - for i = 1, #self do - callback(self[i]) - end - local childs = self._childs - if not childs then - return - end - for i = 1, #childs do - local child = childs[i] - if not child:isLocked() then - child:_each(mark, callback) - end - end -end - -function mt:_expand() - local childs = self._childs - if not childs then - return - end - self._childs = nil - - local mark = {} - mark[self] = true - - local function insert(obj) - if not self[obj] then - self[obj] = true - self[#self+1] = obj - end - end - - for i = 1, #childs do - local child = childs[i] - if child:isLocked() then - if not self._childs then - self._childs = {} - end - if not self._childs[child] then - self._childs[#self._childs+1] = child - self._childs[child] = true - end - else - child:_each(mark, insert) - end - end -end - ---@return boolean function mt:isEmpty() - self:_expand() return #self == 0 end +function mt:clear() + self.optional = nil + for i, c in ipairs(self) do + self[i] = nil + self[c] = nil + end +end + ---@param n integer ---@return vm.object? function mt:get(n) - self:_expand() return self[n] end -function mt:lock() - self._locked = true -end - -function mt:unlock() - self._locked = false -end - -function mt:isLocked() - return self._locked == true -end - function mt:setData(k, v) if not self.data then self.data = {} @@ -147,49 +76,143 @@ function mt:getData(k) end function mt:addOptional() - if self:isOptional() then - return self - end self.optional = true end function mt:removeOptional() - if not self:isOptional() then - return self - end - self:_expand() - for i = #self, 1, -1 do - local n = self[i] - if n.type == 'nil' - or (n.type == 'boolean' and n[1] == false) - or (n.type == 'doc.type.boolean' and n[1] == false) then - self[i] = self[#self] - self[#self] = nil - end - end + self:remove 'nil' end ---@return boolean function mt:isOptional() - if self.optional ~= nil then - return self.optional + return self.optional == true +end + +---@return boolean +function mt:hasFalsy() + if self.optional then + return true end - self:_expand() for _, c in ipairs(self) do if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'false') or (c.type == 'boolean' and c[1] == false) or (c.type == 'doc.type.boolean' and c[1] == false) then - self.optional = true return true end end - self.optional = false return false end +---@return boolean +function mt:isNullable() + if self.optional then + return true + end + if #self == 0 then + return true + end + for _, c in ipairs(self) do + if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'any') then + return true + end + end + return false +end + +---@return vm.node +function mt:setTruthy() + if self.optional == true then + self.optional = nil + end + local hasBoolean + for index = #self, 1, -1 do + local c = self[index] + if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'false') + or (c.type == 'boolean' and c[1] == false) + or (c.type == 'doc.type.boolean' and c[1] == false) then + table.remove(self, index) + self[c] = nil + goto CONTINUE + end + if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean') + or (c.type == 'boolean' or c.type == 'doc.type.boolean') then + hasBoolean = true + table.remove(self, index) + self[c] = nil + goto CONTINUE + end + ::CONTINUE:: + end + if hasBoolean then + self[#self+1] = vm.declareGlobal('type', 'true') + end +end + +---@return vm.node +function mt:setFalsy() + if self.optional == false then + self.optional = nil + end + local hasBoolean + for index = #self, 1, -1 do + local c = self[index] + if c.type == 'nil' + or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') + or (c.type == 'global' and c.cate == 'type' and c.name == 'false') + or (c.type == 'boolean' and c[1] == true) + or (c.type == 'doc.type.boolean' and c[1] == true) then + goto CONTINUE + end + if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean') + or (c.type == 'boolean' or c.type == 'doc.type.boolean') then + hasBoolean = true + table.remove(self, index) + self[c] = nil + end + ::CONTINUE:: + end + if hasBoolean then + self[#self+1] = vm.declareGlobal('type', 'false') + end +end + +---@param name string +function mt:remove(name) + if name == 'nil' and self.optional == true then + self.optional = nil + end + for index = #self, 1, -1 do + local c = self[index] + if (c.type == 'global' and c.cate == 'type' and c.name == name) + or (c.type == name) + or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer')) + or (c.type == 'doc.type.boolean' and name == 'boolean') + or (c.type == 'doc.type.table' and name == 'table') + or (c.type == 'doc.type.array' and name == 'table') + or (c.type == 'doc.type.function' and name == 'function') then + table.remove(self, index) + self[c] = nil + end + end +end + +---@param node vm.node +function mt:removeNode(node) + for _, c in ipairs(node) do + if c.type == 'global' and c.cate == 'type' then + self:remove(c.name) + end + end +end + ---@return fun():vm.object function mt:eachObject() - self:_expand() local i = 0 return function () i = i + 1 @@ -197,12 +220,21 @@ function mt:eachObject() end end ----@param source parser.object | vm.generic +---@return vm.node +function mt:copy() + return vm.createNode(self) +end + +---@param source vm.object ---@param node vm.node | vm.object ---@param cover? boolean function vm.setNode(source, node, cover) if not node then - error('Can not set nil node') + if TEST then + error('Can not set nil node') + else + log.error('Can not set nil node') + end end if source.type == 'global' then error('Can not set node to global') @@ -216,13 +248,14 @@ function vm.setNode(source, node, cover) me:merge(node) else if node.type == 'vm.node' then - vm.nodeCache[source] = node + vm.nodeCache[source] = node:copy() else vm.nodeCache[source] = vm.createNode(node) end end end +---@param source vm.object ---@return vm.node? function vm.getNode(source) return vm.nodeCache[source] @@ -256,11 +289,16 @@ function vm.clearNodeCache() vm.nodeCache = {} end +local ID = 0 + ---@param a? vm.node | vm.object ---@param b? vm.node | vm.object ---@return vm.node function vm.createNode(a, b) - local node = setmetatable({}, mt) + ID = ID + 1 + local node = setmetatable({ + id = ID, + }, mt) if a then node:merge(a) end diff --git a/script/vm/ref.lua b/script/vm/ref.lua index 65e8fdab..545c294a 100644 --- a/script/vm/ref.lua +++ b/script/vm/ref.lua @@ -2,8 +2,6 @@ local vm = require 'vm.vm' local util = require 'utility' local guide = require 'parser.guide' -local localID = require 'vm.local-id' -local globalMgr = require 'vm.global-manager' local files = require 'files' local await = require 'await' local progress = require 'progress' @@ -242,7 +240,7 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchByLocalID(source, pushResult) - local idSources = localID.getSources(source) + local idSources = vm.getLocalSources(source) if not idSources then return end @@ -291,7 +289,7 @@ end ---@async ---@param source parser.object ----@param fileNotify fun(uri: uri): boolean +---@param fileNotify? fun(uri: uri): boolean function vm.getRefs(source, fileNotify) local results = {} local mark = {} diff --git a/script/vm/runner.lua b/script/vm/runner.lua new file mode 100644 index 00000000..9fe0f172 --- /dev/null +++ b/script/vm/runner.lua @@ -0,0 +1,444 @@ +---@class vm +local vm = require 'vm.vm' +local guide = require 'parser.guide' + +---@class vm.runner +---@field loc parser.object +---@field mainBlock parser.object +---@field blocks table<parser.object, true> +---@field steps vm.runner.step[] +local mt = {} +mt.__index = mt +mt.index = 1 + +---@class parser.object +---@field _casts parser.object[] + +---@class vm.runner.step +---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast' +---@field pos integer +---@field order? integer +---@field node? vm.node +---@field object? parser.object +---@field name? string +---@field cast? parser.object +---@field tag? string +---@field copy? boolean +---@field new? boolean +---@field ref1? vm.runner.step +---@field ref2? vm.runner.step + +---@param filter parser.object +---@param outStep vm.runner.step +---@param blockStep vm.runner.step +function mt:_compileNarrowByFilter(filter, outStep, blockStep) + if not filter then + return + end + if filter.type == 'paren' then + if filter.exp then + self:_compileNarrowByFilter(filter.exp, outStep, blockStep) + end + return + end + if filter.type == 'unary' then + if not filter.op + or not filter[1] then + return + end + if filter.op.type == 'not' then + local exp = filter[1] + if exp.type == 'getlocal' and exp.node == self.loc then + self.steps[#self.steps+1] = { + type = 'falsy', + pos = filter.finish, + new = true, + } + self.steps[#self.steps+1] = { + type = 'truthy', + pos = filter.finish, + ref1 = outStep, + } + end + end + elseif filter.type == 'binary' then + if not filter.op + or not filter[1] + or not filter[2] then + return + end + if filter.op.type == 'and' then + local dummyStep = { + type = 'save', + copy = true, + ref1 = outStep, + pos = filter.start - 1, + } + self.steps[#self.steps+1] = dummyStep + self:_compileNarrowByFilter(filter[1], dummyStep, blockStep) + self:_compileNarrowByFilter(filter[2], dummyStep, blockStep) + end + if filter.op.type == 'or' then + self:_compileNarrowByFilter(filter[1], outStep, blockStep) + local dummyStep = { + type = 'push', + copy = true, + ref1 = outStep, + pos = filter.op.finish, + } + self.steps[#self.steps+1] = dummyStep + self:_compileNarrowByFilter(filter[2], outStep, dummyStep) + self.steps[#self.steps+1] = { + type = 'push', + tag = 'or reset', + ref1 = blockStep, + pos = filter.finish, + } + end + if filter.op.type == '==' + or filter.op.type == '~=' then + local loc, exp + for i = 1, 2 do + loc = filter[i] + if loc.type == 'getlocal' and loc.node == self.loc then + exp = filter[i % 2 + 1] + break + end + end + if not loc or not exp then + return + end + if guide.isLiteral(exp) then + if filter.op.type == '==' then + self.steps[#self.steps+1] = { + type = 'remove', + name = exp.type, + pos = filter.finish, + ref1 = outStep, + } + self.steps[#self.steps+1] = { + type = 'as', + name = exp.type, + pos = filter.finish, + new = true, + } + end + if filter.op.type == '~=' then + self.steps[#self.steps+1] = { + type = 'as', + name = exp.type, + pos = filter.finish, + ref1 = outStep, + } + self.steps[#self.steps+1] = { + type = 'remove', + name = exp.type, + pos = filter.finish, + new = true, + } + end + end + end + else + if filter.type == 'getlocal' and filter.node == self.loc then + self.steps[#self.steps+1] = { + type = 'truthy', + pos = filter.finish, + new = true, + } + self.steps[#self.steps+1] = { + type = 'falsy', + pos = filter.finish, + ref1 = outStep, + } + end + end +end + +---@param block parser.object +function mt:_compileBlock(block) + if self.blocks[block] then + return + end + self.blocks[block] = true + if block == self.mainBlock then + return + end + + local parentBlock = guide.getParentBlock(block) + self:_compileBlock(parentBlock) + + if block.type == 'if' then + ---@type vm.runner.step[] + local finals = {} + for i, childBlock in ipairs(block) do + local blockStep = { + type = 'save', + tag = 'block', + copy = true, + pos = childBlock.start, + } + local outStep = { + type = 'save', + tag = 'out', + copy = true, + pos = childBlock.start, + } + self.steps[#self.steps+1] = blockStep + self.steps[#self.steps+1] = outStep + self.steps[#self.steps+1] = { + type = 'push', + ref1 = blockStep, + pos = childBlock.start, + } + self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep) + if not childBlock.hasReturn + and not childBlock.hasGoTo + and not childBlock.hasBreak then + local finalStep = { + type = 'save', + pos = childBlock.finish, + tag = 'final #' .. i, + } + finals[#finals+1] = finalStep + self.steps[#self.steps+1] = finalStep + end + self.steps[#self.steps+1] = { + type = 'push', + tag = 'reset child', + ref1 = outStep, + pos = childBlock.finish, + } + end + self.steps[#self.steps+1] = { + type = 'push', + tag = 'reset if', + pos = block.finish, + copy = true, + } + for _, final in ipairs(finals) do + self.steps[#self.steps+1] = { + type = 'merge', + ref2 = final, + pos = block.finish, + } + end + end + + if block.type == 'function' + or block.type == 'while' + or block.type == 'loop' + or block.type == 'in' + or block.type == 'repeat' + or block.type == 'for' then + local savePoint = { + type = 'save', + copy = true, + pos = block.start, + } + self.steps[#self.steps+1] = { + type = 'push', + copy = true, + pos = block.start, + } + self.steps[#self.steps+1] = savePoint + self.steps[#self.steps+1] = { + type = 'push', + pos = block.finish, + ref1 = savePoint, + } + end +end + +---@return parser.object[] +function mt:_getCasts() + local root = guide.getRoot(self.loc) + if not root._casts then + root._casts = {} + local docs = root.docs + for _, doc in ipairs(docs) do + if doc.type == 'doc.cast' and doc.loc then + root._casts[#root._casts+1] = doc + end + end + end + return root._casts +end + +function mt:_preCompile() + local startPos = self.loc.start + local finishPos = 0 + + for _, ref in ipairs(self.loc.ref) do + self.steps[#self.steps+1] = { + type = 'object', + object = ref, + pos = ref.range or ref.start, + } + if ref.start > finishPos then + finishPos = ref.start + end + local block = guide.getParentBlock(ref) + self:_compileBlock(block) + end + + for i, step in ipairs(self.steps) do + if step.type ~= 'object' then + step.order = i + end + end + + local casts = self:_getCasts() + for _, cast in ipairs(casts) do + if cast.loc[1] == self.loc[1] + and cast.start > startPos + and cast.finish < finishPos + and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then + self.steps[#self.steps+1] = { + type = 'cast', + cast = cast, + pos = cast.start, + } + end + end + + table.sort(self.steps, function (a, b) + if a.pos == b.pos then + return (a.order or 0) < (b.order or 0) + else + return a.pos < b.pos + end + end) +end + +---@param loc parser.object +---@param node vm.node +---@return vm.node +local function checkAssert(loc, node) + local parent = loc.parent + if parent.type == 'binary' then + if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then + local exp + for i = 1, 2 do + if parent[i] == loc then + exp = parent[i % 2 + 1] + end + end + if exp and guide.isLiteral(exp) then + local callargs = parent.parent + if callargs.type == 'callargs' + and callargs.parent.node.special == 'assert' + and callargs[1] == parent then + if parent.op.type == '~=' then + node:remove(exp.type) + end + if parent.op.type == '==' then + node = vm.compileNode(exp) + end + end + end + end + end + if parent.type == 'callargs' + and parent.parent.node.special == 'assert' + and parent[1] == loc then + node:setTruthy() + end + return node +end + +---@param callback fun(src: parser.object, node: vm.node) +function mt:launch(callback) + local topNode = vm.getNode(self.loc):copy() + for _, step in ipairs(self.steps) do + local node = step.ref1 and step.ref1.node or topNode + if step.type == 'truthy' then + if step.new then + node = node:copy() + topNode = node + end + node:setTruthy() + elseif step.type == 'falsy' then + if step.new then + node = node:copy() + topNode = node + end + node:setFalsy() + elseif step.type == 'as' then + if step.new then + topNode = vm.createNode(vm.getGlobal('type', step.name)) + else + node:clear() + node:merge(vm.getGlobal('type', step.name)) + end + elseif step.type == 'add' then + if step.new then + node = node:copy() + topNode = node + end + node:merge(vm.getGlobal('type', step.name)) + elseif step.type == 'remove' then + if step.new then + node = node:copy() + topNode = node + end + node:remove(step.name) + elseif step.type == 'object' then + topNode = callback(step.object, node) or node + if step.object.type == 'getlocal' then + topNode = checkAssert(step.object, node) + end + elseif step.type == 'save' then + if step.copy then + node = node:copy() + end + step.node = node + elseif step.type == 'push' then + if step.copy then + node = node:copy() + end + topNode = node + elseif step.type == 'merge' then + node:merge(step.ref2.node) + elseif step.type == 'cast' then + topNode = node:copy() + for _, cast in ipairs(step.cast.casts) do + if cast.mode == '+' then + if cast.optional then + topNode:addOptional() + end + if cast.extends then + topNode:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + topNode:removeOptional() + end + if cast.extends then + topNode:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + topNode:clear() + topNode:merge(vm.compileNode(cast.extends)) + end + end + end + end + end +end + +---@param loc parser.object +---@return vm.runner +function vm.createRunner(loc) + local self = setmetatable({ + loc = loc, + mainBlock = guide.getParentBlock(loc), + blocks = {}, + steps = {}, + }, mt) + + self:_preCompile() + + return self +end diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 2d45a5a7..fe112bc2 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -1,6 +1,6 @@ local guide = require 'parser.guide' +---@class vm local vm = require 'vm.vm' -local infer = require 'vm.infer' ---@class vm.sign ---@field parent parser.object @@ -16,12 +16,12 @@ end ---@param uri uri ---@param args parser.object +---@param removeGeneric true? ---@return table<string, vm.node> -function mt:resolve(uri, args) +function mt:resolve(uri, args, removeGeneric) if not args then return nil end - local globalMgr = require 'vm.global-manager' local resolved = {} ---@param object parser.object @@ -33,7 +33,7 @@ function mt:resolve(uri, args) -- 'number' -> `T` for n in node:eachObject() do if n.type == 'string' then - local type = globalMgr.declareGlobal('type', n[1], guide.getUri(n)) + local type = vm.declareGlobal('type', n[1], guide.getUri(n)) resolved[key] = vm.createNode(type, resolved[key]) end end @@ -48,6 +48,19 @@ function mt:resolve(uri, args) -- number[] -> T[] resolve(object.node, vm.compileNode(n.node)) end + if n.type == 'doc.type.table' then + -- { [integer]: number } -> T[] + local tvalueNode = vm.getTableValue(uri, node, 'integer') + if tvalueNode then + resolve(object.node, tvalueNode) + end + end + if n.type == 'global' and n.cate == 'type' then + -- ---@field [integer]: number -> T[] + vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field) + resolve(object.node, vm.compileNode(field.extends)) + end) + end end end if object.type == 'doc.type.table' then @@ -98,7 +111,7 @@ function mt:resolve(uri, args) goto CONTINUE end end - local view = infer.viewObject(obj) + local view = vm.viewObject(obj) if view then knownTypes[view] = true end @@ -114,10 +127,10 @@ function mt:resolve(uri, args) local function buildArgNode(argNode, knownTypes) local newArgNode = vm.createNode() for n in argNode:eachObject() do - if argNode:isOptional() and vm.isFalsy(n) then + if argNode:hasFalsy() then goto CONTINUE end - local view = infer.viewObject(n) + local view = vm.viewObject(n) if knownTypes[view] then goto CONTINUE end @@ -156,7 +169,7 @@ function mt:resolve(uri, args) end ---@return vm.sign -return function () +function vm.createSign() local genericMgr = setmetatable({ signList = {}, }, mt) diff --git a/script/vm/type.lua b/script/vm/type.lua index fa02d19e..c3264993 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -1,4 +1,3 @@ -local globalMgr = require 'vm.global-manager' ---@class vm local vm = require 'vm.vm' @@ -9,10 +8,10 @@ local vm = require 'vm.vm' ---@return boolean function vm.isSubType(uri, child, parent, mark) if type(parent) == 'string' then - parent = vm.createNode(globalMgr.getGlobal('type', parent)) + parent = vm.createNode(vm.getGlobal('type', parent)) end if type(child) == 'string' then - child = vm.createNode(globalMgr.getGlobal('type', child)) + child = vm.createNode(vm.getGlobal('type', child)) end if not child or not parent then @@ -134,7 +133,7 @@ function vm.getTableKey(uri, tnode, vnode) end end if tn.type == 'doc.type.array' then - result:merge(globalMgr.getGlobal('type', 'integer')) + result:merge(vm.declareGlobal('type', 'integer')) end if tn.type == 'table' then for _, field in ipairs(tn) do @@ -144,10 +143,10 @@ function vm.getTableKey(uri, tnode, vnode) end end if field.type == 'tablefield' then - result:merge(globalMgr.getGlobal('type', 'string')) + result:merge(vm.declareGlobal('type', 'string')) end if field.type == 'tableexp' then - result:merge(globalMgr.getGlobal('type', 'integer')) + result:merge(vm.declareGlobal('type', 'integer')) end end end diff --git a/script/vm/value.lua b/script/vm/value.lua index a784be2a..d29ca9d0 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -17,7 +17,16 @@ function vm.test(source) hasTrue = true end if n[1] == false then - hasTrue = false + hasFalse = true + end + end + if n.type == 'global' and n.cate == 'type' then + if n.name == 'true' then + hasTrue = true + end + if n.name == 'false' + or n.name == 'nil' then + hasFalse = true end end if n.type == 'nil' then @@ -41,28 +50,9 @@ function vm.test(source) end end ----@param source parser.object ----@return boolean -function vm.isFalsy(source) - if source.type == 'nil' then - return true - end - if source.type == 'boolean' - or source.type == 'doc.type.boolean' then - return source[1] == false - end - return false -end - ---@param v vm.object ---@return string? local function getUnique(v) - if v.type == 'local' then - return ('loc:%s@%d'):format(guide.getUri(v), v.start) - end - if v.type == 'global' then - return ('%s:%s'):format(v.cate, v.name) - end if v.type == 'boolean' then if v[1] == nil then return false diff --git a/script/vm/vm.lua b/script/vm/vm.lua index 3c1762bf..8117d311 100644 --- a/script/vm/vm.lua +++ b/script/vm/vm.lua @@ -23,6 +23,7 @@ function m.getSpecial(source) return source.special end +---@return string? function m.getKeyName(source) if not source then return nil diff --git a/script/workspace/loading.lua b/script/workspace/loading.lua index f40c08c6..66e0a3aa 100644 --- a/script/workspace/loading.lua +++ b/script/workspace/loading.lua @@ -65,7 +65,7 @@ function mt:checkMaxPreload(uri) end ---@param uri uri ----@param libraryUri boolean +---@param libraryUri? uri ---@async function mt:loadFile(uri, libraryUri) if files.isLua(uri) then diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua index 91923bb8..33f8784d 100644 --- a/script/workspace/workspace.lua +++ b/script/workspace/workspace.lua @@ -68,9 +68,10 @@ local globInteferFace = { type = function (path) local result pcall(function () - if fs.is_directory(fs.path(path)) then + local status = fs.symlink_status(path):type() + if status == 'directory' then result = 'directory' - else + elseif status == 'regular' then result = 'file' end end) @@ -78,7 +79,7 @@ local globInteferFace = { end, list = function (path) local fullPath = fs.path(path) - if not fs.exists(fullPath) then + if fs.symlink_status(fullPath):type() ~= 'directory' then return nil end local paths = {} @@ -332,6 +333,8 @@ function m.findUrisByFilePath(path) return results end +---@param path string +---@return string function m.normalize(path) if not path then return nil |