diff options
author | CppCXY <812125110@qq.com> | 2022-08-11 19:36:36 +0800 |
---|---|---|
committer | CppCXY <812125110@qq.com> | 2022-08-11 19:36:36 +0800 |
commit | ff9103ae4001d8e520171b99cd192997fc689bc9 (patch) | |
tree | 04c0b685e81aac48210604dc12d24b91862a36d9 /script | |
parent | 40f191a85ea21bb64c427f9dab4bc597e2a0ea1b (diff) | |
parent | 82bcfef9037c26681993c94b2f92b68d335de3c6 (diff) | |
download | lua-language-server-ff9103ae4001d8e520171b99cd192997fc689bc9.zip |
Merge branch 'master' of github.com:CppCXY/lua-language-server
Diffstat (limited to 'script')
139 files changed, 11313 insertions, 10249 deletions
diff --git a/script/SDBMHash.lua b/script/SDBMHash.lua new file mode 100644 index 00000000..48728aec --- /dev/null +++ b/script/SDBMHash.lua @@ -0,0 +1,60 @@ +local byte = string.byte +local max = 0x7fffffff + +---@class SDBMHash +local mt = {} +mt.__index = mt + +mt.cache = nil + +---@param str string +---@return integer +function mt:rawHash(str) + local id = 0 + for i = 1, #str do + local b = byte(str, i, i) + id = id * 65599 + b + end + return id & max +end + +---@param str string +---@return integer +function mt:hash(str) + local id = self:rawHash(str) + local other = self.cache[id] + if other == nil or str == other then + self.cache[id] = str + self.cache[str] = id + return id + else + log.warn(('哈希碰撞:[%s] -> [%s]: [%d]'):format(str, other, id)) + for i = 1, max do + local newId = (id + i) % max + if not self.cache[newId] then + self.cache[newId] = str + self.cache[str] = newId + return newId + end + end + error(('哈希碰撞解决失败:[%s] -> [%s]: [%d]'):format(str, other, id)) + end +end + +function mt:setCache(t) + self.cache = t +end + +function mt:getCache() + return self.cache +end + +mt.__call = mt.hash + +---@return SDBMHash +return function () + local self = setmetatable({ + cache = {} + }, mt) + return self +end diff --git a/script/await.lua b/script/await.lua index 4fb81cd8..fa2aea13 100644 --- a/script/await.lua +++ b/script/await.lua @@ -132,9 +132,6 @@ end ---@param callback function ---@async function m.wait(callback, ...) - if not coroutine.isyieldable() then - return - end local co = coroutine.running() local resumed callback(function (...) diff --git a/script/cli/check.lua b/script/cli/check.lua index dd2e7737..4df94c59 100644 --- a/script/cli/check.lua +++ b/script/cli/check.lua @@ -50,14 +50,14 @@ lclient():start(function (client) ws.awaitReady(rootUri) - local disables = config.get(rootUri, 'Lua.diagnostics.disable') + local disables = util.arrayToHash(config.get(rootUri, 'Lua.diagnostics.disable')) for name, serverity in pairs(define.DiagnosticDefaultSeverity) do serverity = config.get(rootUri, 'Lua.diagnostics.severity')[name] or 'Warning' if define.DiagnosticSeverity[serverity] > checkLevel then disables[name] = true end end - config.set(nil, 'Lua.diagnostics.disable', disables) + config.set(nil, 'Lua.diagnostics.disable', util.getTableKeys(disables, true)) local uris = files.getAllUris(rootUri) local max = #uris diff --git a/script/client.lua b/script/client.lua index 2ce29803..7432e60b 100644 --- a/script/client.lua +++ b/script/client.lua @@ -12,6 +12,7 @@ local scope = require 'workspace.scope' local inspect = require 'inspect' local m = {} +m._eventList = {} function m.client(newClient) if newClient then @@ -112,7 +113,7 @@ end ---@param type message.type ---@param message string ---@param titles string[] ----@param callback fun(action: string, index: integer) +---@param callback fun(action?: string, index?: integer) function m.requestMessage(type, message, titles, callback) proto.notify('window/logMessage', { type = define.MessageType[type] or 3, @@ -239,6 +240,9 @@ local function tryModifySpecifiedConfig(uri, finalChanges) return false end local path = workspace.getAbsolutePath(uri, CONFIGPATH) + if not path then + return false + end util.saveFile(path, json.beautify(scp:get('lastLocalConfig'), { indent = ' ' })) return true end @@ -252,7 +256,10 @@ local function tryModifyRC(uri, finalChanges, create) if not path then return false end - path = fs.exists(path) and path or workspace.getAbsolutePath(uri, '.luarc.json') + path = fs.exists(fs.path(path)) and path or workspace.getAbsolutePath(uri, '.luarc.json') + if not path then + return false + end local buf = util.loadFile(path) if not buf and not create then return false @@ -389,8 +396,22 @@ function m.editText(uri, edits) }) end +---@param callback async fun() +function m.event(callback) + m._eventList[#m._eventList+1] = callback +end + +function m._callEvent(ev) + for _, callback in ipairs(m._eventList) do + await.call(function () + callback(ev) + end) + end +end + function m.setReady() m._ready = true + m._callEvent('ready') end function m.isReady() @@ -415,6 +436,7 @@ function m.init(t) lang(LOCALE or t.locale) converter.setOffsetEncoding(m.getOffsetEncoding()) hookPrint() + m._callEvent('init') end return m diff --git a/script/config/config.lua b/script/config/config.lua index bde214b0..70e83fc8 100644 --- a/script/config/config.lua +++ b/script/config/config.lua @@ -1,224 +1,10 @@ -local util = require 'utility' -local define = require 'proto.define' -local timer = require 'timer' -local scope = require 'workspace.scope' +local util = require 'utility' +local timer = require 'timer' +local scope = require 'workspace.scope' +local template = require 'config.template' ---@alias config.source '"client"'|'"path"'|'"local"' ----@class config.unit ----@field caller function -local mt = {} -mt.__index = mt - -function mt:__call(...) - self:caller(...) - return self -end - -function mt:__shr(default) - self.default = default - return self -end - -local units = {} - -local function register(name, default, checker, loader, caller) - units[name] = { - default = default, - checker = checker, - loader = loader, - caller = caller, - } -end - -local Type = setmetatable({}, { __index = function (_, name) - local unit = {} - for k, v in pairs(units[name]) do - unit[k] = v - end - return setmetatable(unit, mt) -end }) - -register('Boolean', false, function (self, v) - return type(v) == 'boolean' -end, function (self, v) - return v -end) - -register('Integer', 0, function (self, v) - return type(v) == 'number' -end, function (self, v) - return math.floor(v) -end) - -register('String', '', function (self, v) - return type(v) == 'string' -end, function (self, v) - return tostring(v) -end) - -register('Nil', nil, function (self, v) - return type(v) == 'nil' -end, function (self, v) - return nil -end) - -register('Array', {}, function (self, value) - return type(value) == 'table' -end, function (self, value) - local t = {} - for _, v in ipairs(value) do - if self.sub:checker(v) then - t[#t+1] = self.sub:loader(v) - end - end - return t -end, function (self, sub) - self.sub = sub -end) - -register('Hash', {}, function (self, value) - if type(value) == 'table' then - if #value == 0 then - for k, v in pairs(value) do - if not self.subkey:checker(k) - or not self.subvalue:checker(v) then - return false - end - end - else - if not self.subvalue:checker(true) then - return false - end - for _, v in ipairs(value) do - if not self.subkey:checker(v) then - return false - end - end - end - return true - end - if type(value) == 'string' then - return self.subkey:checker('') - and self.subvalue:checker(true) - end -end, function (self, value) - if type(value) == 'table' then - local t = {} - if #value == 0 then - for k, v in pairs(value) do - t[k] = v - end - else - for _, k in pairs(value) do - t[k] = true - end - end - return t - end - if type(value) == 'string' then - local t = {} - for s in value:gmatch('[^' .. self.sep .. ']+') do - t[s] = true - end - return t - end -end, function (self, subkey, subvalue, sep) - self.subkey = subkey - self.subvalue = subvalue - self.sep = sep -end) - -register('Or', nil, function (self, value) - for _, sub in ipairs(self.subs) do - if sub:checker(value) then - return true - end - end - return false -end, function (self, value) - for _, sub in ipairs(self.subs) do - if sub:checker(value) then - return sub:loader(value) - end - end -end, function (self, ...) - self.subs = { ... } -end) - -local Template = { - ['Lua.runtime.version'] = Type.String >> 'Lua 5.4', - ['Lua.runtime.path'] = Type.Array(Type.String) >> { - "?.lua", - "?/init.lua", - }, - ['Lua.runtime.pathStrict'] = Type.Boolean >> false, - ['Lua.runtime.special'] = Type.Hash(Type.String, Type.String), - ['Lua.runtime.meta'] = Type.String >> '${version} ${language} ${encoding}', - ['Lua.runtime.unicodeName'] = Type.Boolean, - ['Lua.runtime.nonstandardSymbol'] = Type.Hash(Type.String, Type.Boolean, ';'), - ['Lua.runtime.plugin'] = Type.String, - ['Lua.runtime.fileEncoding'] = Type.String >> 'utf8', - ['Lua.runtime.builtin'] = Type.Hash(Type.String, Type.String), - ['Lua.diagnostics.enable'] = Type.Boolean >> true, - ['Lua.diagnostics.globals'] = Type.Hash(Type.String, Type.Boolean, ';'), - ['Lua.diagnostics.disable'] = Type.Hash(Type.String, Type.Boolean, ';'), - ['Lua.diagnostics.severity'] = Type.Hash(Type.String, Type.String) - >> util.deepCopy(define.DiagnosticDefaultSeverity), - ['Lua.diagnostics.neededFileStatus'] = Type.Hash(Type.String, Type.String) - >> util.deepCopy(define.DiagnosticDefaultNeededFileStatus), - ['Lua.diagnostics.workspaceDelay'] = Type.Integer >> 5, - ['Lua.diagnostics.workspaceRate'] = Type.Integer >> 100, - ['Lua.diagnostics.libraryFiles'] = Type.String >> 'Opened', - ['Lua.diagnostics.ignoredFiles'] = Type.String >> 'Opened', - ['Lua.workspace.ignoreDir'] = Type.Array(Type.String), - ['Lua.workspace.ignoreSubmodules'] = Type.Boolean >> true, - ['Lua.workspace.useGitIgnore'] = Type.Boolean >> true, - ['Lua.workspace.maxPreload'] = Type.Integer >> 3000, - ['Lua.workspace.preloadFileSize'] = Type.Integer >> 500, - ['Lua.workspace.library'] = Type.Hash(Type.String, Type.Boolean, ';'), - ['Lua.workspace.checkThirdParty'] = Type.Boolean >> true, - ['Lua.workspace.userThirdParty'] = Type.Array(Type.String), - ['Lua.completion.enable'] = Type.Boolean >> true, - ['Lua.completion.callSnippet'] = Type.String >> 'Disable', - ['Lua.completion.keywordSnippet'] = Type.String >> 'Replace', - ['Lua.completion.displayContext'] = Type.Integer >> 0, - ['Lua.completion.workspaceWord'] = Type.Boolean >> true, - ['Lua.completion.showWord'] = Type.String >> 'Fallback', - ['Lua.completion.autoRequire'] = Type.Boolean >> true, - ['Lua.completion.showParams'] = Type.Boolean >> true, - ['Lua.completion.requireSeparator'] = Type.String >> '.', - ['Lua.completion.postfix'] = Type.String >> '@', - ['Lua.signatureHelp.enable'] = Type.Boolean >> true, - ['Lua.hover.enable'] = Type.Boolean >> true, - ['Lua.hover.viewString'] = Type.Boolean >> true, - ['Lua.hover.viewStringMax'] = Type.Integer >> 1000, - ['Lua.hover.viewNumber'] = Type.Boolean >> true, - ['Lua.hover.previewFields'] = Type.Integer >> 20, - ['Lua.hover.enumsLimit'] = Type.Integer >> 5, - ['Lua.hover.expandAlias'] = Type.Boolean >> true, - ['Lua.semantic.enable'] = Type.Boolean >> true, - ['Lua.semantic.variable'] = Type.Boolean >> true, - ['Lua.semantic.annotation'] = Type.Boolean >> true, - ['Lua.semantic.keyword'] = Type.Boolean >> false, - ['Lua.hint.enable'] = Type.Boolean >> false, - ['Lua.hint.paramType'] = Type.Boolean >> true, - ['Lua.hint.setType'] = Type.Boolean >> false, - ['Lua.hint.paramName'] = Type.String >> 'All', - ['Lua.hint.await'] = Type.Boolean >> true, - ['Lua.hint.arrayIndex'] = Type.Boolean >> 'Auto', - ['Lua.window.statusBar'] = Type.Boolean >> true, - ['Lua.window.progressBar'] = Type.Boolean >> true, - ['Lua.format.enable'] = Type.Boolean >> true, - ['Lua.format.defaultConfig'] = Type.Hash(Type.String, Type.String) - >> {}, - ['Lua.telemetry.enable'] = Type.Or(Type.Boolean >> false, Type.Nil) >> nil, - ['files.associations'] = Type.Hash(Type.String, Type.String), - ['files.exclude'] = Type.Hash(Type.String, Type.Boolean), - ['editor.semanticHighlighting.enabled'] = Type.Or(Type.Boolean, Type.String), - ['editor.acceptSuggestionOnEnter'] = Type.String >> 'on', -} - ---@class config.api local m = {} m.watchList = {} @@ -241,7 +27,7 @@ local function update(scp, key, nowValue, rawValue) raw[key] = rawValue end ----@param uri uri +---@param uri? uri ---@param key? string ---@return scope local function getScope(uri, key) @@ -252,7 +38,7 @@ local function getScope(uri, key) end end if uri then - ---@type scope + ---@type scope? local scp = scope.getFolder(uri) or scope.getLinkedScope(uri) if scp then if not key @@ -268,7 +54,7 @@ end ---@param key string ---@param value any function m.setByScope(scp, key, value) - local unit = Template[key] + local unit = template[key] if not unit then return false end @@ -284,10 +70,12 @@ function m.setByScope(scp, key, value) return true end ----@param uri uri +---@param uri? uri ---@param key string ---@param value any function m.set(uri, key, value) + local unit = template[key] + assert(unit, 'unknown key: ' .. key) local scp = getScope(uri) local oldValue = m.get(uri, key) m.setByScope(scp, key, value) @@ -300,14 +88,10 @@ function m.set(uri, key, value) end function m.add(uri, key, value) - local unit = Template[key] - if not unit then - return false - end + local unit = template[key] + assert(unit, 'unknown key: ' .. key) local list = m.getRaw(uri, key) - if type(list) ~= 'table' then - return false - end + assert(type(list) == 'table', 'not a list: ' .. key) local copyed = {} for i, v in ipairs(list) do if util.equal(v, value) then @@ -326,15 +110,32 @@ function m.add(uri, key, value) return false end -function m.prop(uri, key, prop, value) - local unit = Template[key] - if not unit then - return false +function m.remove(uri, key, value) + local unit = template[key] + assert(unit, 'unknown key: ' .. key) + local list = m.getRaw(uri, key) + assert(type(list) == 'table', 'not a list: ' .. key) + local copyed = {} + for i, v in ipairs(list) do + if not util.equal(v, value) then + copyed[i] = v + end end - local map = m.getRaw(uri, key) - if type(map) ~= 'table' then - return false + local oldValue = m.get(uri, key) + m.set(uri, key, copyed) + local newValue = m.get(uri, key) + if not util.equal(oldValue, newValue) then + m.event(uri, key, newValue, oldValue) + return true end + return false +end + +function m.prop(uri, key, prop, value) + local unit = template[key] + assert(unit, 'unknown key: ' .. key) + local map = m.getRaw(uri, key) + assert(type(map) == 'table', 'not a map: ' .. key) if util.equal(map[prop], value) then return false end @@ -353,14 +154,14 @@ function m.prop(uri, key, prop, value) return false end ----@param uri uri +---@param uri? uri ---@param key string ---@return any function m.get(uri, key) local scp = getScope(uri, key) local value = m.getNowTable(scp)[key] if value == nil then - value = Template[key].default + value = template[key].default end if value == m.NULL then value = nil @@ -375,7 +176,7 @@ function m.getRaw(uri, key) local scp = getScope(uri, key) local value = m.getRawTable(scp)[key] if value == nil then - value = Template[key].default + value = template[key].default end if value == m.NULL then value = nil @@ -412,9 +213,9 @@ function m.update(scp, ...) if m.nullSymbols[value] then value = m.NULL end - if Template[fullKey] then + if template[fullKey] then m.setByScope(scp, fullKey, value) - elseif Template['Lua.' .. fullKey] then + elseif template['Lua.' .. fullKey] then m.setByScope(scp, 'Lua.' .. fullKey, value) elseif type(value) == 'table' then expand(value, fullKey) @@ -424,7 +225,7 @@ function m.update(scp, ...) local news = table.pack(...) for i = 1, news.n do - if news[i] then + if type(news[i]) == 'table' then expand(news[i]) end end diff --git a/script/config/loader.lua b/script/config/loader.lua index 30711dde..5cc7139f 100644 --- a/script/config/loader.lua +++ b/script/config/loader.lua @@ -17,6 +17,7 @@ end ---@class config.loader local m = {} +---@return table? function m.loadRCConfig(uri, filename) local scp = scope.getScope(uri) local path = workspace.getAbsolutePath(uri, filename) @@ -34,11 +35,16 @@ function m.loadRCConfig(uri, filename) errorMessage(lang.script('CONFIG_LOAD_ERROR', res)) return scp:get('lastRCConfig') end + ---@cast res table scp:set('lastRCConfig', res) return res end +---@return table? function m.loadLocalConfig(uri, filename) + if not filename then + return nil + end local scp = scope.getScope(uri) local path = workspace.getAbsolutePath(uri, filename) if not path then @@ -60,6 +66,7 @@ function m.loadLocalConfig(uri, filename) errorMessage(lang.script('CONFIG_LOAD_ERROR', res)) return scp:get('lastLocalConfig') end + ---@cast res table scp:set('lastLocalConfig', res) scp:set('lastLocalType', 'json') return res @@ -79,7 +86,7 @@ end ---@async ---@param uri? uri ----@return table +---@return table? function m.loadClientConfig(uri) local configs = proto.awaitRequest('workspace/configuration', { items = { diff --git a/script/config/template.lua b/script/config/template.lua new file mode 100644 index 00000000..60f3dbca --- /dev/null +++ b/script/config/template.lua @@ -0,0 +1,388 @@ +local util = require 'utility' +local define = require 'proto.define' +local diag = require 'proto.diagnostic' + +---@class config.unit +---@field caller function +---@field _checker fun(self: config.unit, value: any): boolean +---@field name string +---@field [string] config.unit +---@operator shl: config.unit +---@operator shr: config.unit +---@operator call: config.unit +local mt = {} +mt.__index = mt + +function mt:__call(...) + self:caller(...) + return self +end + +function mt:__shr(default) + self.default = default + self.hasDefault = true + return self +end + +function mt:__shl(enums) + self.enums = enums + return self +end + +function mt:checker(v) + if self.enums then + local ok + for _, enum in ipairs(self.enums) do + if util.equal(enum, v) then + ok = true + break + end + end + if not ok then + return false + end + end + return self:_checker(v) +end + +local units = {} + +local function register(name, default, checker, loader, caller) + units[name] = { + name = name, + default = default, + _checker = checker, + loader = loader, + caller = caller, + } +end + +---@type config.unit +local Type = setmetatable({}, { __index = function (_, name) + local unit = {} + for k, v in pairs(units[name]) do + unit[k] = v + end + return setmetatable(unit, mt) +end }) + +register('Boolean', false, function (self, v) + return type(v) == 'boolean' +end, function (self, v) + return v +end) + +register('Integer', 0, function (self, v) + return type(v) == 'number' +end, function (self, v) + return math.floor(v) +end) + +register('String', '', function (self, v) + return type(v) == 'string' +end, function (self, v) + return tostring(v) +end) + +register('Nil', nil, function (self, v) + return type(v) == 'nil' +end, function (self, v) + return nil +end) + +register('Array', {}, function (self, value) + return type(value) == 'table' +end, function (self, value) + local t = {} + if #value == 0 then + for k in pairs(value) do + if self.sub:checker(k) then + t[#t+1] = self.sub:loader(k) + end + end + else + for _, v in ipairs(value) do + if self.sub:checker(v) then + t[#t+1] = self.sub:loader(v) + end + end + end + return t +end, function (self, sub) + self.sub = sub +end) + +register('Hash', {}, function (self, value) + if type(value) == 'table' then + if #value == 0 then + for k, v in pairs(value) do + if not self.subkey:checker(k) + or not self.subvalue:checker(v) then + return false + end + end + else + if not self.subvalue:checker(true) then + return false + end + for _, v in ipairs(value) do + if not self.subkey:checker(v) then + return false + end + end + end + return true + end + if type(value) == 'string' then + return self.subkey:checker('') + and self.subvalue:checker(true) + end +end, function (self, value) + if type(value) == 'table' then + local t = {} + if #value == 0 then + for k, v in pairs(value) do + t[k] = v + end + else + for _, k in pairs(value) do + t[k] = true + end + end + return t + end + if type(value) == 'string' then + local t = {} + for s in value:gmatch('[^' .. self.sep .. ']+') do + t[s] = true + end + return t + end +end, function (self, subkey, subvalue, sep) + self.subkey = subkey + self.subvalue = subvalue + self.sep = sep +end) + +register('Or', nil, function (self, value) + for _, sub in ipairs(self.subs) do + if sub:checker(value) then + return true + end + end + return false +end, function (self, value) + for _, sub in ipairs(self.subs) do + if sub:checker(value) then + return sub:loader(value) + end + end +end, function (self, ...) + self.subs = { ... } +end) + +local template = { + ['Lua.runtime.version'] = Type.String >> 'Lua 5.4' << { + 'Lua 5.1', + 'Lua 5.2', + 'Lua 5.3', + 'Lua 5.4', + 'LuaJIT', + }, + ['Lua.runtime.path'] = Type.Array(Type.String) >> { + "?.lua", + "?/init.lua", + }, + ['Lua.runtime.pathStrict'] = Type.Boolean >> false, + ['Lua.runtime.special'] = Type.Hash( + Type.String, + Type.String >> 'require' << { + '_G', + 'rawset', + 'rawget', + 'setmetatable', + 'require', + 'dofile', + 'loadfile', + 'pcall', + 'xpcall', + 'assert', + 'error', + 'type', + } + ), + ['Lua.runtime.meta'] = Type.String >> '${version} ${language} ${encoding}', + ['Lua.runtime.unicodeName'] = Type.Boolean, + ['Lua.runtime.nonstandardSymbol'] = Type.Array(Type.String << { + '//', '/**/', + '`', + '+=', '-=', '*=', '/=', '%=', '^=', '//=', + '|=', '&=', '<<=', '>>=', + '||', '&&', '!', '!=', + 'continue', + }), + ['Lua.runtime.plugin'] = Type.String, + ['Lua.runtime.pluginArgs'] = Type.Array(Type.String), + ['Lua.runtime.fileEncoding'] = Type.String >> 'utf8' << { + 'utf8', + 'ansi', + 'utf16le', + 'utf16be', + }, + ['Lua.runtime.builtin'] = Type.Hash( + Type.String << util.getTableKeys(define.BuiltIn, true), + Type.String >> 'default' << { + 'default', + 'enable', + 'disable', + } + ) + >> util.deepCopy(define.BuiltIn), + ['Lua.diagnostics.enable'] = Type.Boolean >> true, + ['Lua.diagnostics.globals'] = Type.Array(Type.String), + ['Lua.diagnostics.disable'] = Type.Array(Type.String << util.getTableKeys(diag.getDiagAndErrNameMap(), true)), + ['Lua.diagnostics.severity'] = Type.Hash( + Type.String << util.getTableKeys(define.DiagnosticDefaultNeededFileStatus, true), + Type.String << { + 'Error', + 'Warning', + 'Information', + 'Hint', + 'Error!', + 'Warning!', + 'Information!', + 'Hint!', + } + ) + >> util.deepCopy(define.DiagnosticDefaultSeverity), + ['Lua.diagnostics.neededFileStatus'] = Type.Hash( + Type.String << util.getTableKeys(define.DiagnosticDefaultNeededFileStatus, true), + Type.String << { + 'Any', + 'Opened', + 'None', + 'Any!', + 'Opened!', + 'None!', + } + ) + >> util.deepCopy(define.DiagnosticDefaultNeededFileStatus), + ['Lua.diagnostics.groupSeverity'] = Type.Hash( + Type.String << util.getTableKeys(define.DiagnosticDefaultGroupSeverity, true), + Type.String << { + 'Error', + 'Warning', + 'Information', + 'Hint', + 'Fallback', + } + ) + >> util.deepCopy(define.DiagnosticDefaultGroupSeverity), + ['Lua.diagnostics.groupFileStatus'] = Type.Hash( + Type.String << util.getTableKeys(define.DiagnosticDefaultGroupFileStatus, true), + Type.String << { + 'Any', + 'Opened', + 'None', + 'Fallback', + } + ) + >> util.deepCopy(define.DiagnosticDefaultGroupFileStatus), + ['Lua.diagnostics.disableScheme'] = Type.Array(Type.String) >> { 'git' }, + ['Lua.diagnostics.workspaceDelay'] = Type.Integer >> 3000, + ['Lua.diagnostics.workspaceRate'] = Type.Integer >> 100, + ['Lua.diagnostics.libraryFiles'] = Type.String >> 'Opened' << { + 'Enable', + 'Opened', + 'Disable', + }, + ['Lua.diagnostics.ignoredFiles'] = Type.String >> 'Opened' << { + 'Enable', + 'Opened', + 'Disable', + }, + ['Lua.diagnostics.unusedLocalExclude'] = Type.Array(Type.String), + ['Lua.workspace.ignoreDir'] = Type.Array(Type.String) >> { + '.vscode', + }, + ['Lua.workspace.ignoreSubmodules'] = Type.Boolean >> true, + ['Lua.workspace.useGitIgnore'] = Type.Boolean >> true, + ['Lua.workspace.maxPreload'] = Type.Integer >> 5000, + ['Lua.workspace.preloadFileSize'] = Type.Integer >> 500, + ['Lua.workspace.library'] = Type.Array(Type.String), + ['Lua.workspace.checkThirdParty'] = Type.Boolean >> true, + ['Lua.workspace.userThirdParty'] = Type.Array(Type.String), + ['Lua.workspace.supportScheme'] = Type.Array(Type.String) >> { 'file', 'untitled', 'git' }, + ['Lua.completion.enable'] = Type.Boolean >> true, + ['Lua.completion.callSnippet'] = Type.String >> 'Disable' << { + 'Disable', + 'Both', + 'Replace', + }, + ['Lua.completion.keywordSnippet'] = Type.String >> 'Replace' << { + 'Disable', + 'Both', + 'Replace', + }, + ['Lua.completion.displayContext'] = Type.Integer >> 0, + ['Lua.completion.workspaceWord'] = Type.Boolean >> true, + ['Lua.completion.showWord'] = Type.String >> 'Fallback' << { + 'Enable', + 'Fallback', + 'Disable', + }, + ['Lua.completion.autoRequire'] = Type.Boolean >> true, + ['Lua.completion.showParams'] = Type.Boolean >> true, + ['Lua.completion.requireSeparator'] = Type.String >> '.', + ['Lua.completion.postfix'] = Type.String >> '@', + ['Lua.signatureHelp.enable'] = Type.Boolean >> true, + ['Lua.hover.enable'] = Type.Boolean >> true, + ['Lua.hover.viewString'] = Type.Boolean >> true, + ['Lua.hover.viewStringMax'] = Type.Integer >> 1000, + ['Lua.hover.viewNumber'] = Type.Boolean >> true, + ['Lua.hover.previewFields'] = Type.Integer >> 50, + ['Lua.hover.enumsLimit'] = Type.Integer >> 5, + ['Lua.hover.expandAlias'] = Type.Boolean >> true, + ['Lua.semantic.enable'] = Type.Boolean >> true, + ['Lua.semantic.variable'] = Type.Boolean >> true, + ['Lua.semantic.annotation'] = Type.Boolean >> true, + ['Lua.semantic.keyword'] = Type.Boolean >> false, + ['Lua.hint.enable'] = Type.Boolean >> false, + ['Lua.hint.paramType'] = Type.Boolean >> true, + ['Lua.hint.setType'] = Type.Boolean >> false, + ['Lua.hint.paramName'] = Type.String >> 'All' << { + 'All', + 'Literal', + 'Disable', + }, + ['Lua.hint.await'] = Type.Boolean >> true, + ['Lua.hint.arrayIndex'] = Type.String >> 'Auto' << { + 'Enable', + 'Auto', + 'Disable', + }, + ['Lua.hint.semicolon'] = Type.String >> 'SameLine' << { + 'All', + 'SameLine', + 'Disable', + }, + ['Lua.window.statusBar'] = Type.Boolean >> true, + ['Lua.window.progressBar'] = Type.Boolean >> true, + ['Lua.format.enable'] = Type.Boolean >> true, + ['Lua.format.defaultConfig'] = Type.Hash(Type.String, Type.String) + >> {}, + ['Lua.spell.dict'] = Type.Array(Type.String), + ['Lua.telemetry.enable'] = Type.Or(Type.Boolean >> false, Type.Nil) >> nil, + ['Lua.misc.parameters'] = Type.Array(Type.String), + ['Lua.type.castNumberToInteger'] = Type.Boolean >> true, + ['Lua.type.weakUnionCheck'] = Type.Boolean >> false, + ['Lua.type.weakNilCheck'] = Type.Boolean >> false, + + -- VSCode + ['files.associations'] = Type.Hash(Type.String, Type.String), + ['files.exclude'] = Type.Hash(Type.String, Type.Boolean), + ['editor.semanticHighlighting.enabled'] = Type.Or(Type.Boolean, Type.String), + ['editor.acceptSuggestionOnEnter'] = Type.String >> 'on', +} + +return template diff --git a/script/core/code-action.lua b/script/core/code-action.lua index 6bba0a82..4eb21ff8 100644 --- a/script/core/code-action.lua +++ b/script/core/code-action.lua @@ -5,11 +5,18 @@ local sp = require 'bee.subprocess' local guide = require "parser.guide" local converter = require 'proto.converter' +---@param uri uri +---@param row integer +---@param mode string +---@param code string local function checkDisableByLuaDocExits(uri, row, mode, code) if row < 0 then return nil end local state = files.getState(uri) + if not state then + return nil + end local lines = state.lines if state.ast.docs and lines then return guide.eachSourceBetween( @@ -124,9 +131,12 @@ local function changeVersion(uri, version, results) end local function solveUndefinedGlobal(uri, diag, results) - local ast = files.getState(uri) - local start = converter.unpackRange(uri, diag.range) - guide.eachSourceContain(ast.ast, start, function (source) + local state = files.getState(uri) + if not state then + return + end + local start = converter.unpackRange(uri, diag.range) + guide.eachSourceContain(state.ast, start, function (source) if source.type ~= 'getglobal' then return end @@ -143,9 +153,12 @@ local function solveUndefinedGlobal(uri, diag, results) end local function solveLowercaseGlobal(uri, diag, results) - local ast = files.getState(uri) - local start = converter.unpackRange(uri, diag.range) - guide.eachSourceContain(ast.ast, start, function (source) + local state = files.getState(uri) + if not state then + return + end + local start = converter.unpackRange(uri, diag.range) + guide.eachSourceContain(state.ast, start, function (source) if source.type ~= 'setglobal' then return end @@ -156,8 +169,11 @@ local function solveLowercaseGlobal(uri, diag, results) end local function findSyntax(uri, diag) - local ast = files.getState(uri) - for _, err in ipairs(ast.errs) do + local state = files.getState(uri) + if not state then + return + end + for _, err in ipairs(state.errs) do if err.type:lower():gsub('_', '-') == diag.code then local range = converter.packRange(uri, err.start, err.finish) if util.equal(range, diag.range) then @@ -333,6 +349,8 @@ local function solveAwaitInSync(uri, diag, results) end local row = guide.rowColOf(parentFunction.start) local pos = guide.positionOf(row, 0) + local offset = guide.positionToOffset(state, pos + 1) + local space = state.lua:match('[ \t]*', offset) results[#results+1] = { title = lang.script.ACTION_MARK_ASYNC, kind = 'quickfix', @@ -342,7 +360,7 @@ local function solveAwaitInSync(uri, diag, results) { start = pos, finish = pos, - newText = '---@async\n', + newText = space .. '---@async\n', } } } @@ -350,6 +368,51 @@ local function solveAwaitInSync(uri, diag, results) } end +local function solveSpell(uri, diag, results) + local spell = require 'provider.spell' + local word = diag.data + if word == nil then + return + end + + results[#results+1] = { + title = lang.script('ACTION_ADD_DICT', word), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_ADD_DICT, + command = 'lua.setConfig', + arguments = { + { + key = 'Lua.spell.dict', + action = 'add', + value = word, + uri = uri, + } + } + } + } + + local suggests = spell.getSpellSuggest(word) + for _, suggest in ipairs(suggests) do + results[#results+1] = { + title = suggest, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + start = converter.unpackPosition(uri, diag.range.start), + finish = converter.unpackPosition(uri, diag.range["end"]), + newText = suggest + } + } + } + } + } + end + +end + local function solveDiagnostic(uri, diag, start, results) if diag.source == lang.script.DIAG_SYNTAX_CHECK then solveSyntax(uri, diag, results) @@ -370,6 +433,8 @@ local function solveDiagnostic(uri, diag, start, results) solveTrailingSpace(uri, diag, results) elseif diag.code == 'await-in-sync' then solveAwaitInSync(uri, diag, results) + elseif diag.code == 'spell-check' then + solveSpell(uri, diag, results) end disableDiagnostic(uri, diag.code, start, results) end @@ -386,7 +451,7 @@ end local function checkSwapParams(results, uri, start, finish) local state = files.getState(uri) local text = files.getText(uri) - if not state then + if not state or not text then return end local args = {} @@ -554,6 +619,9 @@ end local function checkJsonToLua(results, uri, start, finish) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return + end local startOffset = guide.positionToOffset(state, start) local finishOffset = guide.positionToOffset(state, finish) local jsonStart = text:match('()[%{%[]', startOffset + 1) diff --git a/script/core/collector.lua b/script/core/collector.lua deleted file mode 100644 index a2e3ca08..00000000 --- a/script/core/collector.lua +++ /dev/null @@ -1,188 +0,0 @@ -local scope = require 'workspace.scope' - ----@class collector ----@field subscribed table<uri, table<string, any>> ----@field collect table<string, table<uri, any>> -local mt = {} -mt.__index = mt - ---- 订阅一个名字 ----@param uri uri ----@param name string ----@param value any -function mt:subscribe(uri, name, value) - uri = uri or '<fallback>' - -- 订阅部分 - local uriSubscribed = self.subscribed[uri] - if not uriSubscribed then - uriSubscribed = {} - self.subscribed[uri] = uriSubscribed - end - uriSubscribed[name] = true - -- 收集部分 - local nameCollect = self.collect[name] - if not nameCollect then - nameCollect = {} - self.collect[name] = nameCollect - end - if value == nil then - value = true - end - nameCollect[uri] = value -end - ---- 丢弃掉某个 uri 中收集的所有信息 ----@param uri uri -function mt:dropUri(uri) - uri = uri or '<fallback>' - local uriSubscribed = self.subscribed[uri] - if not uriSubscribed then - return - end - self.subscribed[uri] = nil - for name in pairs(uriSubscribed) do - self.collect[name][uri] = nil - if not next(self.collect[name]) then - self.collect[name] = nil - end - end -end - -function mt:dropAll() - self.subscribed = {} - self.collect = {} -end - ---- 是否包含某个名字的订阅 ----@param uri uri ----@param name string ----@return boolean -function mt:has(uri, name) - if self:each(uri, name)() then - return true - else - return false - end -end - -local DUMMY_FUNCTION = function () end - ----@param scp scope -local function eachOfFolder(nameCollect, scp) - local curi, value - - local function getNext() - curi, value = next(nameCollect, curi) - if not curi then - return nil, nil - end - if scp:isChildUri(curi) - or scp:isLinkedUri(curi) then - return value, curi - end - return getNext() - end - - return getNext -end - ----@param scp scope -local function eachOfLinked(nameCollect, scp) - local curi, value - - local function getNext() - curi, value = next(nameCollect, curi) - if not curi then - return nil, nil - end - if scp:isChildUri(curi) - and scp:isLinkedUri(curi) then - return value, curi - end - - local cscp = scope.getFolder(curi) - or scope.getLinkedScope(curi) - or scope.fallback - - if cscp == scp - or cscp:isChildUri(scp.uri) - or cscp:isLinkedUri(scp.uri) then - return value, curi - end - - return getNext() - end - - return getNext -end - ----@param scp scope -local function eachOfFallback(nameCollect, scp) - local curi, value - - local function getNext() - curi, value = next(nameCollect, curi) - if not curi then - return nil, nil - end - if scp:isLinkedUri(curi) then - return value, curi - end - - local cscp = scope.getFolder(curi) - or scope.getLinkedScope(curi) - or scope.fallback - - if cscp == scp then - return value, curi - end - - return getNext() - end - - return getNext -end - ---- 迭代某个名字的订阅 ----@param uri uri ----@param name string -function mt:each(uri, name) - uri = uri or '<fallback>' - local nameCollect = self.collect[name] - if not nameCollect then - return DUMMY_FUNCTION - end - - local scp = scope.getFolder(uri) - - if scp then - return eachOfFolder(nameCollect, scp) - end - - scp = scope.getLinkedScope(uri) - - if scp then - return eachOfLinked(nameCollect, scp) - end - - return eachOfFallback(nameCollect, scope.fallback) -end - -local collectors = {} - -local function new() - return setmetatable({ - collect = {}, - subscribed = {}, - }, mt) -end - ----@return collector -return function (name) - if name then - collectors[name] = collectors[name] or new() - return collectors[name] - else - return new() - end -end diff --git a/script/core/color.lua b/script/core/color.lua new file mode 100644 index 00000000..2cbcce11 --- /dev/null +++ b/script/core/color.lua @@ -0,0 +1,79 @@ +local files = require "files" +local guide = require "parser.guide" + +local colorPattern = string.rep('%x', 8) +---@param source parser.object +---@return boolean +local function isColor(source) + ---@type string + local text = source[1] + if text:len() ~= 8 then + return false + end + return text:match(colorPattern) +end + + +---@param colorText string +---@return Color +local function textToColor(colorText) + return { + alpha = tonumber(colorText:sub(1, 2), 16) / 255, + red = tonumber(colorText:sub(3, 4), 16) / 255, + green = tonumber(colorText:sub(5, 6), 16) / 255, + blue = tonumber(colorText:sub(7, 8), 16) / 255, + } +end + + +---@param color Color +---@return string +local function colorToText(color) + return string.format('%02X%02X%02X%02X' + , math.floor(color.alpha * 255) + , math.floor(color.red * 255) + , math.floor(color.green * 255) + , math.floor(color.blue * 255) + ) +end + +---@class Color +---@field red number +---@field green number +---@field blue number +---@field alpha number + +---@class ColorValue +---@field color Color +---@field start integer +---@field finish integer + +---@async +local function colors(uri) + local state = files.getState(uri) + local text = files.getText(uri) + if not state or not text then + return nil + end + ---@type ColorValue[] + local colorValues = {} + + guide.eachSource(state.ast, function (source) ---@async + if source.type == 'string' and isColor(source) then + ---@type string + local colorText = source[1] + + colorValues[#colorValues+1] = { + start = source.start + 1, + finish = source.finish - 1, + color = textToColor(colorText) + } + end + end) + return colorValues +end + +return { + colors = colors, + colorToText = colorToText +} diff --git a/script/core/command/autoRequire.lua b/script/core/command/autoRequire.lua index c0deecfc..32911d92 100644 --- a/script/core/command/autoRequire.lua +++ b/script/core/command/autoRequire.lua @@ -21,6 +21,9 @@ end local function findInsertRow(uri) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return + end local lines = state.lines local fmt = { pair = false, @@ -68,7 +71,7 @@ local function askAutoRequire(uri, visiblePaths) local selects = {} local nameMap = {} for _, visible in ipairs(visiblePaths) do - local expect = visible.expect + local expect = visible.name local select = lang.script(expect) if not nameMap[select] then nameMap[select] = expect @@ -143,7 +146,7 @@ return function (data) return end table.sort(visiblePaths, function (a, b) - return #a.expect < #b.expect + return #a.name < #b.name end) local result = askAutoRequire(uri, visiblePaths) diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua index aa565f7f..992a0705 100644 --- a/script/core/command/removeSpace.lua +++ b/script/core/command/removeSpace.lua @@ -4,20 +4,12 @@ local proto = require 'proto' local lang = require 'language' local converter = require 'proto.converter' -local function isInString(ast, offset) - return guide.eachSourceContain(ast.ast, offset, function (source) - if source.type == 'string' then - return true - end - end) or false -end - ---@async return function (data) local uri = data.uri local text = files.getText(uri) local state = files.getState(uri) - if not state then + if not state or not text then return end @@ -32,7 +24,8 @@ return function (data) goto NEXT_LINE end local lastPos = guide.offsetToPosition(state, lastOffset) - if isInString(state.ast, lastPos) then + if guide.isInString(state.ast, lastPos) + or guide.isInComment(state.ast, lastPos) then goto NEXT_LINE end local firstOffset = startOffset diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua index 8065aa9d..98ceaa58 100644 --- a/script/core/command/solve.lua +++ b/script/core/command/solve.lua @@ -32,7 +32,7 @@ return function (data) local uri = data.uri local text = files.getText(uri) local state = files.getState(uri) - if not state then + if not state or not text then return end diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index d4c20c60..8f28e450 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -18,6 +18,7 @@ local lookBackward = require 'core.look-backward' local guide = require 'parser.guide' local await = require 'await' local postfix = require 'core.completion.postfix' +local diag = require 'proto.diagnostic' local diagnosticModes = { 'disable-next-line', @@ -56,6 +57,7 @@ local function trim(str) end local function findNearestSource(state, position) + ---@type parser.object local source guide.eachSourceContain(state.ast, position, function (src) source = src @@ -66,6 +68,9 @@ end local function findNearestTableField(state, position) local uri = state.uri local text = files.getText(uri) + if not text then + return nil + end local offset = guide.positionToOffset(state, position) local soffset = lookBackward.findAnyOffset(text, offset) if not soffset then @@ -155,36 +160,24 @@ local function buildFunctionSnip(source, value, oop) if oop then table.remove(args, 1) end - local len = #args - local truncated = false - if len > 0 and args[len]:match('^%s*%.%.%.:') then - table.remove(args) - truncated = true - end - for i = #args, 1, -1 do - if args[i]:match('^%s*[^?]+%?:') then - table.remove(args) - truncated = true - else - break - end - end local snipArgs = {} for id, arg in ipairs(args) do - local str = arg:gsub('^(%s*)(.+)', function (sp, word) + local str, count = arg:gsub('^(%s*)(%.%.%.)(.+)', function (sp, word) return ('%s${%d:%s}'):format(sp, id, word) end) + if count == 0 then + str = arg:gsub('^(%s*)([^:]+)(.+)', function (sp, word) + return ('%s${%d:%s}'):format(sp, id, word) + end) + end table.insert(snipArgs, str) end - if truncated and #snipArgs == 0 then - snipArgs = {'$1'} - end return ('%s(%s)'):format(name, table.concat(snipArgs, ', ')) end local function buildDetail(source) - local types = vm.getInfer(source):view() + local types = vm.getInfer(source):view(guide.getUri(source)) local literals = vm.getInfer(source):viewLiterals() if literals then return types .. ' = ' .. literals @@ -204,6 +197,9 @@ local function getSnip(source) local uri = guide.getUri(def) local text = files.getText(uri) local state = files.getState(uri) + if not state then + goto CONTINUE + end local lines = state.lines if not text then goto CONTINUE @@ -302,7 +298,7 @@ local function checkLocal(state, word, position, results) if name:sub(1, 1) == '@' then goto CONTINUE end - if vm.getInfer(source):hasFunction() then + if vm.getInfer(source):hasFunction(state.uri) then local defs = vm.getDefs(source) -- make sure `function` is before `doc.type.function` local orders = {} @@ -356,6 +352,7 @@ local function checkModule(state, word, position, results) if not config.get(state.uri, 'Lua.completion.autoRequire') then return end + local globals = util.arrayToHash(config.get(state.uri, 'Lua.diagnostics.globals')) local locals = guide.getVisibleLocals(state.ast, position) for uri in files.eachFile(state.uri) do if uri == guide.getUri(state.ast) then @@ -366,7 +363,7 @@ local function checkModule(state, word, position, results) local stemName = fileName:gsub('%..+', '') if not locals[stemName] and not vm.hasGlobalSets(state.uri, 'variable', stemName) - and not config.get(state.uri, 'Lua.diagnostics.globals')[stemName] + and not globals[stemName] and stemName:match '^[%a_][%w_]*$' and matchKey(word, stemName) then local targetState = files.getState(uri) @@ -488,7 +485,7 @@ local function checkFieldFromFieldToIndex(state, name, src, parent, word, startP end local function checkFieldThen(state, name, src, word, startPos, position, parent, oop, results) - local value = vm.getObjectValue(src) or src + local value = vm.getObjectFunctionValue(src) or src local kind = define.CompletionItemKind.Field if value.type == 'function' or value.type == 'doc.type.function' then @@ -512,7 +509,7 @@ local function checkFieldThen(state, name, src, word, startPos, position, parent }) return end - if oop and not vm.getInfer(src):hasFunction() then + if oop and not vm.getInfer(src):hasFunction(state.uri) then return end local literal = guide.getLiteral(value) @@ -568,7 +565,8 @@ local function checkFieldOfRefs(refs, state, word, startPos, position, parent, o end local funcLabel if config.get(state.uri, 'Lua.completion.showParams') then - local value = vm.getObjectValue(src) or src + --- TODO determine if getlocal should be a function here too + local value = vm.getObjectFunctionValue(src) or src if value.type == 'function' or value.type == 'doc.type.function' then funcLabel = name .. getParams(value, oop) @@ -916,24 +914,24 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position goto CONTINUE end local path = furi.decode(uri) - local infos = rpath.getVisiblePath(uri, path) + local infos = rpath.getVisiblePath(myUri, path) local relative = workspace.getRelativePath(path) for _, info in ipairs(infos) do - if matchKey(literal, info.expect) then - if not collect[info.expect] then - collect[info.expect] = { + if matchKey(literal, info.name) then + if not collect[info.name] then + collect[info.name] = { textEdit = { start = smark and (source.start + #smark) or position, finish = smark and (source.finish - #smark) or position, - newText = smark and info.expect or util.viewString(info.expect), + newText = smark and info.name or util.viewString(info.name), }, path = relative, } end if vm.isMetaFile(uri) then - collect[info.expect][#collect[info.expect]+1] = ('* [[meta]](%s)'):format(uri) + collect[info.name][#collect[info.name]+1] = ('* [[meta]](%s)'):format(uri) else - collect[info.expect][#collect[info.expect]+1] = ([=[* [%s](%s) %s]=]):format( + collect[info.name][#collect[info.name]+1] = ([=[* [%s](%s) %s]=]):format( relative, uri, lang.script('HOVER_USE_LUA_PATH', info.searcher) @@ -1098,11 +1096,11 @@ local function tryLabelInString(label, source) if not source or source.type ~= 'string' then return label end - local state = parser.parse(label, 'String') + local state = parser.compile(label, 'String') if not state or not state.ast then return label end - if not matchKey(source[1], state.ast[1]) then + if not matchKey(source[1], state.ast[1]--[[@as string]]) then return nil end return util.viewString(state.ast[1], source[2]) @@ -1124,18 +1122,112 @@ local function cleanEnums(enums, source) return enums end -local function checkTypingEnum(state, position, defs, str, results) +---@param state parser.state +---@param pos integer +---@param doc vm.node.object +---@param enums table[] +---@return table[]? +local function insertDocEnum(state, pos, doc, enums) + local tbl = doc.bindSource + if not tbl then + return nil + end + local parent = tbl.parent + local parentName + if parent._globalNode then + parentName = parent._globalNode:getName() + else + local locals = guide.getVisibleLocals(state.ast, pos) + for _, loc in pairs(locals) do + if util.arrayHas(vm.getDefs(loc), tbl) then + parentName = loc[1] + break + end + end + end + local valueEnums = {} + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + if field.value.type == 'integer' + or field.value.type == 'string' then + if parentName then + enums[#enums+1] = { + label = parentName .. '.' .. key, + kind = define.CompletionItemKind.EnumMember, + id = stack(function () ---@async + return { + detail = buildDetail(field), + description = buildDesc(field), + } + end), + } + end + valueEnums[#valueEnums+1] = { + label = util.viewLiteral(field.value[1]), + kind = define.CompletionItemKind.EnumMember, + id = stack(function () ---@async + return { + detail = buildDetail(field), + description = buildDesc(field), + } + end), + } + end + ::CONTINUE:: + end + end + for _, enum in ipairs(valueEnums) do + enums[#enums+1] = enum + end + return enums +end + +---@param state parser.state +---@param pos integer +---@param src vm.node.object +---@param enums table[] +---@param isInArray boolean? +local function insertEnum(state, pos, src, enums, isInArray) + if src.type == 'doc.type.string' + or src.type == 'doc.type.integer' + or src.type == 'doc.type.boolean' then + ---@cast src parser.object + enums[#enums+1] = { + label = vm.viewObject(src, state.uri), + description = src.comment, + kind = define.CompletionItemKind.EnumMember, + } + elseif src.type == 'doc.type.code' then + enums[#enums+1] = { + label = src[1], + description = src.comment, + kind = define.CompletionItemKind.EnumMember, + } + elseif isInArray and src.type == 'doc.type.array' then + for i, d in ipairs(vm.getDefs(src.node)) do + insertEnum(state, pos, d, enums, isInArray) + end + elseif src.type == 'global' and src.cate == 'type' then + for _, set in ipairs(src:getSets(state.uri)) do + if set.type == 'doc.enum' then + insertDocEnum(state, pos, set, enums) + end + end + end +end + +local function checkTypingEnum(state, position, defs, str, results, isInArray) local enums = {} for _, def in ipairs(defs) do - if def.type == 'doc.type.string' - or def.type == 'doc.type.integer' - or def.type == 'doc.type.boolean' then - enums[#enums+1] = { - label = vm.viewObject(def), - description = def.comment and def.comment.text, - kind = define.CompletionItemKind.EnumMember, - } - end + insertEnum(state, position, def, enums, isInArray) end cleanEnums(enums, str) for _, res in ipairs(enums) do @@ -1143,7 +1235,7 @@ local function checkTypingEnum(state, position, defs, str, results) end end -local function checkEqualEnumLeft(state, position, source, results) +local function checkEqualEnumLeft(state, position, source, results, isInArray) if not source then return end @@ -1153,7 +1245,7 @@ local function checkEqualEnumLeft(state, position, source, results) end end) local defs = vm.getDefs(source) - checkTypingEnum(state, position, defs, str, results) + checkTypingEnum(state, position, defs, str, results, isInArray) end local function checkEqualEnum(state, position, results) @@ -1197,15 +1289,24 @@ local function checkEqualEnumInString(state, position, results) end checkEqualEnumLeft(state, position, parent[1], results) end + if (parent.type == 'tableexp') then + checkEqualEnumLeft(state, position, parent.parent.parent, results, true) + return + end if parent.type == 'local' then checkEqualEnumLeft(state, position, parent, results) end + if parent.type == 'setlocal' or parent.type == 'setglobal' or parent.type == 'setfield' or parent.type == 'setindex' then checkEqualEnumLeft(state, position, parent.node, results) end + if parent.type == 'tablefield' + or parent.type == 'tableindex' then + checkEqualEnumLeft(state, position, parent, results) + end end local function isFuncArg(state, position) @@ -1234,7 +1335,10 @@ local function tryIndex(state, position, results) if not parent then return end - local word = parent.next.index[1] + local word = parent.next and parent.next.index and parent.next.index[1] + if not word then + return + end checkField(state, word, position, position, parent, oop, results) end @@ -1414,18 +1518,12 @@ local function tryCallArg(state, position, results) if not node then return end + local enums = {} for src in node:eachObject() do - if src.type == 'doc.type.string' - or src.type == 'doc.type.integer' - or src.type == 'doc.type.boolean' then - enums[#enums+1] = { - label = vm.viewObject(src), - description = src.comment, - kind = define.CompletionItemKind.EnumMember, - } - end + insertEnum(state, position, src, enums, arg and arg.type == 'table') if src.type == 'doc.type.function' then + ---@cast src parser.object local insertText = buildInsertDocFunction(src) local description if src.comment then @@ -1439,7 +1537,7 @@ local function tryCallArg(state, position, results) : string() end enums[#enums+1] = { - label = vm.getInfer(src):view(), + label = vm.getInfer(src):view(state.uri), description = description, kind = define.CompletionItemKind.Function, insertText = insertText, @@ -1467,6 +1565,7 @@ local function tryTable(state, position, results) if source.type ~= 'table' then tbl = source.parent end + local defs = vm.getFields(tbl) for _, field in ipairs(defs) do local name = guide.getKeyName(field) @@ -1478,9 +1577,28 @@ local function tryTable(state, position, results) checkTableLiteralField(state, position, tbl, fields, results) end +local function tryArray(state, position, results) + local source = findNearestSource(state, position) + if not source then + return + end + if source.type ~= 'table' and (not source.parent or source.parent.type ~= 'table') then + return + end + local tbl = source + if source.type ~= 'table' then + tbl = source.parent + end + if source.parent.type == 'callargs' and source.parent.parent.type == 'call' then + return + end + -- { } inside when enum + checkEqualEnumLeft(state, position, tbl, results, true) +end + local function getComment(state, position) local offset = guide.positionToOffset(state, position) - local symbolOffset = lookBackward.findAnyOffset(state.lua, offset) + local symbolOffset = lookBackward.findAnyOffset(state.lua, offset, true) if not symbolOffset then return end @@ -1493,9 +1611,9 @@ local function getComment(state, position) return nil end -local function getluaDoc(state, position) +local function getLuaDoc(state, position) local offset = guide.positionToOffset(state, position) - local symbolOffset = lookBackward.findAnyOffset(state.lua, offset) + local symbolOffset = lookBackward.findAnyOffset(state.lua, offset, true) if not symbolOffset then return end @@ -1528,11 +1646,15 @@ local function tryluaDocCate(word, results) 'async', 'nodiscard', 'cast', + 'operator', + 'source', + 'enum', } do if matchKey(word, docType) then results[#results+1] = { label = docType, kind = define.CompletionItemKind.Event, + description = lang.script('LUADOC_DESC_' .. docType:upper()) } end end @@ -1608,6 +1730,7 @@ local function tryluaDocBySource(state, position, source, results) for _, doc in ipairs(vm.getDocSets(state.uri)) do local name = (doc.type == 'doc.class' and doc.class[1]) or (doc.type == 'doc.alias' and doc.alias[1]) + or (doc.type == 'doc.enum' and doc.enum[1]) if name and not used[name] and matchKey(source[1], name) then @@ -1697,6 +1820,35 @@ local function tryluaDocBySource(state, position, source, results) end end return true + elseif source.type == 'doc.operator.name' then + for _, name in ipairs(vm.UNARY_OP) do + if matchKey(source[1], name) then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_UNARY_MAP[name]), + } + end + end + for _, name in ipairs(vm.BINARY_OP) do + if matchKey(source[1], name) then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_BINARY_MAP[name]), + } + end + end + for _, name in ipairs(vm.OTHER_OP) do + if matchKey(source[1], name) then + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_OTHER_MAP[name]), + } + end + end + return true end return false end @@ -1734,6 +1886,14 @@ local function tryluaDocByErr(state, position, err, docState, results) kind = define.CompletionItemKind.Class, } end + if doc.type == 'doc.enum' + and not used[doc.enum[1]] then + used[doc.enum[1]] = true + results[#results+1] = { + label = doc.enum[1], + kind = define.CompletionItemKind.Enum, + } + end end elseif err.type == 'LUADOC_MISS_PARAM_NAME' then local funcs = {} @@ -1783,7 +1943,7 @@ local function tryluaDocByErr(state, position, err, docState, results) } end elseif err.type == 'LUADOC_MISS_DIAG_NAME' then - for name in util.sortPairs(define.DiagnosticDefaultSeverity) do + for name in util.sortPairs(diag.getDiagAndErrNameMap()) do results[#results+1] = { label = name, kind = define.CompletionItemKind.Value, @@ -1807,6 +1967,28 @@ local function tryluaDocByErr(state, position, err, docState, results) } end end + elseif err.type == 'LUADOC_MISS_OPERATOR_NAME' then + for _, name in ipairs(vm.UNARY_OP) do + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_UNARY_MAP[name]), + } + end + for _, name in ipairs(vm.BINARY_OP) do + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_BINARY_MAP[name]), + } + end + for _, name in ipairs(vm.OTHER_OP) do + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Operator, + description = ('```lua\n%s\n```'):format(vm.OP_OTHER_MAP[name]), + } + end end end @@ -1818,14 +2000,14 @@ local function buildluaDocOfFunction(func) local returns = {} if func.args then for _, arg in ipairs(func.args) do - args[#args+1] = vm.getInfer(arg):view() + args[#args+1] = vm.getInfer(arg):view(guide.getUri(func)) end end if func.returns then for _, rtns in ipairs(func.returns) do for n = 1, #rtns do if not returns[n] then - returns[n] = vm.getInfer(rtns[n]):view() + returns[n] = vm.getInfer(rtns[n]):view(guide.getUri(func)) end end end @@ -1853,25 +2035,27 @@ local function buildluaDocOfFunction(func) end local function tryluaDocOfFunction(doc, results) - if not doc.bindSources then + if not doc.bindSource then return end - local func - for _, source in ipairs(doc.bindSources) do - if source.type == 'function' then - func = source - break - end - end + local func = (doc.bindSource.type == 'function' and doc.bindSource) + or (doc.bindSource.value and doc.bindSource.value.type == 'function' and doc.bindSource.value) + or nil if not func then return end for _, otherDoc in ipairs(doc.bindGroup) do - if otherDoc.type == 'doc.param' - or otherDoc.type == 'doc.return' then + if otherDoc.type == 'doc.return' then return end end + if func.args then + for _, param in ipairs(func.args) do + if param.bindDocs then + return + end + end + end local insertText = buildluaDocOfFunction(func) results[#results+1] = { label = '@param;@return', @@ -1882,8 +2066,8 @@ local function tryluaDocOfFunction(doc, results) } end -local function tryluaDoc(state, position, results) - local doc = getluaDoc(state, position) +local function tryLuaDoc(state, position, results) + local doc = getLuaDoc(state, position) if not doc then return end @@ -1922,7 +2106,7 @@ local function tryComment(state, position, results) return end local word = lookBackward.findWord(state.lua, guide.positionToOffset(state, position)) - local doc = getluaDoc(state, position) + local doc = getLuaDoc(state, position) if not word then local comment = getComment(state, position) if not comment then @@ -1961,7 +2145,7 @@ local function tryCompletions(state, position, triggerCharacter, results) return end if getComment(state, position) then - tryluaDoc(state, position, results) + tryLuaDoc(state, position, results) tryComment(state, position, results) return end @@ -1971,6 +2155,7 @@ local function tryCompletions(state, position, triggerCharacter, results) trySpecial(state, position, results) tryCallArg(state, position, results) tryTable(state, position, results) + tryArray(state, position, results) tryWord(state, position, triggerCharacter, results) tryIndex(state, position, results) trySymbol(state, position, results) @@ -1983,8 +2168,6 @@ local function completion(uri, position, triggerCharacter) return nil end clearStack() - vm.lockCache() - local _ <close> = vm.unlockCache local results = {} tracy.ZoneBeginN 'completion #2' tryCompletions(state, position, triggerCharacter, results) diff --git a/script/core/definition.lua b/script/core/definition.lua index e4868532..866e8f84 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -4,6 +4,7 @@ local vm = require 'vm' local findSource = require 'core.find-source' local guide = require 'parser.guide' local rpath = require 'workspace.require-path' +local jumpSource = require 'core.jump-source' local function sortResults(results) -- 先按照顺序排序 @@ -54,6 +55,7 @@ local accept = { ['doc.see.name'] = true, ['doc.see.field'] = true, ['doc.cast.name'] = true, + ['doc.enum.name'] = true, } local function checkRequire(source, offset) @@ -75,7 +77,7 @@ local function checkRequire(source, offset) return nil end if libName == 'require' then - return rpath.findUrisByRequirePath(guide.getUri(source), literal) + return rpath.findUrisByRequireName(guide.getUri(source), literal) elseif libName == 'dofile' or libName == 'loadfile' then return workspace.findUrisByFilePath(literal) @@ -169,8 +171,12 @@ return function (uri, offset) if src.type == 'doc.alias' then src = src.alias end + if src.type == 'doc.enum' then + src = src.enum + end if src.type == 'doc.class.name' - or src.type == 'doc.alias.name' then + or src.type == 'doc.alias.name' + or src.type == 'doc.enum.name' then if source.type ~= 'doc.type.name' and source.type ~= 'doc.extends.name' and source.type ~= 'doc.see.name' then @@ -197,6 +203,7 @@ return function (uri, offset) end sortResults(results) + jumpSource(results) return results end diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua index f03f4361..830b2f2f 100644 --- a/script/core/diagnostics/ambiguity-1.lua +++ b/script/core/diagnostics/ambiguity-1.lua @@ -27,10 +27,10 @@ local literalMap = { return function (uri, callback) local state = files.getState(uri) - if not state then + local text = files.getText(uri) + if not state or not text then return end - local text = files.getText(uri) guide.eachSourceType(state.ast, 'binary', function (source) if source.op.type ~= 'or' then return diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua new file mode 100644 index 00000000..2d5c3f98 --- /dev/null +++ b/script/core/diagnostics/assign-type-mismatch.lua @@ -0,0 +1,117 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +local checkTypes = { + 'local', + 'setlocal', + 'setglobal', + 'setfield', + 'setindex', + 'setmethod', + 'tablefield', + 'tableindex' +} + +---@param source parser.object +---@return boolean +local function hasMarkType(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' + or doc.type == 'doc.class' then + return true + end + end + return false +end + +---@param source parser.object +---@return boolean +local function hasMarkClass(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' then + return true + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceTypes(state.ast, checkTypes, function (source) + local value = source.value + if not value then + return + end + await.delay() + if source.type == 'setlocal' then + local locNode = vm.compileNode(source.node) + if not locNode:getData 'hasDefined' then + return + end + end + if value.type == 'nil' then + --[[ + ---@class A + local mt + ---@type X + mt._x = nil -- don't warn this + ]] + if hasMarkType(source) then + return + end + if source.type == 'setfield' + or source.type == 'setindex' then + return + end + end + + local valueNode = vm.compileNode(value) + if source.type == 'setindex' then + -- boolean[1] = nil + valueNode = valueNode:copy():removeOptional() + end + + if value.type == 'getfield' + or value.type == 'getindex' then + -- 由于无法对字段进行类型收窄, + -- 因此将假值移除再进行检查 + valueNode = valueNode:copy():setTruthy() + end + + local varNode = vm.compileNode(source) + if vm.canCastType(uri, varNode, valueNode) then + return + end + + -- local Cat = setmetatable({}, {__index = Animal}) 允许逆变 + if hasMarkClass(source) then + if vm.canCastType(uri, valueNode:copy():remove 'table', varNode) then + return + end + end + + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_ASSIGN_TYPE_MISMATCH', { + def = vm.getInfer(varNode):view(uri), + ref = vm.getInfer(valueNode):view(uri), + }), + } + end) +end diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua new file mode 100644 index 00000000..c3d6e1bb --- /dev/null +++ b/script/core/diagnostics/cast-local-type.lua @@ -0,0 +1,50 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'local', function (loc) + if not loc.ref then + return + end + await.delay() + local locNode = vm.compileNode(loc) + if not locNode:getData 'hasDefined' then + return + end + for _, ref in ipairs(loc.ref) do + if ref.type == 'setlocal' and ref.value then + await.delay() + local refNode = vm.compileNode(ref) + local value = ref.value + + if value.type == 'getfield' + or value.type == 'getindex' then + -- 由于无法对字段进行类型收窄, + -- 因此将假值移除再进行检查 + refNode = refNode:copy():setTruthy() + end + + if not vm.canCastType(uri, locNode, refNode) then + callback { + start = ref.start, + finish = ref.finish, + message = lang.script('DIAG_CAST_LOCAL_TYPE', { + def = vm.getInfer(locNode):view(uri), + ref = vm.getInfer(refNode):view(uri), + }), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/cast-type-mismatch.lua b/script/core/diagnostics/cast-type-mismatch.lua new file mode 100644 index 00000000..a48e6cca --- /dev/null +++ b/script/core/diagnostics/cast-type-mismatch.lua @@ -0,0 +1,45 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.cast' and doc.loc then + await.delay() + local defs = vm.getDefs(doc.loc) + local loc = defs[1] + if loc then + local defNode = vm.compileNode(loc) + if defNode:getData 'hasDefined' then + for _, cast in ipairs(doc.casts) do + if not cast.mode and cast.extends then + local refNode = vm.compileNode(cast.extends) + if not vm.canCastType(uri, defNode, refNode) then + callback { + start = cast.extends.start, + finish = cast.extends.finish, + message = lang.script('DIAG_CAST_TYPE_MISMATCH', { + def = vm.getInfer(defNode):view(uri), + ref = vm.getInfer(refNode):view(uri), + }) + } + end + end + end + end + end + end + end +end diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua index 40d4afeb..fcd2021d 100644 --- a/script/core/diagnostics/circle-doc-class.lua +++ b/script/core/diagnostics/circle-doc-class.lua @@ -2,7 +2,9 @@ local files = require 'files' local lang = require 'language' local vm = require 'vm' local guide = require 'parser.guide' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -18,6 +20,7 @@ return function (uri, callback) if not doc.extends then goto CONTINUE end + await.delay() local myName = guide.getKeyName(doc) local list = { doc } local mark = {} diff --git a/script/core/diagnostics/close-non-object.lua b/script/core/diagnostics/close-non-object.lua index c97014fa..1a42b800 100644 --- a/script/core/diagnostics/close-non-object.lua +++ b/script/core/diagnostics/close-non-object.lua @@ -25,10 +25,11 @@ return function (uri, callback) return end local infer = vm.getInfer(source.value) - if not infer:hasClass() - and not infer:hasType 'nil' - and not infer:hasType 'table' - and infer:view('any', uri) ~= 'any' then + if not infer:hasClass(uri) + and not infer:hasType(uri, 'nil') + and not infer:hasType(uri, 'table') + and not infer:hasUnknown(uri) + and not infer:hasAny(uri) then callback { start = source.value.start, finish = source.value.finish, diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua index 21f7e83a..963fd9ed 100644 --- a/script/core/diagnostics/code-after-break.lua +++ b/script/core/diagnostics/code-after-break.lua @@ -2,7 +2,9 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -10,12 +12,14 @@ return function (uri, callback) end local mark = {} + ---@async guide.eachSourceType(state.ast, 'break', function (source) local list = source.parent if mark[list] then return end mark[list] = true + await.delay() for i = #list, 1, -1 do local src = list[i] if src == source then diff --git a/script/core/diagnostics/codestyle-check.lua b/script/core/diagnostics/codestyle-check.lua index 34d55ee2..25603b4b 100644 --- a/script/core/diagnostics/codestyle-check.lua +++ b/script/core/diagnostics/codestyle-check.lua @@ -7,7 +7,7 @@ local pformatting = require 'provider.formatting' ---@async return function(uri, callback) - local text = files.getText(uri) + local text = files.getOriginText(uri) if not text then return end diff --git a/script/core/diagnostics/count-down-loop.lua b/script/core/diagnostics/count-down-loop.lua index 9bc4b273..bd6e5ee3 100644 --- a/script/core/diagnostics/count-down-loop.lua +++ b/script/core/diagnostics/count-down-loop.lua @@ -10,12 +10,15 @@ return function (uri, callback) end guide.eachSourceType(state.ast, 'loop', function (source) - local maxNumer = source.max and tonumber(source.max[1]) - if maxNumer ~= 1 then + local maxNumber = source.max and tonumber(source.max[1]) + if not maxNumber then return end local minNumber = source.init and tonumber(source.init[1]) - if minNumber and minNumber <= 1 then + if minNumber and maxNumber and minNumber <= maxNumber then + return + end + if not minNumber and maxNumber ~= 1 then return end if not source.step then @@ -24,7 +27,7 @@ return function (uri, callback) finish = source.max.finish, message = lang.script('DIAG_COUNT_DOWN_LOOP' , ('%s, %s'):format(text:sub( - guide.positionToOffset(state, source.init.start), + guide.positionToOffset(state, source.init.start + 1), guide.positionToOffset(state, source.max.finish) ), '-1') ) @@ -37,7 +40,7 @@ return function (uri, callback) finish = source.step.finish, message = lang.script('DIAG_COUNT_DOWN_LOOP' , ('%s, -%s'):format(text:sub( - guide.positionToOffset(state, source.init.start), + guide.positionToOffset(state, source.init.start + 1), guide.positionToOffset(state, source.max.finish) ), source.step[1]) ) diff --git a/script/core/diagnostics/deprecated.lua b/script/core/diagnostics/deprecated.lua index 27920c43..85ae2d19 100644 --- a/script/core/diagnostics/deprecated.lua +++ b/script/core/diagnostics/deprecated.lua @@ -15,7 +15,7 @@ return function (uri, callback) return end - local dglobals = config.get(uri, 'Lua.diagnostics.globals') + local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) local rspecial = config.get(uri, 'Lua.runtime.special') guide.eachSourceTypes(ast.ast, types, function (src) ---@async diff --git a/script/core/diagnostics/different-requires.lua b/script/core/diagnostics/different-requires.lua index de063c9f..22e3e681 100644 --- a/script/core/diagnostics/different-requires.lua +++ b/script/core/diagnostics/different-requires.lua @@ -21,7 +21,7 @@ return function (uri, callback) return end local literal = arg1[1] - local results = rpath.findUrisByRequirePath(uri, literal) + local results = rpath.findUrisByRequireName(uri, literal) if not results or #results ~= 1 then return end diff --git a/script/core/diagnostics/duplicate-doc-alias.lua b/script/core/diagnostics/duplicate-doc-alias.lua index 3df6f972..360358e4 100644 --- a/script/core/diagnostics/duplicate-doc-alias.lua +++ b/script/core/diagnostics/duplicate-doc-alias.lua @@ -2,7 +2,9 @@ local files = require 'files' local lang = require 'language' local vm = require 'vm' local guide = require 'parser.guide' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -15,14 +17,20 @@ return function (uri, callback) local cache = {} for _, doc in ipairs(state.ast.docs) do - if doc.type == 'doc.alias' then + if doc.type == 'doc.alias' + or doc.type == 'doc.enum' then local name = guide.getKeyName(doc) + if not name then + return + end + await.delay() if not cache[name] then local docs = vm.getDocSets(uri, name) cache[name] = {} for _, otherDoc in ipairs(docs) do if otherDoc.type == 'doc.alias' - or otherDoc.type == 'doc.class' then + or otherDoc.type == 'doc.class' + or otherDoc.type == 'doc.enum' then cache[name][#cache[name]+1] = { start = otherDoc.start, finish = otherDoc.finish, @@ -33,10 +41,10 @@ return function (uri, callback) end if #cache[name] > 1 then callback { - start = doc.alias.start, - finish = doc.alias.finish, + start = (doc.alias or doc.enum).start, + finish = (doc.alias or doc.enum).finish, related = cache, - message = lang.script('DIAG_DUPLICATE_DOC_CLASS', name) + message = lang.script('DIAG_DUPLICATE_DOC_ALIAS', name) } end end diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua index d4116b9b..a30dfa88 100644 --- a/script/core/diagnostics/duplicate-doc-field.lua +++ b/script/core/diagnostics/duplicate-doc-field.lua @@ -1,5 +1,7 @@ local files = require 'files' local lang = require 'language' +local vm = require 'vm.vm' +local await = require 'await' local function getFieldEventName(doc) if not doc.extends then @@ -28,6 +30,7 @@ local function getFieldEventName(doc) return nil end +---@async return function (uri, callback) local state = files.getState(uri) if not state then @@ -45,7 +48,13 @@ return function (uri, callback) mark = {} elseif doc.type == 'doc.field' then if mark then - local name = ('%q'):format(doc.field[1]) + await.delay() + local name + if doc.field.type == 'doc.type' then + name = ('[%s]'):format(vm.getInfer(doc.field):view(uri)) + else + name = ('%q'):format(doc.field[1]) + end local eventName = getFieldEventName(doc) if eventName then name = name .. '|' .. eventName diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua index 5097ab3a..dfd9bd4b 100644 --- a/script/core/diagnostics/duplicate-index.lua +++ b/script/core/diagnostics/duplicate-index.lua @@ -2,14 +2,17 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' +local await = require 'await' +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end - + ---@async guide.eachSourceType(ast.ast, 'table', function (source) + await.delay() local mark = {} for _, obj in ipairs(source) do if obj.type == 'tablefield' diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua index 8052c420..ce67ab46 100644 --- a/script/core/diagnostics/duplicate-set-field.lua +++ b/script/core/diagnostics/duplicate-set-field.lua @@ -3,17 +3,21 @@ local lang = require 'language' local define = require 'proto.define' local guide = require 'parser.guide' local vm = require 'vm' +local await = require 'await' +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end + ---@async guide.eachSourceType(ast.ast, 'local', function (source) if not source.ref then return end + await.delay() local sets = {} for _, ref in ipairs(source.ref) do if ref.type ~= 'getlocal' then @@ -48,10 +52,12 @@ return function (uri, callback) local blocks = {} for _, value in ipairs(values) do local block = guide.getBlock(value) - if not blocks[block] then - blocks[block] = {} + if block then + if not blocks[block] then + blocks[block] = {} + end + blocks[block][#blocks[block]+1] = value end - blocks[block][#blocks[block]+1] = value end for _, defs in pairs(blocks) do if #defs <= 1 then diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua index fc205d7e..e05b6aef 100644 --- a/script/core/diagnostics/empty-block.lua +++ b/script/core/diagnostics/empty-block.lua @@ -2,15 +2,18 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local define = require 'proto.define' +local await = require 'await' --- 检查空代码块 +-- 检查空代码块 -- 但是排除忙等待(repeat/while) +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end + await.delay() guide.eachSourceType(ast.ast, 'if', function (source) for _, block in ipairs(source) do if #block > 0 then @@ -24,6 +27,7 @@ return function (uri, callback) message = lang.script.DIAG_EMPTY_BLOCK, } end) + await.delay() guide.eachSourceType(ast.ast, 'loop', function (source) if #source > 0 then return @@ -35,6 +39,7 @@ return function (uri, callback) message = lang.script.DIAG_EMPTY_BLOCK, } end) + await.delay() guide.eachSourceType(ast.ast, 'in', function (source) if #source > 0 then return diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua index 334fd81a..e154080c 100644 --- a/script/core/diagnostics/global-in-nil-env.lua +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -2,65 +2,35 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' --- TODO: 检查路径是否可达 -local function mayRun(path) - return true -end - return function (uri, callback) - local ast = files.getState(uri) - if not ast then - return - end - local root = guide.getRoot(ast.ast) - local env = guide.getENV(root) - - local nilDefs = {} - if not env or not env.ref then - return - end - for _, ref in ipairs(env.ref) do - if ref.type == 'setlocal' then - if ref.value and ref.value.type == 'nil' then - nilDefs[#nilDefs+1] = ref - end - end - end - - if #nilDefs == 0 then + local state = files.getState(uri) + if not state then return end local function check(source) local node = source.node if node.tag == '_ENV' then - local ok - for _, nilDef in ipairs(nilDefs) do - local mode, pathA = guide.getPath(nilDef, source) - if mode == 'before' - and mayRun(pathA) then - ok = nilDef - break - end - end - if ok then - callback { - start = source.start, - finish = source.finish, - uri = uri, - message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, - related = { - { - start = ok.start, - finish = ok.finish, - uri = uri, - } + return + end + + if not node.value or node.value.type == 'nil' then + callback { + start = source.start, + finish = source.finish, + uri = uri, + message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, + related = { + { + start = node.start, + finish = node.finish, + uri = uri, } } - end + } end end - guide.eachSourceType(ast.ast, 'getglobal', check) - guide.eachSourceType(ast.ast, 'setglobal', check) + guide.eachSourceType(state.ast, 'getglobal', check) + guide.eachSourceType(state.ast, 'setglobal', check) end diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua index b4ae3715..c33de6ce 100644 --- a/script/core/diagnostics/init.lua +++ b/script/core/diagnostics/init.lua @@ -3,14 +3,22 @@ local define = require 'proto.define' local config = require 'config' local await = require 'await' local vm = require "vm.vm" +local util = require 'utility' +local diagd = require 'proto.diagnostic' -- 把耗时最长的诊断放到最后面 local diagSort = { - ['redundant-value'] = 96, - ['not-yieldable'] = 97, - ['deprecated'] = 98, - ['undefined-field'] = 99, - ['redundant-parameter'] = 100, + ['redundant-value'] = 100, + ['not-yieldable'] = 100, + ['deprecated'] = 100, + ['undefined-field'] = 110, + ['redundant-parameter'] = 110, + ['cast-local-type'] = 120, + ['assign-type-mismatch'] = 120, + ['param-type-mismatch'] = 120, + ['missing-return'] = 120, + ['missing-return-value'] = 120, + ['redundant-return-value'] = 120, } local diagList = {} @@ -46,30 +54,86 @@ local function checkSleep(uri, passed) sleepRest = sleepRest - sleeped end +---@param uri uri +---@param name string +---@return string +local function getSeverity(uri, name) + local severity = config.get(uri, 'Lua.diagnostics.severity')[name] + or define.DiagnosticDefaultSeverity[name] + if severity:sub(-1) == '!' then + return severity:sub(1, -2) + end + local groupSeverity = config.get(uri, 'Lua.diagnostics.groupSeverity') + local groups = diagd.getGroups(name) + local groupLevel = 999 + for _, groupName in ipairs(groups) do + local gseverity = groupSeverity[groupName] + if gseverity and gseverity ~= 'Fallback' then + groupLevel = math.min(groupLevel, define.DiagnosticSeverity[gseverity]) + end + end + if groupLevel == 999 then + return severity + end + for severityName, level in pairs(define.DiagnosticSeverity) do + if level == groupLevel then + return severityName + end + end + return severity +end + +---@param uri uri +---@param name string +---@return string +local function getStatus(uri, name) + local status = config.get(uri, 'Lua.diagnostics.neededFileStatus')[name] + or define.DiagnosticDefaultNeededFileStatus[name] + if status:sub(-1) == '!' then + return status:sub(1, -2) + end + local groupStatus = config.get(uri, 'Lua.diagnostics.groupFileStatus') + local groups = diagd.getGroups(name) + local groupLevel = 0 + for _, groupName in ipairs(groups) do + local gstatus = groupStatus[groupName] + if gstatus and gstatus ~= 'Fallback' then + groupLevel = math.max(groupLevel, define.DiagnosticFileStatus[gstatus]) + end + end + if groupLevel == 0 then + return status + end + for statusName, level in pairs(define.DiagnosticFileStatus) do + if level == groupLevel then + return statusName + end + end + return status +end + ---@async ---@param uri uri ---@param name string ---@param isScopeDiag boolean ---@param response async fun(result: any) local function check(uri, name, isScopeDiag, response) - if config.get(uri, 'Lua.diagnostics.disable')[name] then + local disables = config.get(uri, 'Lua.diagnostics.disable') + if util.arrayHas(disables, name) then return end - local level = config.get(uri, 'Lua.diagnostics.severity')[name] - or define.DiagnosticDefaultSeverity[name] - - local neededFileStatus = config.get(uri, 'Lua.diagnostics.neededFileStatus')[name] - or define.DiagnosticDefaultNeededFileStatus[name] + local severity = getSeverity(uri, name) + local status = getStatus(uri, name) - if neededFileStatus == 'None' then + if status == 'None' then return end - if neededFileStatus == 'Opened' and not files.isOpen(uri) then + if status == 'Opened' and not files.isOpen(uri) then return end - local severity = define.DiagnosticSeverity[level] + local level = define.DiagnosticSeverity[severity] local clock = os.clock() local mark = {} ---@async @@ -85,7 +149,7 @@ local function check(uri, name, isScopeDiag, response) end mark[result.start] = true - result.level = severity or result.level + result.level = level or result.level result.code = name response(result) end, name) diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua index d03e8c70..68bec234 100644 --- a/script/core/diagnostics/lowercase-global.lua +++ b/script/core/diagnostics/lowercase-global.lua @@ -3,6 +3,7 @@ local guide = require 'parser.guide' local lang = require 'language' local config = require 'config' local vm = require 'vm' +local util = require 'utility' local function isDocClass(source) if not source.bindDocs then @@ -23,10 +24,7 @@ return function (uri, callback) return end - local definedGlobal = {} - for name in pairs(config.get(uri, 'Lua.diagnostics.globals')) do - definedGlobal[name] = true - end + local definedGlobal = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) guide.eachSourceType(ast.ast, 'setglobal', function (source) local name = guide.getKeyName(source) diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua index 698680ca..78b94a09 100644 --- a/script/core/diagnostics/missing-parameter.lua +++ b/script/core/diagnostics/missing-parameter.lua @@ -2,68 +2,27 @@ local files = require 'files' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' +local await = require 'await' -local function countCallArgs(source) - local result = 0 - if not source.args then - return 0 - end - result = result + #source.args - return result -end - ----@return integer -local function countFuncArgs(source) - if not source.args or #source.args == 0 then - return 0 - end - local count = 0 - for i = #source.args, 1, -1 do - local arg = source.args[i] - if arg.type ~= '...' - and not (arg.name and arg.name[1] =='...') - and not vm.compileNode(arg):isNullable() then - return i - end - end - return count -end - -local function getFuncArgs(func) - local funcArgs - local defs = vm.getDefs(func) - for _, def in ipairs(defs) do - if def.type == 'function' - or def.type == 'doc.type.function' then - local args = countFuncArgs(def) - if not funcArgs or args < funcArgs then - funcArgs = args - end - end - end - return funcArgs -end - +---@async return function (uri, callback) local state = files.getState(uri) if not state then return end + ---@async guide.eachSourceType(state.ast, 'call', function (source) - local callArgs = countCallArgs(source) + await.delay() + local _, callArgs = vm.countList(source.args) - local func = source.node - local funcArgs = getFuncArgs(func) + local funcNode = vm.compileNode(source.node) + local funcArgs = vm.countParamsOfNode(funcNode) - if not funcArgs then + if callArgs >= funcArgs then return end - local delta = callArgs - funcArgs - if delta >= 0 then - return - end callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/missing-return-value.lua b/script/core/diagnostics/missing-return-value.lua new file mode 100644 index 00000000..2156d66c --- /dev/null +++ b/script/core/diagnostics/missing-return-value.lua @@ -0,0 +1,66 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' + +local function hasDocReturn(func) + if not func.bindDocs then + return false + end + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + return true + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + await.delay() + if not hasDocReturn(source) then + return + end + local min = vm.countReturnsOfFunction(source) + if min == 0 then + return + end + local returns = source.returns + if not returns then + return + end + for _, ret in ipairs(returns) do + local rmin, rmax = vm.countList(ret) + if rmax < min then + if rmin == rmax then + callback { + start = ret.start, + finish = ret.start + #'return', + message = lang.script('DIAG_MISSING_RETURN_VALUE', { + min = min, + rmax = rmax, + }), + } + else + callback { + start = ret.start, + finish = ret.start + #'return', + message = lang.script('DIAG_MISSING_RETURN_VALUE_RANGE', { + min = min, + rmin = rmin, + rmax = rmax, + }), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/missing-return.lua b/script/core/diagnostics/missing-return.lua new file mode 100644 index 00000000..e3539ac0 --- /dev/null +++ b/script/core/diagnostics/missing-return.lua @@ -0,0 +1,86 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' + +---@param block parser.object +---@return boolean +local function hasReturn(block) + if block.hasReturn or block.hasError then + return true + end + if block.type == 'if' then + local hasElse + for _, subBlock in ipairs(block) do + if not hasReturn(subBlock) then + return false + end + if subBlock.type == 'elseblock' then + hasElse = true + end + end + return hasElse == true + else + if block.type == 'while' then + if vm.testCondition(block.filter) then + return true + end + end + for _, action in ipairs(block) do + if guide.isBlockType(action) then + if hasReturn(action) then + return true + end + end + end + end + return false +end + +---@param func parser.object +---@return boolean +local function isEmptyFunction(func) + if #func > 0 then + return false + end + local startRow = guide.rowColOf(func.start) + local finishRow = guide.rowColOf(func.finish) + return finishRow - startRow <= 1 +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + -- check declare only + if isEmptyFunction(source) then + return + end + await.delay() + if vm.countReturnsOfFunction(source, true) == 0 then + return + end + if hasReturn(source) then + return + end + local lastAction = source[#source] + local pos + if lastAction then + pos = lastAction.range or lastAction.finish + else + local row = guide.rowColOf(source.finish) + pos = guide.positionOf(row - 1, 0) + end + callback { + start = pos, + finish = pos, + message = lang.script('DIAG_MISSING_RETURN'), + } + end) +end diff --git a/script/core/diagnostics/need-check-nil.lua b/script/core/diagnostics/need-check-nil.lua index 98fdfd08..9c86939a 100644 --- a/script/core/diagnostics/need-check-nil.lua +++ b/script/core/diagnostics/need-check-nil.lua @@ -2,14 +2,18 @@ local files = require 'files' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' +local await = require 'await' +---@async return function (uri, callback) local state = files.getState(uri) if not state then return end + ---@async guide.eachSourceType(state.ast, 'getlocal', function (src) + await.delay() local checkNil local nxt = src.next if nxt then @@ -24,11 +28,15 @@ return function (uri, callback) if call and call.type == 'call' and call.node == src then checkNil = true end + local setIndex = src.parent + if setIndex and setIndex.type == 'setindex' and setIndex.index == src then + checkNil = true + end if not checkNil then return end local node = vm.compileNode(src) - if node:hasFalsy() then + if node:hasFalsy() and not vm.getInfer(src):hasType(uri, 'any') then callback { start = src.start, finish = src.finish, diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua index 669ed2bb..bd114959 100644 --- a/script/core/diagnostics/newfield-call.lua +++ b/script/core/diagnostics/newfield-call.lua @@ -1,16 +1,20 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' +local await = require 'await' +local sub = require 'core.substring' +---@async return function (uri, callback) - local ast = files.getState(uri) - if not ast then + local state = files.getState(uri) + local text = files.getText(uri) + if not state or not text then return end - local text = files.getText(uri) - - guide.eachSourceType(ast.ast, 'table', function (source) + ---@async + guide.eachSourceType(state.ast, 'table', function (source) + await.delay() for i = 1, #source do local field = source[i] if field.type ~= 'tableexp' then @@ -33,8 +37,8 @@ return function (uri, callback) start = call.start, finish = call.finish, message = lang.script('DIAG_PREFIELD_CALL' - , text:sub(func.start, func.finish) - , text:sub(args.start, args.finish) + , sub(state)(func.start + 1, func.finish) + , sub(state)(args.start + 1, args.finish) ) } end diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua index 3f2d5ca5..2ba2ce03 100644 --- a/script/core/diagnostics/newline-call.lua +++ b/script/core/diagnostics/newline-call.lua @@ -1,14 +1,18 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' +local await = require 'await' +local sub = require 'core.substring' +---@async return function (uri, callback) local state = files.getState(uri) local text = files.getText(uri) - if not state then + if not state or not text then return end + ---@async guide.eachSourceType(state.ast, 'call', function (source) local node = source.node local args = source.args @@ -20,6 +24,9 @@ return function (uri, callback) if not source.next then return end + + await.delay() + local startOffset = guide.positionToOffset(state, args.start) + 1 local finishOffset = guide.positionToOffset(state, args.finish) if text:sub(startOffset, startOffset) ~= '(' @@ -38,8 +45,8 @@ return function (uri, callback) start = node.start, finish = args.finish, message = lang.script('DIAG_PREVIOUS_CALL' - , text:sub(node.start, node.finish) - , text:sub(args.start, args.finish) + , sub(state)(node.start + 1, node.finish) + , sub(state)(args.start + 1, args.finish) ), } end diff --git a/script/core/diagnostics/no-unknown.lua b/script/core/diagnostics/no-unknown.lua index 48aab5da..e706931a 100644 --- a/script/core/diagnostics/no-unknown.lua +++ b/script/core/diagnostics/no-unknown.lua @@ -2,25 +2,30 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' local vm = require 'vm' +local await = require 'await' +local types = { + 'local', + 'setlocal', + 'setglobal', + 'getglobal', + 'setfield', + 'setindex', + 'tablefield', + 'tableindex', +} + +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end - guide.eachSource(ast.ast, function (source) - if source.type ~= 'local' - and source.type ~= 'setlocal' - and source.type ~= 'setglobal' - and source.type ~= 'getglobal' - and source.type ~= 'setfield' - and source.type ~= 'setindex' - and source.type ~= 'tablefield' - and source.type ~= 'tableindex' then - return - end - if vm.getInfer(source):view() == 'unknown' then + ---@async + guide.eachSourceTypes(ast.ast, types, function (source) + await.delay() + if vm.getInfer(source):view(uri) == 'unknown' then callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/not-yieldable.lua b/script/core/diagnostics/not-yieldable.lua index a1c84276..055025d4 100644 --- a/script/core/diagnostics/not-yieldable.lua +++ b/script/core/diagnostics/not-yieldable.lua @@ -11,7 +11,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if vm.getInfer(arg):hasType 'any' + if vm.getInfer(arg):hasType(guide.getUri(def), 'any') or vm.isAsync(arg, true) or arg.type == '...' then return true @@ -22,7 +22,7 @@ local function isYieldAble(defs, i) local arg = def.args and def.args[i] if arg then hasFuncDef = true - if vm.getInfer(arg.extends):hasType 'any' + if vm.getInfer(arg.extends):hasType(guide.getUri(def), 'any') or vm.isAsync(arg.extends, true) then return true end diff --git a/script/core/diagnostics/param-type-mismatch.lua b/script/core/diagnostics/param-type-mismatch.lua new file mode 100644 index 00000000..6f34f579 --- /dev/null +++ b/script/core/diagnostics/param-type-mismatch.lua @@ -0,0 +1,72 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@param funcNode vm.node + ---@param i integer + ---@return vm.node? + local function getDefNode(funcNode, i) + local defNode = vm.createNode() + for f in funcNode:eachObject() do + if f.type == 'function' + or f.type == 'doc.type.function' then + local param = f.args and f.args[i] + if param then + defNode:merge(vm.compileNode(param)) + if param[1] == '...' then + defNode:addOptional() + end + end + end + end + if defNode:isEmpty() then + return nil + end + return defNode + end + + ---@async + guide.eachSourceType(state.ast, 'call', function (source) + if not source.args then + return + end + await.delay() + local funcNode = vm.compileNode(source.node) + for i, arg in ipairs(source.args) do + if i == 1 and source.node.type == 'getmethod' then + goto CONTINUE + end + local refNode = vm.compileNode(arg) + local defNode = getDefNode(funcNode, i) + if not defNode then + goto CONTINUE + end + if arg.type == 'getfield' + or arg.type == 'getindex' then + -- 由于无法对字段进行类型收窄, + -- 因此将假值移除再进行检查 + refNode = refNode:copy():setTruthy() + end + if not vm.canCastType(uri, defNode, refNode) then + callback { + start = arg.start, + finish = arg.finish, + message = lang.script('DIAG_PARAM_TYPE_MISMATCH', { + def = vm.getInfer(defNode):view(uri), + ref = vm.getInfer(refNode):view(uri), + }) + } + end + ::CONTINUE:: + end + end) +end diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua index 2157ae71..1fb3ca6b 100644 --- a/script/core/diagnostics/redefined-local.lua +++ b/script/core/diagnostics/redefined-local.lua @@ -1,18 +1,23 @@ local files = require 'files' local guide = require 'parser.guide' local lang = require 'language' +local await = require 'await' +---@async return function (uri, callback) local ast = files.getState(uri) if not ast then return end + + ---@async guide.eachSourceType(ast.ast, 'local', function (source) local name = source[1] if name == '_' or name == ast.ENVMode then return end + await.delay() local exist = guide.getLocal(source, name, source.start-1) if exist then callback { diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index 41781df8..9898d9bd 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -2,73 +2,48 @@ local files = require 'files' local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' +local await = require 'await' -local function countCallArgs(source) - local result = 0 - if not source.args then - return 0 - end - result = result + #source.args - return result -end - -local function countFuncArgs(source) - if not source.args or #source.args == 0 then - return 0 - end - local lastArg = source.args[#source.args] - if lastArg.type == '...' - or (lastArg.name and lastArg.name[1] == '...') then - return math.maxinteger - else - return #source.args - end -end - -local function getFuncArgs(func) - local funcArgs - local defs = vm.getDefs(func) - for _, def in ipairs(defs) do - if def.type == 'function' - or def.type == 'doc.type.function' then - local args = countFuncArgs(def) - if not funcArgs or args > funcArgs then - funcArgs = args - end - end - end - return funcArgs -end - +---@async return function (uri, callback) local state = files.getState(uri) if not state then return end + ---@async guide.eachSourceType(state.ast, 'call', function (source) - local callArgs = countCallArgs(source) + await.delay() + local callArgs = vm.countList(source.args) if callArgs == 0 then return end - local func = source.node - local funcArgs = getFuncArgs(func) + local funcNode = vm.compileNode(source.node) + local _, funcArgs = vm.countParamsOfNode(funcNode) - if not funcArgs then - return - end - - local delta = callArgs - funcArgs - if delta <= 0 then + if callArgs <= funcArgs then return end if callArgs == 1 and source.node.type == 'getmethod' then return end - for i = #source.args - delta + 1, #source.args do - local arg = source.args[i] - if arg then + if funcArgs + 1 > #source.args then + local lastArg = source.args[#source.args] + if lastArg.type == 'call' and funcArgs > 0 then + -- 如果函数接收至少一个参数,那么调用方最后一个参数是函数调用 + -- 导致的参数数量太多可以忽略。 + -- 如果函数不接收任何参数,那么任何参数都是错误的。 + return + end + callback { + start = lastArg.start, + finish = lastArg.finish, + message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs) + } + else + for i = funcArgs + 1, #source.args do + local arg = source.args[i] callback { start = arg.start, finish = arg.finish, diff --git a/script/core/diagnostics/redundant-return-value.lua b/script/core/diagnostics/redundant-return-value.lua new file mode 100644 index 00000000..36432f98 --- /dev/null +++ b/script/core/diagnostics/redundant-return-value.lua @@ -0,0 +1,73 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' + +local function hasDocReturn(func) + if not func.bindDocs then + return false + end + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + return true + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + await.delay() + if not hasDocReturn(source) then + return + end + local _, max = vm.countReturnsOfFunction(source) + local returns = source.returns + if not returns then + return + end + for _, ret in ipairs(returns) do + local rmin, rmax = vm.countList(ret) + if rmin > max then + for i = max + 1, #ret - 1 do + callback { + start = ret[i].start, + finish = ret[i].finish, + message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', { + max = max, + rmax = i, + }), + } + end + if #ret == rmax then + callback { + start = ret[#ret].start, + finish = ret[#ret].finish, + message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', { + max = max, + rmax = rmax, + }), + } + else + callback { + start = ret[#ret].start, + finish = ret[#ret].finish, + message = lang.script('DIAG_REDUNDANT_RETURN_VALUE_RANGE', { + max = max, + rmin = #ret, + rmax = rmax, + }), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/return-type-mismatch.lua b/script/core/diagnostics/return-type-mismatch.lua new file mode 100644 index 00000000..cce4aad8 --- /dev/null +++ b/script/core/diagnostics/return-type-mismatch.lua @@ -0,0 +1,76 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +---@param func parser.object +---@return vm.node[]? +local function getDocReturns(func) + if not func.bindDocs then + return nil + end + local returns = {} + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + for _, ret in ipairs(doc.returns) do + returns[ret.returnIndex] = vm.compileNode(ret) + end + end + end + if #returns == 0 then + return nil + end + return returns +end +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@param docReturns vm.node[] + ---@param rets parser.object + local function checkReturn(docReturns, rets) + for i, docRet in ipairs(docReturns) do + local retNode, exp = vm.selectNode(rets, i) + if not exp then + break + end + if retNode:hasName 'nil' then + if exp.type == 'getfield' + or exp.type == 'getindex' then + retNode = retNode:copy():removeOptional() + end + end + if not vm.canCastType(uri, docRet, retNode) then + callback { + start = exp.start, + finish = exp.finish, + message = lang.script('DIAG_RETURN_TYPE_MISMATCH', { + def = vm.getInfer(docRet):view(uri), + ref = vm.getInfer(retNode):view(uri), + index = i, + }), + } + end + end + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + if not source.returns then + return + end + await.delay() + local docReturns = getDocReturns(source) + if not docReturns then + return + end + for _, ret in ipairs(source.returns) do + checkReturn(docReturns, ret) + await.delay() + end + end) +end diff --git a/script/core/diagnostics/spell-check.lua b/script/core/diagnostics/spell-check.lua new file mode 100644 index 00000000..7369a235 --- /dev/null +++ b/script/core/diagnostics/spell-check.lua @@ -0,0 +1,34 @@ +local files = require 'files' +local converter = require 'proto.converter' +local log = require 'log' +local spell = require 'provider.spell' + + +---@async +return function(uri, callback) + local text = files.getOriginText(uri) + if not text then + return + end + + local status, diagnosticInfos = spell.spellCheck(uri, text) + + if not status then + if diagnosticInfos ~= nil then + log.error(diagnosticInfos) + end + + return + end + + if diagnosticInfos then + for _, diagnosticInfo in ipairs(diagnosticInfos) do + callback { + start = converter.unpackPosition(uri, diagnosticInfo.range.start), + finish = converter.unpackPosition(uri, diagnosticInfo.range["end"]), + message = diagnosticInfo.message, + data = diagnosticInfo.data + } + end + end +end diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua index cc51cf77..2e0398b2 100644 --- a/script/core/diagnostics/trailing-space.lua +++ b/script/core/diagnostics/trailing-space.lua @@ -1,25 +1,18 @@ local files = require 'files' local lang = require 'language' local guide = require 'parser.guide' +local await = require 'await' -local function isInString(ast, offset) - local result = false - guide.eachSourceType(ast, 'string', function (source) - if offset >= source.start and offset <= source.finish then - result = true - end - end) - return result -end - +---@async return function (uri, callback) local state = files.getState(uri) - if not state then + local text = files.getText(uri) + if not state or not text then return end - local text = files.getText(uri) local lines = state.lines for i = 0, #lines do + await.delay() local startOffset = lines[i] local finishOffset = text:find('[\r\n]', startOffset) or (#text + 1) local lastOffset = finishOffset - 1 @@ -28,7 +21,8 @@ return function (uri, callback) goto NEXT_LINE end local lastPos = guide.offsetToPosition(state, lastOffset) - if isInString(state.ast, lastPos) then + if guide.isInString(state.ast, lastPos) + or guide.isInComment(state.ast, lastPos) then goto NEXT_LINE end local firstOffset = startOffset diff --git a/script/core/diagnostics/type-check.lua b/script/core/diagnostics/type-check.lua deleted file mode 100644 index cc2b3228..00000000 --- a/script/core/diagnostics/type-check.lua +++ /dev/null @@ -1,3 +0,0 @@ ----@async -return function(uri, callback) -end diff --git a/script/core/diagnostics/unbalanced-assignments.lua b/script/core/diagnostics/unbalanced-assignments.lua index df71f0c9..c21ca993 100644 --- a/script/core/diagnostics/unbalanced-assignments.lua +++ b/script/core/diagnostics/unbalanced-assignments.lua @@ -2,7 +2,17 @@ local files = require 'files' local define = require 'proto.define' local lang = require 'language' local guide = require 'parser.guide' +local await = require 'await' +local types = { + 'local', + 'setlocal', + 'setglobal', + 'setfield', + 'setindex' , +} + +---@async return function (uri, callback, code) local ast = files.getState(uri) if not ast then @@ -31,13 +41,9 @@ return function (uri, callback, code) end end - guide.eachSource(ast.ast, function (source) - if source.type == 'local' - or source.type == 'setlocal' - or source.type == 'setglobal' - or source.type == 'setfield' - or source.type == 'setindex' then - checkSet(source) - end + ---@async + guide.eachSourceTypes(ast.ast, types, function (source) + await.delay() + checkSet(source) end) end diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index 69edb380..bacd4288 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -32,7 +32,7 @@ return function (uri, callback) return end local name = source[1] - if name == '...' then + if name == '...' or name == '_' then return end if #vm.getDocSets(uri, name) > 0 diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua index 98919284..7a60a74f 100644 --- a/script/core/diagnostics/undefined-doc-param.lua +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -1,21 +1,6 @@ local files = require 'files' local lang = require 'language' -local function hasParamName(func, name) - if not func.args then - return false - end - for _, arg in ipairs(func.args) do - if arg[1] == name then - return true - end - if arg.type == '...' and name == '...' then - return true - end - end - return false -end - return function (uri, callback) local state = files.getState(uri) if not state then @@ -27,26 +12,13 @@ return function (uri, callback) end for _, doc in ipairs(state.ast.docs) do - if doc.type ~= 'doc.param' then - goto CONTINUE - end - local binds = doc.bindSources - if not binds then - goto CONTINUE - end - local param = doc.param - local name = param[1] - for _, source in ipairs(binds) do - if source.type == 'function' then - if not hasParamName(source, name) then - callback { - start = param.start, - finish = param.finish, - message = lang.script('DIAG_UNDEFINED_DOC_PARAM', name) - } - end - end + if doc.type == 'doc.param' + and not doc.bindSource then + callback { + start = doc.param.start, + finish = doc.param.finish, + message = lang.script('DIAG_UNDEFINED_DOC_PARAM', doc.param[1]) + } end - ::CONTINUE:: end end diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua index 2f559697..1dff575b 100644 --- a/script/core/diagnostics/undefined-env-child.lua +++ b/script/core/diagnostics/undefined-env-child.lua @@ -3,20 +3,40 @@ local guide = require 'parser.guide' local lang = require 'language' local vm = require "vm.vm" +---@param source parser.object +---@return boolean +local function isBindDoc(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' + or doc.type == 'doc.class' then + return true + end + end + return false +end + return function (uri, callback) - local ast = files.getState(uri) - if not ast then + local state = files.getState(uri) + if not state then return end - guide.eachSourceType(ast.ast, 'getglobal', function (source) - -- 单独验证自己是否在重载过的 _ENV 中有定义 + + guide.eachSourceType(state.ast, 'getglobal', function (source) if source.node.tag == '_ENV' then return end - local defs = vm.getDefs(source) - if #defs > 0 then + + if not isBindDoc(source.node) then return end + + if #vm.getDefs(source) > 0 then + return + end + local key = source[1] callback { start = source.start, diff --git a/script/core/diagnostics/undefined-field.lua b/script/core/diagnostics/undefined-field.lua index 41fcda48..a83241f5 100644 --- a/script/core/diagnostics/undefined-field.lua +++ b/script/core/diagnostics/undefined-field.lua @@ -34,11 +34,11 @@ return function (uri, callback) local node = src.node if node then local ok - for view in vm.getInfer(node):eachView() do - if not skipCheckClass[view] then - ok = true - break + for view in vm.getInfer(node):eachView(uri) do + if skipCheckClass[view] then + return end + ok = true end if not ok then return diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua index bd0aae69..179c9204 100644 --- a/script/core/diagnostics/undefined-global.lua +++ b/script/core/diagnostics/undefined-global.lua @@ -4,6 +4,7 @@ local lang = require 'language' local config = require 'config' local guide = require 'parser.guide' local await = require 'await' +local util = require 'utility' local requireLike = { ['include'] = true, @@ -14,17 +15,17 @@ local requireLike = { ---@async return function (uri, callback) - local ast = files.getState(uri) - if not ast then + local state = files.getState(uri) + if not state then return end - local dglobals = config.get(uri, 'Lua.diagnostics.globals') + local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals')) local rspecial = config.get(uri, 'Lua.runtime.special') local cache = {} -- 遍历全局变量,检查所有没有 set 模式的全局变量 - guide.eachSourceType(ast.ast, 'getglobal', function (src) ---@async + guide.eachSourceType(state.ast, 'getglobal', function (src) ---@async local key = src[1] if not key then return @@ -40,6 +41,7 @@ return function (uri, callback) return end if cache[key] == nil then + await.delay() cache[key] = vm.hasGlobalSets(uri, 'variable', key) end if cache[key] then diff --git a/script/core/diagnostics/unknown-cast-variable.lua b/script/core/diagnostics/unknown-cast-variable.lua new file mode 100644 index 00000000..3f082a50 --- /dev/null +++ b/script/core/diagnostics/unknown-cast-variable.lua @@ -0,0 +1,32 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local vm = require 'vm' +local await = require 'await' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.cast' and doc.loc then + await.delay() + local defs = vm.getDefs(doc.loc) + local loc = defs[1] + if not loc then + callback { + start = doc.loc.start, + finish = doc.loc.finish, + message = lang.script('DIAG_UNKNOWN_CAST_VARIABLE', doc.loc[1]) + } + end + end + end +end diff --git a/script/core/diagnostics/unknown-diag-code.lua b/script/core/diagnostics/unknown-diag-code.lua index 9e492a29..07128a27 100644 --- a/script/core/diagnostics/unknown-diag-code.lua +++ b/script/core/diagnostics/unknown-diag-code.lua @@ -1,6 +1,6 @@ local files = require 'files' local lang = require 'language' -local define = require 'proto.define' +local diag = require 'proto.diagnostic' return function (uri, callback) local state = files.getState(uri) @@ -17,7 +17,7 @@ return function (uri, callback) if doc.names then for _, nameUnit in ipairs(doc.names) do local code = nameUnit[1] - if not define.DiagnosticDefaultSeverity[code] then + if not diag.getDiagAndErrNameMap()[code] then callback { start = nameUnit.start, finish = nameUnit.finish, diff --git a/script/core/diagnostics/unknown-operator.lua b/script/core/diagnostics/unknown-operator.lua new file mode 100644 index 00000000..7404b5ef --- /dev/null +++ b/script/core/diagnostics/unknown-operator.lua @@ -0,0 +1,36 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local vm = require 'vm' +local await = require 'await' +local util = require 'utility' + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.operator' then + local op = doc.op + if op then + local opName = op[1] + if not vm.OP_BINARY_MAP[opName] + and not vm.OP_UNARY_MAP[opName] + and not vm.OP_OTHER_MAP[opName] then + callback { + start = doc.op.start, + finish = doc.op.finish, + message = lang.script('DIAG_UNKNOWN_OPERATOR', opName) + } + end + end + end + end +end diff --git a/script/core/diagnostics/unreachable-code.lua b/script/core/diagnostics/unreachable-code.lua new file mode 100644 index 00000000..4f0a38b7 --- /dev/null +++ b/script/core/diagnostics/unreachable-code.lua @@ -0,0 +1,84 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local await = require 'await' +local define = require 'proto.define' + +---@param source parser.object +---@return boolean +local function allLiteral(source) + local result = true + guide.eachSource(source, function (src) + if src.type ~= 'unary' + and src.type ~= 'binary' + and not guide.isLiteral(src) then + result = false + return false + end + end) + return result +end + +---@param block parser.object +---@return boolean +local function hasReturn(block) + if block.hasReturn or block.hasError then + return true + end + if block.type == 'if' then + local hasElse + for _, subBlock in ipairs(block) do + if not hasReturn(subBlock) then + return false + end + if subBlock.type == 'elseblock' then + hasElse = true + end + end + return hasElse == true + else + if block.type == 'while' then + if vm.testCondition(block.filter) + and not block.breaks + and allLiteral(block.filter) then + return true + end + end + for _, action in ipairs(block) do + if guide.isBlockType(action) then + if hasReturn(action) then + return true + end + end + end + end + return false +end + +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@async + guide.eachSourceTypes(state.ast, {'main', 'function'}, function (source) + await.delay() + for i, action in ipairs(source) do + if guide.isBlockType(action) + and hasReturn(action) then + if i < #source then + callback { + start = source[i+1].start, + finish = source[#source].finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNREACHABLE_CODE'), + } + end + return + end + end + end) +end diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua index 813ac804..a873375f 100644 --- a/script/core/diagnostics/unused-function.lua +++ b/script/core/diagnostics/unused-function.lua @@ -18,7 +18,8 @@ local function isToBeClosed(source) return false end ----@param source parser.object +---@param source parser.object? +---@return boolean local function isValidFunction(source) if not source then return false @@ -55,7 +56,7 @@ local function collect(ast, white, roots, links) for _, ref in ipairs(loc.ref or {}) do if ref.type == 'getlocal' then local func = guide.getParentFunction(ref) - if not isValidFunction(func) or roots[func] then + if not func or not isValidFunction(func) or roots[func] then roots[src] = true return end diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua index d12ceb2b..8f2ee217 100644 --- a/script/core/diagnostics/unused-local.lua +++ b/script/core/diagnostics/unused-local.lua @@ -3,6 +3,8 @@ local guide = require 'parser.guide' local define = require 'proto.define' local lang = require 'language' local vm = require 'vm.vm' +local config = require 'config.config' +local glob = require 'glob' local function hasGet(loc) if not loc.ref then @@ -63,18 +65,24 @@ local function isDocClass(source) return false end -local function isDocParam(source) - if not source.bindDocs then +---@param func parser.object +---@return boolean +local function isEmptyFunction(func) + if #func > 0 then return false end - for _, doc in ipairs(source.bindDocs) do - if doc.type == 'doc.param' then - if doc.param[1] == source[1] then - return true - end - end + local startRow = guide.rowColOf(func.start) + local finishRow = guide.rowColOf(func.finish) + return finishRow - startRow <= 1 +end + +---@param source parser.object +local function isDeclareFunctionParam(source) + if source.parent.type ~= 'funcargs' then + return false end - return false + local func = source.parent.parent + return isEmptyFunction(func) end return function (uri, callback) @@ -82,19 +90,24 @@ return function (uri, callback) if not ast then return end + local ignorePatterns = config.get(uri, 'Lua.diagnostics.unusedLocalExclude') + local ignore = glob.glob(ignorePatterns) guide.eachSourceType(ast.ast, 'local', function (source) local name = source[1] if name == '_' or name == ast.ENVMode then return end + if ignore(name) then + return + end if isToBeClosed(source) then return end if isDocClass(source) then return end - if vm.isMetaFile(uri) and isDocParam(source) then + if isDeclareFunctionParam(source) then return end local data = hasGet(source) diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua index ce033cf3..08f12c4d 100644 --- a/script/core/diagnostics/unused-vararg.lua +++ b/script/core/diagnostics/unused-vararg.lua @@ -15,6 +15,9 @@ return function (uri, callback) end guide.eachSourceType(ast.ast, 'function', function (source) + if #source == 0 then + return + end local args = source.args if not args then return diff --git a/script/core/find-source.lua b/script/core/find-source.lua index 26a411e5..99013b31 100644 --- a/script/core/find-source.lua +++ b/script/core/find-source.lua @@ -21,7 +21,7 @@ return function (ast, position, accept) end end local start, finish = guide.getStartFinish(source) - if finish - start < len and accept[source.type] then + if finish - start <= len and accept[source.type] then result = source len = finish - start end diff --git a/script/core/folding.lua b/script/core/folding.lua index 4f93aed9..7f59636e 100644 --- a/script/core/folding.lua +++ b/script/core/folding.lua @@ -66,7 +66,8 @@ local care = { ['repeat'] = function (source, text, results) local start = source.start local finish = source.keyword[#source.keyword] - if text:sub(finish - #'until' + 1, finish) ~= 'until' then + -- must end with 'until' + if #source.keyword ~= 4 then return end local folding = { @@ -143,6 +144,15 @@ local care = { } results[#results+1] = folding end, + ['doc.alias'] = function (source, text, results) + local folding = { + start = source.start, + finish = source.bindGroup[#source.bindGroup].finish, + kind = 'comment', + hideLastLine = true, + } + results[#results+1] = folding + end, } ---@async diff --git a/script/core/formatting.lua b/script/core/formatting.lua index b52854a4..fb5ca9c7 100644 --- a/script/core/formatting.lua +++ b/script/core/formatting.lua @@ -4,7 +4,10 @@ local log = require("log") return function(uri, options) local text = files.getOriginText(uri) - local ast = files.getState(uri) + local state = files.getState(uri) + if not state then + return + end local status, formattedText = codeFormat.format(uri, text, options) if not status then @@ -17,8 +20,8 @@ return function(uri, options) return { { - start = ast.ast.start, - finish = ast.ast.finish, + start = state.ast.start, + finish = state.ast.finish, text = formattedText, } } diff --git a/script/core/hint.lua b/script/core/hint.lua index f97cdcec..767e531e 100644 --- a/script/core/hint.lua +++ b/script/core/hint.lua @@ -5,6 +5,7 @@ local guide = require 'parser.guide' local await = require 'await' local define = require 'proto.define' local lang = require 'language' +local substr = require 'core.substring' ---@async local function typeHint(uri, results, start, finish) @@ -38,7 +39,7 @@ local function typeHint(uri, results, start, finish) end end await.delay() - local view = vm.getInfer(source):view() + local view = vm.getInfer(source):view(uri) if view == 'any' or view == 'unknown' or view == 'nil' then @@ -189,24 +190,44 @@ local function arrayIndex(uri, results, start, finish) end ---@async - guide.eachSourceBetween(state.ast, start, finish, function (source) - if source.type ~= 'tableexp' then + guide.eachSourceType(state.ast, 'table', function (source) + if source.finish < start or source.start > finish then return end await.delay() if option == 'Auto' then - if not isMixedOrLargeTable(source.parent) then + if not isMixedOrLargeTable(source) then return end end - results[#results+1] = { - text = ('[%d]'):format(source.tindex), - offset = source.start, - kind = define.InlayHintKind.Other, - where = 'left', - source = source.parent, - } + local list = {} + local max = 0 + for _, field in ipairs(source) do + if field.type == 'tableexp' + and field.start < finish + and field.finish > start then + list[#list+1] = field + if field.tindex > max then + max = field.tindex + end + end + end + + if #list > 0 then + local length = #tostring(max) + local fmt = '[%0' .. length .. 'd]' + for _, field in ipairs(list) do + results[#results+1] = { + text = fmt:format(field.tindex), + offset = field.start, + kind = define.InlayHintKind.Other, + where = 'left', + source = field.parent, + } + end + end end) + end ---@async @@ -238,6 +259,72 @@ local function awaitHint(uri, results, start, finish) end) end +local blockTypes = { + 'main', + 'function', + 'for', + 'loop', + 'in', + 'do', + 'repeat', + 'while', + 'ifblock', + 'elseifblock', + 'elseblock', +} + +---@async +local function semicolonHint(uri, results, start, finish) + local state = files.getState(uri) + if not state then + return + end + local mode = config.get(uri, 'Lua.hint.semicolon') + if mode == 'Disable' then + return + end + local subber = substr(state) + ---@async + guide.eachSourceTypes(state.ast, blockTypes, function (src) + await.delay() + for i = 1, #src - 1 do + local current = src[i] + local next = src[i+1] + local left = current.finish + local right = next.start + local text = subber(left, right) + if mode == 'All' then + if not text:find '[,;]' then + results[#results+1] = { + text = ';', + offset = left, + kind = define.InlayHintKind.Other, + where = 'right', + } + end + elseif mode == 'SameLine' then + if not text:find '[,;\r\n]' then + results[#results+1] = { + text = ';', + offset = left, + kind = define.InlayHintKind.Other, + where = 'right', + } + end + end + end + if mode == 'All' then + local last = src[#src] + results[#results+1] = { + text = ';', + offset = last.range or last.finish, + kind = define.InlayHintKind.Other, + where = 'right', + } + end + end) +end + ---@async return function (uri, start, finish) local results = {} @@ -245,5 +332,6 @@ return function (uri, start, finish) paramName(uri, results, start, finish) awaitHint(uri, results, start, finish) arrayIndex(uri, results, start, finish) + semicolonHint(uri, results, start, finish) return results end diff --git a/script/core/hover/args.lua b/script/core/hover/args.lua index c485d9b9..bb4d4297 100644 --- a/script/core/hover/args.lua +++ b/script/core/hover/args.lua @@ -9,7 +9,7 @@ local function asFunction(source) methodDef = true end if methodDef then - args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view 'any') + args[#args+1] = ('self: %s'):format(vm.getInfer(parent.node):view(guide.getUri(source), 'any')) end if source.args then for i = 1, #source.args do @@ -29,15 +29,15 @@ local function asFunction(source) args[#args+1] = ('%s%s: %s'):format( name, optional and '?' or '', - vm.getInfer(argNode):view('any', guide.getUri(source)) + vm.getInfer(argNode):view(guide.getUri(source), 'any') ) elseif arg.type == '...' then - args[#args+1] = ('%s: %s'):format( + args[#args+1] = ('%s%s'):format( '...', - vm.getInfer(arg):view 'any' + vm.getInfer(arg):view(guide.getUri(source), 'any') ) else - args[#args+1] = ('%s'):format(vm.getInfer(arg):view 'any') + args[#args+1] = ('%s'):format(vm.getInfer(arg):view(guide.getUri(source), 'any')) end ::CONTINUE:: end @@ -46,17 +46,17 @@ local function asFunction(source) end local function asDocFunction(source) + local args = {} if not source.args then - return '' + return args end - local args = {} for i = 1, #source.args do local arg = source.args[i] local name = arg.name[1] args[i] = ('%s%s: %s'):format( name, arg.optional and '?' or '', - arg.extends and vm.getInfer(arg.extends):view 'any' or 'any' + arg.extends and vm.getInfer(arg.extends):view(guide.getUri(source), 'any') or 'any' ) end return args diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index e9267c0f..e11dd6c8 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -6,11 +6,12 @@ local lang = require 'language' local util = require 'utility' local guide = require 'parser.guide' local rpath = require 'workspace.require-path' +local furi = require 'file-uri' local function collectRequire(mode, literal, uri) local result, searchers if mode == 'require' then - result, searchers = rpath.findUrisByRequirePath(uri, literal) + result, searchers = rpath.findUrisByRequireName(uri, literal) elseif mode == 'dofile' or mode == 'loadfile' then result = ws.findUrisByFilePath(literal) @@ -82,7 +83,53 @@ local function asString(source) or asStringView(source, literal) end -local function getBindComment(source, docGroup, base) +---@param comment string +---@param suri uri +---@return string? +local function normalizeComment(comment, suri) + if not comment then + return nil + end + if comment:sub(1, 1) == '-' then + comment = comment:sub(2) + end + if comment:sub(1, 1) == '@' then + return nil + end + comment = comment:gsub('(%[.-%]%()(.-)(%))', function (left, path, right) + local scheme = furi.split(path) + if scheme + -- strange way to check `C:/xxx.lua` + and #scheme > 1 then + return + end + local absPath = ws.getAbsolutePath(suri:gsub('/[^/]+$', ''), path) + if not absPath then + return + end + local uri = furi.encode(absPath) + return left .. uri .. right + end) + return comment +end + +local function getBindComment(source) + local uri = guide.getUri(source) + local lines = {} + for _, docComment in ipairs(source.bindComments) do + lines[#lines+1] = normalizeComment(docComment.comment.text, uri) + end + if not lines or #lines == 0 then + return nil + end + return table.concat(lines, '\n') +end + +local function lookUpDocComments(source) + local docGroup = source.bindDocs + if not docGroup then + return + end if source.type == 'setlocal' or source.type == 'getlocal' then source = source.node @@ -90,34 +137,23 @@ local function getBindComment(source, docGroup, base) if source.parent.type == 'funcargs' then return end - local continue - local lines + local uri = guide.getUri(source) + local lines = {} for _, doc in ipairs(docGroup) do if doc.type == 'doc.comment' then - if not continue then - continue = true - lines = {} + lines[#lines+1] = normalizeComment(doc.comment.text, uri) + elseif doc.type == 'doc.type' then + if doc.comment then + lines[#lines+1] = normalizeComment(doc.comment.text, uri) end - if doc.comment.text:sub(1, 1) == '-' then - lines[#lines+1] = doc.comment.text:sub(2) - else - lines[#lines+1] = doc.comment.text - end - elseif doc == base then - break - else - continue = false - if doc.type == 'doc.field' - or doc.type == 'doc.class' then - lines = nil + elseif doc.type == 'doc.class' then + for _, docComment in ipairs(doc.bindComments) do + lines[#lines+1] = normalizeComment(docComment.comment.text, uri) end end end if source.comment then - if not lines then - lines = {} - end - lines[#lines+1] = source.comment.text + lines[#lines+1] = normalizeComment(source.comment.text, uri) end if not lines or #lines == 0 then return nil @@ -128,8 +164,9 @@ end local function tryDocClassComment(source) for _, def in ipairs(vm.getDefs(source)) do if def.type == 'doc.class' - or def.type == 'doc.alias' then - local comment = getBindComment(def, def.bindGroup, def) + or def.type == 'doc.alias' + or def.type == 'doc.enum' then + local comment = getBindComment(def) if comment then return comment end @@ -144,7 +181,7 @@ local function tryDocModule(source) return collectRequire('require', source.module, guide.getUri(source)) end -local function buildEnumChunk(docType, name) +local function buildEnumChunk(docType, name, uri) if not docType then return nil end @@ -152,10 +189,11 @@ local function buildEnumChunk(docType, name) local types = {} local lines = {} for _, tp in ipairs(vm.getDefs(docType)) do - types[#types+1] = vm.getInfer(tp):view() + types[#types+1] = vm.getInfer(tp):view(guide.getUri(docType)) if tp.type == 'doc.type.string' or tp.type == 'doc.type.integer' - or tp.type == 'doc.type.boolean' then + or tp.type == 'doc.type.boolean' + or tp.type == 'doc.type.code' then enums[#enums+1] = tp end local comment = tryDocClassComment(tp) @@ -174,7 +212,7 @@ local function buildEnumChunk(docType, name) (enum.default and '->') or (enum.additional and '+>') or ' |', - vm.viewObject(enum) + vm.viewObject(enum, uri) ) if enum.comment then local first = true @@ -198,26 +236,33 @@ local function getBindEnums(source, docGroup) return end + local uri = guide.getUri(source) local mark = {} local chunks = {} local returnIndex = 0 for _, doc in ipairs(docGroup) do if doc.type == 'doc.param' then local name = doc.param[1] + if name == '...' then + name = '...(param)' + end if mark[name] then goto CONTINUE end mark[name] = true - chunks[#chunks+1] = buildEnumChunk(doc.extends, name) + chunks[#chunks+1] = buildEnumChunk(doc.extends, name, uri) elseif doc.type == 'doc.return' then for _, rtn in ipairs(doc.returns) do returnIndex = returnIndex + 1 local name = rtn.name and rtn.name[1] or ('return #%d'):format(returnIndex) + if name == '...' then + name = '...(return)' + end if mark[name] then goto CONTINUE end mark[name] = true - chunks[#chunks+1] = buildEnumChunk(rtn, name) + chunks[#chunks+1] = buildEnumChunk(rtn, name, uri) end end ::CONTINUE:: @@ -228,37 +273,38 @@ local function getBindEnums(source, docGroup) return table.concat(chunks, '\n\n') end -local function tryDocFieldUpComment(source) - if source.type ~= 'doc.field.name' then +local function tryDocFieldComment(source) + if source.type ~= 'doc.field' then return end - local docField = source.parent - if not docField.bindGroup then - return + if source.comment then + return normalizeComment(source.comment.text, guide.getUri(source)) + end + if source.bindGroup then + return getBindComment(source) end - local comment = getBindComment(docField, docField.bindGroup, docField) - return comment end local function getFunctionComment(source) local docGroup = source.bindDocs + if not docGroup then + return + end local hasReturnComment = false - for _, doc in ipairs(docGroup) do + for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.return' and doc.comment then hasReturnComment = true break end end + local uri = guide.getUri(source) local md = markdown() for _, doc in ipairs(docGroup) do if doc.type == 'doc.comment' then - if doc.comment.text:sub(1, 1) == '-' then - md:add('md', doc.comment.text:sub(2)) - else - md:add('md', doc.comment.text) - end + local comment = normalizeComment(doc.comment.text, uri) + md:add('md', comment) elseif doc.type == 'doc.param' then if doc.comment then md:add('md', ('@*param* `%s` — %s'):format( @@ -295,18 +341,36 @@ local function getFunctionComment(source) local enums = getBindEnums(source, docGroup) md:add('lua', enums) - return md + + local comment = md:string() + if comment == '' then + return nil + end + return comment end local function tryDocComment(source) - if not source.bindDocs then - return + local md = markdown() + if source.type == 'function' then + local comment = getFunctionComment(source) + md:add('md', comment) + source = source.parent end - if source.type ~= 'function' then - local comment = getBindComment(source, source.bindDocs) - return comment + local comment = lookUpDocComments(source) + md:add('md', comment) + if source.type == 'doc.alias' then + local enums = buildEnumChunk(source, source.alias[1], guide.getUri(source)) + md:add('lua', enums) end - return getFunctionComment(source) + if source.type == 'doc.enum' then + local enums = buildEnumChunk(source, source.enum[1], guide.getUri(source)) + md:add('lua', enums) + end + local result = md:string() + if result == '' then + return nil + end + return result end local function tryDocOverloadToComment(source) @@ -315,14 +379,12 @@ local function tryDocOverloadToComment(source) end local doc = source.parent if doc.type ~= 'doc.overload' - or not doc.bindSources then + or not doc.bindSource then return end - for _, src in ipairs(doc.bindSources) do - local md = tryDocComment(src) - if md then - return md - end + local md = tryDocComment(doc.bindSource) + if md then + return md end end @@ -350,6 +412,45 @@ local function tyrDocParamComment(source) end end +---@param source parser.object +local function tryDocEnum(source) + if source.type ~= 'doc.enum' then + return + end + local tbl = source.bindSource + if not tbl then + return + end + local md = markdown() + md:add('lua', '{') + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + if field.value.type == 'integer' + or field.value.type == 'string' then + md:add('lua', (' %s: %s = %s,'):format(key, field.value.type, field.value[1])) + end + if field.value.type == 'binary' + or field.value.type == 'unary' then + local number = vm.getNumber(field.value) + if number then + md:add('lua', (' %s: %s = %s,'):format(key, math.tointeger(number) and 'integer' or 'number', number)) + end + end + ::CONTINUE:: + end + end + md:add('lua', '}') + return md:string() +end + return function (source) if source.type == 'string' then return asString(source) @@ -358,9 +459,10 @@ return function (source) source = source.parent end return tryDocOverloadToComment(source) - or tryDocFieldUpComment(source) + or tryDocFieldComment(source) or tyrDocParamComment(source) or tryDocComment(source) or tryDocClassComment(source) or tryDocModule(source) + or tryDocEnum(source) end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua index 7231944a..5a65cbce 100644 --- a/script/core/hover/init.lua +++ b/script/core/hover/init.lua @@ -39,7 +39,7 @@ local function getHover(source) end local oop - if vm.getInfer(source):view() == 'function' then + if vm.getInfer(source):view(guide.getUri(source)) == 'function' then local defs = vm.getDefs(source) -- make sure `function` is before `doc.type.function` local orders = {} @@ -92,19 +92,21 @@ local function getHover(source) end local accept = { - ['local'] = true, - ['setlocal'] = true, - ['getlocal'] = true, - ['setglobal'] = true, - ['getglobal'] = true, - ['field'] = true, - ['method'] = true, - ['string'] = true, - ['number'] = true, - ['integer'] = true, - ['doc.type.name'] = true, - ['function'] = true, - ['doc.module'] = true, + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['field'] = true, + ['method'] = true, + ['string'] = true, + ['number'] = true, + ['integer'] = true, + ['doc.type.name'] = true, + ['doc.class.name'] = true, + ['doc.enum.name'] = true, + ['function'] = true, + ['doc.module'] = true, } ---@async diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua index 2bbfe806..5c502ec1 100644 --- a/script/core/hover/label.lua +++ b/script/core/hover/label.lua @@ -33,7 +33,10 @@ local function asDocTypeName(source) return '(class) ' .. doc.class[1] end if doc.type == 'doc.alias' then - return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view()) + return '(alias) ' .. doc.alias[1] .. ' ' .. lang.script('HOVER_EXTENDS', vm.getInfer(doc.extends):view(guide.getUri(source))) + end + if doc.type == 'doc.enum' then + return '(enum) ' .. doc.enum[1] end end end @@ -42,7 +45,7 @@ end local function asValue(source, title) local name = buildName(source, false) or '' local ifr = vm.getInfer(source) - local type = ifr:view() + local type = ifr:view(guide.getUri(source)) local literal = ifr:viewLiterals() local cont = buildTable(source) local pack = {} @@ -55,10 +58,11 @@ local function asValue(source, title) and ( type == 'table' or type == 'any' or type == 'unknown' - or type == 'nil') then - type = nil + or type == 'nil' + or type:sub(1, 1) == '{') then + else + pack[#pack+1] = type end - pack[#pack+1] = type if literal then pack[#pack+1] = '=' pack[#pack+1] = literal @@ -139,7 +143,7 @@ local function asDocFieldName(source) break end end - local view = vm.getInfer(source.extends):view() + local view = vm.getInfer(source.extends):view(guide.getUri(source)) if not class then return ('(field) ?.%s: %s'):format(name, view) end @@ -212,7 +216,8 @@ return function (source, oop) elseif source.type == 'number' or source.type == 'integer' then return asNumber(source) - elseif source.type == 'doc.type.name' then + elseif source.type == 'doc.type.name' + or source.type == 'doc.enum.name' then return asDocTypeName(source) elseif source.type == 'doc.field' then return asDocFieldName(source) diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua index f8473638..3fabfb89 100644 --- a/script/core/hover/name.lua +++ b/script/core/hover/name.lua @@ -20,6 +20,9 @@ local function asField(source, oop) local class if source.node.type ~= 'getglobal' then class = vm.getInfer(source.node):viewClass() + if class == 'any' or class == 'unknown' then + class = nil + end end local node = class or buildName(source.node, false) @@ -47,14 +50,12 @@ end local function asDocFunction(source, oop) local doc = guide.getParentType(source, 'doc.type') or guide.getParentType(source, 'doc.overload') - if not doc or not doc.bindSources then + if not doc or not doc.bindSource then return '' end - for _, src in ipairs(doc.bindSources) do - local name = buildName(src, oop) - if name ~= '' then - return name - end + local name = buildName(doc.bindSource, oop) + if name ~= '' then + return name end return '' end diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua index 3d8a94a5..b71b9e5d 100644 --- a/script/core/hover/return.lua +++ b/script/core/hover/return.lua @@ -1,34 +1,5 @@ local vm = require 'vm.vm' - ----@param source parser.object ----@return integer -local function countReturns(source) - local n = 0 - - local docs = source.bindDocs - if docs then - for _, doc in ipairs(docs) do - if doc.type == 'doc.return' then - for _, rtn in ipairs(doc.returns) do - if rtn.returnIndex and rtn.returnIndex > n then - n = rtn.returnIndex - end - end - end - end - end - - local returns = source.returns - if returns then - for _, rtn in ipairs(returns) do - if #rtn > n then - n = #rtn - end - end - end - - return n -end +local guide = require 'parser.guide' ---@param source parser.object ---@return parser.object[] @@ -50,7 +21,7 @@ local function getReturnDocs(source) end local function asFunction(source) - local num = countReturns(source) + local _, _, num = vm.countReturnsOfFunction(source) if num == 0 then return nil end @@ -62,11 +33,14 @@ local function asFunction(source) for i = 1, num do local rtn = vm.getReturnOfFunction(source, i) local doc = docs[i] - local name = doc and doc.name and doc.name[1] and (doc.name[1] .. ': ') - local text = ('%s%s'):format( + local name = doc and doc.name and doc.name[1] + if name and name ~= '...' then + name = name .. ': ' + end + local text = rtn and ('%s%s'):format( name or '', - vm.getInfer(rtn):view() - ) + vm.getInfer(rtn):view(guide.getUri(source)) + ) or 'unknown' if i == 1 then returns[i] = (' -> %s'):format(text) else @@ -83,7 +57,14 @@ local function asDocFunction(source) end local returns = {} for i, rtn in ipairs(source.returns) do - local rtnText = vm.getInfer(rtn):view() + local rtnText = vm.getInfer(rtn):view(guide.getUri(source)) + if rtn.name then + if rtn.name[1] == '...' then + rtnText = rtn.name[1] .. rtnText + else + rtnText = rtn.name[1] .. ': ' .. rtnText + end + end if i == 1 then returns[#returns+1] = (' -> %s'):format(rtnText) else diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua index 16874101..677fd76c 100644 --- a/script/core/hover/table.lua +++ b/script/core/hover/table.lua @@ -30,7 +30,7 @@ local function buildAsHash(uri, keys, nodeMap, reachMax) node:removeOptional() end local ifr = vm.getInfer(node) - local typeView = ifr:view('unknown', uri) + local typeView = ifr:view(uri, 'unknown') local literalView = ifr:viewLiterals() if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( @@ -75,7 +75,7 @@ local function buildAsConst(uri, keys, nodeMap, reachMax) node = node:copy() node:removeOptional() end - local typeView = vm.getInfer(node):view('unknown', uri) + local typeView = vm.getInfer(node):view(uri, 'unknown') local literalView = literalMap[key] if literalView then lines[#lines+1] = (' %s%s: %s = %s,'):format( @@ -154,7 +154,7 @@ local function getNodeMap(fields, keyMap) local nodeMap = {} for _, field in ipairs(fields) do local key = vm.getKeyName(field) - if not keyMap[key] then + if not key or not keyMap[key] then goto CONTINUE end await.delay() @@ -178,9 +178,15 @@ return function (source) return nil end - for view in vm.getInfer(source):eachView() do - if view == 'string' - or vm.isSubType(uri, view, 'string') then + local node = vm.compileNode(source) + for n in node:eachObject() do + if n.type == 'global' and n.cate == 'type' then + if n.name == 'string' + or (n.name ~= 'unknown' and n.name ~= 'any' and vm.isSubType(uri, n.name, 'string')) then + return nil + end + elseif n.type == 'doc.type.string' + or n.type == 'string' then return nil end end diff --git a/script/core/jump-source.lua b/script/core/jump-source.lua new file mode 100644 index 00000000..5ce5e048 --- /dev/null +++ b/script/core/jump-source.lua @@ -0,0 +1,62 @@ +local guide = require 'parser.guide' +local furi = require 'file-uri' +local ws = require 'workspace' + +---@param doc parser.object +---@return uri +local function parseUri(doc) + local uri + local scheme = furi.split(doc.path) + if scheme and #scheme >= 2 then + uri = doc.path + else + local suri = guide.getUri(doc):gsub('[/\\][^/\\]*$', '') + local path = ws.getAbsolutePath(suri, doc.path) + if path then + uri = furi.encode(path) + else + uri = doc.path + end + end + ---@cast uri uri + return uri +end + +---@param results table +return function (results) + for _, result in ipairs(results) do + if result.target.type == 'doc.field.name' + or result.target.type == 'doc.class.name' then + local doc = result.target.parent.source + if doc then + local uri = parseUri(doc) + result.uri = uri + result.target = { + uri = uri, + start = guide.positionOf(doc.line - 1, doc.char), + finish = guide.positionOf(doc.line - 1, doc.char), + } + end + else + local target = result.target + if target.type == 'method' + or target.type == 'field' then + target = target.parent + end + if target.bindDocs then + for _, doc in ipairs(target.bindDocs) do + if doc.type == 'doc.source' + and doc.bindSource == target then + local uri = parseUri(doc) + result.uri = uri + result.target = { + uri = uri, + start = guide.positionOf(doc.line - 1, doc.char), + finish = guide.positionOf(doc.line - 1, doc.char), + } + end + end + end + end + end +end diff --git a/script/core/look-backward.lua b/script/core/look-backward.lua index eeee6017..8d3e3439 100644 --- a/script/core/look-backward.lua +++ b/script/core/look-backward.lua @@ -81,9 +81,19 @@ function m.findTargetSymbol(text, offset, symbol) return nil end -function m.findAnyOffset(text, offset) +---@param text string +---@param offset integer +---@param inline? boolean # 必须在同一行中(排除换行符) +function m.findAnyOffset(text, offset, inline) for i = offset, 1, -1 do - if not m.isSpace(text:sub(i, i)) then + local c = text:sub(i, i) + if inline then + if c == '\r' + or c == '\n' then + return nil + end + end + if not m.isSpace(c) then return i end end diff --git a/script/core/reference.lua b/script/core/reference.lua index 4c9c193d..fa838cff 100644 --- a/script/core/reference.lua +++ b/script/core/reference.lua @@ -2,6 +2,7 @@ local guide = require 'parser.guide' local files = require 'files' local vm = require 'vm' local findSource = require 'core.find-source' +local jumpSource = require 'core.jump-source' local function sortResults(results) -- 先按照顺序排序 @@ -49,6 +50,7 @@ local accept = { ['doc.class.name'] = true, ['doc.extends.name'] = true, ['doc.alias.name'] = true, + ['doc.enum.name'] = true, } ---@async @@ -101,12 +103,17 @@ return function (uri, position) if src.type == 'doc.alias' then src = src.alias end + if src.type == 'doc.enum' then + src = src.enum + end if src.type == 'doc.class.name' or src.type == 'doc.alias.name' + or src.type == 'doc.enum.name' or src.type == 'doc.type.name' or src.type == 'doc.extends.name' then if source.type ~= 'doc.type.name' and source.type ~= 'doc.class.name' + and source.type ~= 'doc.enum.name' and source.type ~= 'doc.extends.name' and source.type ~= 'doc.see.name' then goto CONTINUE @@ -132,6 +139,7 @@ return function (uri, position) end sortResults(results) + jumpSource(results) return results end diff --git a/script/core/rename.lua b/script/core/rename.lua index 7599fad6..90e66224 100644 --- a/script/core/rename.lua +++ b/script/core/rename.lua @@ -81,6 +81,9 @@ local function renameField(source, newname, callback) local uri = guide.getUri(source) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return false + end local func = parent.value -- function mt:name () end --> mt['newname'] = function (self) end local startOffset = guide.positionToOffset(state, parent.start) + 1 @@ -183,13 +186,16 @@ local function ofField(source, newname, callback) local key = guide.getKeyName(source) local refs = vm.getRefs(source) for _, ref in ipairs(refs) do - ofFieldThen(key, ref, newname, callback) + ofFieldThen(key, ref, newname, callback) end end ---@async local function ofGlobal(source, newname, callback) local key = guide.getKeyName(source) + if not key then + return + end local global = vm.getGlobal('variable', key) if not global then return @@ -225,6 +231,9 @@ local function ofDocTypeName(source, newname, callback) if doc.type == 'doc.alias' then callback(doc, doc.alias.start, doc.alias.finish, newname) end + if doc.type == 'doc.enum' then + callback(doc, doc.enum.start, doc.enum.finish, newname) + end end for _, doc in ipairs(global:getGets(uri)) do if doc.type == 'doc.type.name' then @@ -236,16 +245,15 @@ end local function ofDocParamName(source, newname, callback) callback(source, source.start, source.finish, newname) local doc = source.parent - if doc.bindSources then - for _, src in ipairs(doc.bindSources) do - if src.type == 'local' - and src.parent.type == 'funcargs' - and src[1] == source[1] then - renameLocal(src, newname, callback) - if src.ref then - for _, ref in ipairs(src.ref) do - renameLocal(ref, newname, callback) - end + local src = doc.bindSource + if src then + if src.type == 'local' + and src.parent.type == 'funcargs' + and src[1] == source[1] then + renameLocal(src, newname, callback) + if src.ref then + for _, ref in ipairs(src.ref) do + renameLocal(ref, newname, callback) end end end @@ -271,7 +279,8 @@ local function rename(source, newname, callback) return ofGlobal(source, newname, callback) elseif source.type == 'doc.class.name' or source.type == 'doc.type.name' - or source.type == 'doc.alias.name' then + or source.type == 'doc.alias.name' + or source.type == 'doc.enum.name' then return ofDocTypeName(source, newname, callback) elseif source.type == 'doc.param.name' then return ofDocParamName(source, newname, callback) @@ -305,6 +314,7 @@ local function prepareRename(source) or source.type == 'doc.class.name' or source.type == 'doc.type.name' or source.type == 'doc.alias.name' + or source.type == 'doc.enum.name' or source.type == 'doc.param.name' then return source, source[1] elseif source.type == 'string' @@ -345,6 +355,7 @@ local accept = { ['doc.type.name'] = true, ['doc.alias.name'] = true, ['doc.param.name'] = true, + ['doc.enum.name'] = true, } local m = {} diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 33449013..5833807b 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -32,7 +32,7 @@ local Care = util.switch() end options.libGlobals[name] = isLib end - local isFunc = vm.getInfer(source):hasFunction() + local isFunc = vm.getInfer(source):hasFunction(guide.getUri(source)) local type = isFunc and define.TokenTypes['function'] or define.TokenTypes.variable local modifier = isLib and define.TokenModifiers.defaultLibrary or define.TokenModifiers.static @@ -81,7 +81,7 @@ local Care = util.switch() return end end - if vm.getInfer(source):hasFunction() then + if vm.getInfer(source):hasFunction(guide.getUri(source)) then results[#results+1] = { start = source.start, finish = source.finish, @@ -134,19 +134,16 @@ local Care = util.switch() return end local loc = source.node or source + local uri = guide.getUri(loc) -- 1. 值为函数的局部变量 | Local variable whose value is a function - if loc.ref then - for _, ref in ipairs(loc.ref) do - if ref.value and ref.value.type == 'function' then - results[#results+1] = { - start = source.start, - finish = source.finish, - type = define.TokenTypes['function'], - modifieres = define.TokenModifiers.declaration, - } - return - end - end + if vm.getInfer(source):hasFunction(uri) then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes['function'], + modifieres = define.TokenModifiers.declaration, + } + return end -- 3. 特殊变量 | Special variableif source[1] == '_ENV' then if loc[1] == '_ENV' then @@ -196,7 +193,7 @@ local Care = util.switch() end end -- 6. References to other functions - if vm.getInfer(loc):hasFunction() then + if vm.getInfer(loc):hasFunction(guide.getUri(source)) then results[#results+1] = { start = source.start, finish = source.finish, @@ -449,6 +446,7 @@ local Care = util.switch() end end) : case 'doc.alias.name' + : case 'doc.enum.name' : call(function (source, options, results) if not options.annotation then return @@ -667,6 +665,14 @@ local Care = util.switch() type = define.TokenTypes.keyword, } end) + : case 'doc.cast.block' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.keyword, + } + end) : case 'doc.cast.name' : call(function (source, options, results) results[#results+1] = { @@ -675,6 +681,23 @@ local Care = util.switch() type = define.TokenTypes.variable, } end) + : case 'doc.type.code' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.string, + modifieres = define.TokenModifiers.abstract, + } + end) + : case 'doc.operator.name' + : call(function (source, options, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.operator, + } + end) local function buildTokens(uri, results) local tokens = {} @@ -811,9 +834,13 @@ return function (uri, start, finish) keyword = config.get(uri, 'Lua.semantic.keyword'), } + local n = 0 guide.eachSourceBetween(state.ast, start, finish, function (source) ---@async Care(source.type, source, options, results) - await.delay() + n = n + 1 + if n % 100 == 0 then + await.delay() + end end) for _, comm in ipairs(state.comms) do diff --git a/script/core/signature.lua b/script/core/signature.lua index 025e70b7..21e954bf 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -8,6 +8,9 @@ local lookback = require 'core.look-backward' local function findNearCall(uri, ast, pos) local text = files.getText(uri) local state = files.getState(uri) + if not state or not text then + return nil + end local nearCall guide.eachSourceContain(ast.ast, pos, function (src) if src.type == 'call' @@ -65,27 +68,30 @@ local function makeOneSignature(source, oop, index) } end -- 不定参数 - if index > i and i > 0 then + if index and index > i and i > 0 then local lastLabel = params[i].label local text = label:sub(lastLabel[1] + 1, lastLabel[2]) if text:sub(1, 3) == '...' then index = i end end + if #params < (index or 0) then + return nil + end return { label = label, params = params, - index = index, + index = index or 1, description = hoverDesc(source), } end ---@async local function makeSignatures(text, call, pos) - local node = call.node - local oop = node.type == 'method' - or node.type == 'getmethod' - or node.type == 'setmethod' + local func = call.node + local oop = func.type == 'method' + or func.type == 'getmethod' + or func.type == 'setmethod' local index if call.args then local args = {} @@ -121,13 +127,13 @@ local function makeSignatures(text, call, pos) index = #args end end - else - index = 1 end local signs = {} - local defs = vm.getDefs(node) + local node = vm.compileNode(func) + ---@type vm.node + node = node:getData 'originNode' or node local mark = {} - for _, src in ipairs(defs) do + for src in node:eachObject() do if src.type == 'function' or src.type == 'doc.type.function' then if not mark[src] then @@ -142,10 +148,10 @@ end ---@async return function (uri, pos) local state = files.getState(uri) - if not state then + local text = files.getText(uri) + if not state or not text then return nil end - local text = files.getText(uri) local offset = guide.positionToOffset(state, pos) pos = guide.offsetToPosition(state, lookback.skipSpace(text, offset)) local call = findNearCall(uri, state, pos) @@ -156,5 +162,8 @@ return function (uri, pos) if not signs or #signs == 0 then return nil end + table.sort(signs, function (a, b) + return #a.params < #b.params + end) return signs end diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua index d8434c8c..a1c2b29f 100644 --- a/script/core/type-definition.lua +++ b/script/core/type-definition.lua @@ -4,6 +4,7 @@ local vm = require 'vm' local findSource = require 'core.find-source' local guide = require 'parser.guide' local rpath = require 'workspace.require-path' +local jumpSource = require 'core.jump-source' local function sortResults(results) -- 先按照顺序排序 @@ -51,6 +52,7 @@ local accept = { ['doc.class.name'] = true, ['doc.extends.name'] = true, ['doc.alias.name'] = true, + ['doc.enum.name'] = true, ['doc.see.name'] = true, ['doc.see.field'] = true, } @@ -74,7 +76,7 @@ local function checkRequire(source, offset) return nil end if libName == 'require' then - return rpath.findUrisByRequirePath(guide.getUri(source), literal) + return rpath.findUrisByRequireName(guide.getUri(source), literal) elseif libName == 'dofile' or libName == 'loadfile' then return workspace.findUrisByFilePath(literal) @@ -144,6 +146,9 @@ return function (uri, offset) if src.type == 'doc.alias' then src = src.alias end + if src.type == 'doc.enum' then + src = src.enum + end if src.type == 'doc.class.name' or src.type == 'doc.alias.name' or src.type == 'doc.type.function' @@ -164,6 +169,7 @@ return function (uri, offset) end sortResults(results) + jumpSource(results) return results end diff --git a/script/doctor.lua b/script/doctor.lua index 87cdcfcb..e1044689 100644 --- a/script/doctor.lua +++ b/script/doctor.lua @@ -538,7 +538,7 @@ m.exclude = private(function (...) end) --- 比较2个报告 ----@return string +---@return table m.compare = private(function (old, new) local newHash = {} local ret = {} diff --git a/script/file-uri.lua b/script/file-uri.lua index ccd47156..3e916acf 100644 --- a/script/file-uri.lua +++ b/script/file-uri.lua @@ -24,9 +24,6 @@ local m = {} ---@param path string ---@return uri uri function m.encode(path) - if not path then - return nil - end local authority = '' if platform.OS == 'Windows' then path = path:gsub('\\', '/') @@ -70,9 +67,6 @@ end ---@param uri uri ---@return string path function m.decode(uri) - if not uri then - return nil - end local scheme, authority, path = uri:match('([^:]*):?/?/?([^/]*)(.*)') if not scheme then return '' @@ -95,7 +89,27 @@ function m.decode(uri) end function m.split(uri) - return uri:match('([^:]*):?/?/?([^/]*)(.*)') + return uri:match('([^:]*):/?/?([^/]*)(.*)') +end + +---@param uri string +---@return boolean +function m.isValid(uri) + local scheme, authority, path = m.split(uri) + if not scheme or scheme == '' then + return false + end + if path == '' then + return false + end + return true +end + +function m.normalize(uri) + if uri == '' then + return uri + end + return m.encode(m.decode(uri)) end return m diff --git a/script/files.lua b/script/files.lua index 22c9ae31..91bbe570 100644 --- a/script/files.lua +++ b/script/files.lua @@ -14,16 +14,30 @@ local progress = require "progress" local encoder = require 'encoder' local scope = require 'workspace.scope' +---@class file +---@field uri uri +---@field content string +---@field _ref? integer +---@field trusted? boolean +---@field rows? integer[] +---@field originText? string +---@field text string +---@field version? integer +---@field originLines? integer[] +---@field state? parser.state +---@field _diffInfo? table[] +---@field cache table + ---@class files local m = {} m.watchList = {} m.notifyCache = {} m.assocVersion = -1 -m.assocMatcher = nil function m.reset() m.openMap = {} + ---@type table<string, file> m.fileMap = {} m.dllMap = {} m.visible = {} @@ -41,12 +55,14 @@ local uriMap = {} ---@return uri function m.getRealUri(uri) local filename = furi.decode(uri) + -- normalize uri + uri = furi.encode(filename) local path = fs.path(filename) - local suc, res = pcall(fs.exists, path) - if not suc or not res then + local suc, exists = pcall(fs.exists, path) + if not suc or not exists then return uri end - suc, res = pcall(fs.canonical, path) + local suc, res = pcall(fs.canonical, path) if not suc then return uri end @@ -123,6 +139,7 @@ function m.isLibrary(uri, excludeFolder) end --- 获取库文件的根目录 +---@return uri? function m.getLibraryUri(suri, uri) local scp = scope.getScope(suri) return scp:getLinkedUri(uri) @@ -134,6 +151,9 @@ function m.exists(uri) return m.fileMap[uri] ~= nil end +---@param file file +---@param text string +---@return string local function pluginOnSetText(file, text) local plugin = require 'plugin' file._diffInfo = nil @@ -164,7 +184,7 @@ end --- 设置文件文本 ---@param uri uri ----@param text string +---@param text? string ---@param isTrust? boolean ---@param callback? function function m.setText(uri, text, isTrust, callback) @@ -320,7 +340,7 @@ end --- 获取文件文本 ---@param uri uri ----@return string text +---@return string? text function m.getText(uri) local file = m.fileMap[uri] if not file then @@ -331,7 +351,7 @@ end --- 获取文件原始文本 ---@param uri uri ----@return string text +---@return string? text function m.getOriginText(uri) local file = m.fileMap[uri] if not file then @@ -345,9 +365,7 @@ end ---@return integer[] function m.getOriginLines(uri) local file = m.fileMap[uri] - if not file then - return nil - end + assert(file, 'file not exists:' .. uri) return file.originLines end @@ -444,7 +462,6 @@ function m.eachFile(suri) end --- Pairs dll files ----@return function function m.eachDll() local map = {} for uri, file in pairs(m.dllMap) do @@ -488,7 +505,7 @@ function m.compileState(uri, text) , { special = config.get(uri, 'Lua.runtime.special'), unicodeName = config.get(uri, 'Lua.runtime.unicodeName'), - nonstandardSymbol = config.get(uri, 'Lua.runtime.nonstandardSymbol'), + nonstandardSymbol = util.arrayToHash(config.get(uri, 'Lua.runtime.nonstandardSymbol')), } ) local passed = os.clock() - clock @@ -524,7 +541,7 @@ end --- 获取文件语法树 ---@param uri uri ----@return table state +---@return table? state function m.getState(uri) local file = m.fileMap[uri] if not file then @@ -534,7 +551,7 @@ function m.getState(uri) if not state then state = m.compileState(uri, file.text) m.astMap[uri] = state - file.ast = state + file.state = state --await.delay() end file.cacheActiveTime = timer.clock() @@ -546,7 +563,7 @@ function m.getLastState(uri) if not file then return nil end - return file.ast + return file.state end function m.getFile(uri) diff --git a/script/filewatch.lua b/script/filewatch.lua index 5e3a0322..ecc72255 100644 --- a/script/filewatch.lua +++ b/script/filewatch.lua @@ -5,13 +5,13 @@ local await = require 'await' local MODIFY = 1 << 0 local RENAME = 1 << 1 -local function exists(filename) +local function isExists(filename) local path = fs.path(filename) - local suc, res = pcall(fs.exists, path) - if not suc or not res then + local suc, exists = pcall(fs.exists, path) + if not suc or not exists then return false end - suc, res = pcall(fs.canonical, path) + local suc, res = pcall(fs.canonical, path) if not suc or res:string() ~= path:string() then return false end @@ -69,6 +69,7 @@ function m.update() if not ev then break end + log.debug('filewatch:', ev, path) if not collect then collect = {} end @@ -85,7 +86,7 @@ function m.update() for path, flag in pairs(collect) do if flag & RENAME ~= 0 then - if exists(path) then + if isExists(path) then m._callEvent('create', path) else m._callEvent('delete', path) diff --git a/script/fs-utility.lua b/script/fs-utility.lua index 08aae98a..b789177c 100644 --- a/script/fs-utility.lua +++ b/script/fs-utility.lua @@ -13,9 +13,10 @@ local tableSort = table.sort _ENV = nil +---@class fs-utility local m = {} --- 读取文件 ----@param path string +---@param path string|fs.path function m.loadFile(path, keepBom) if type(path) ~= 'string' then ---@diagnostic disable-next-line: undefined-field @@ -40,7 +41,7 @@ function m.loadFile(path, keepBom) end --- 写入文件 ----@param path string +---@param path any ---@param content string function m.saveFile(path, content) if type(path) ~= 'string' then @@ -255,6 +256,9 @@ function dfs:saveFile(path, text) dir[filename] = text end +---@param path string|fs.path +---@param option table +---@return fs.path? local function fsAbsolute(path, option) if type(path) == 'string' then local suc, res = pcall(fs.path, path) @@ -444,6 +448,9 @@ local function fileRemove(path, option) end end +---@param source fs.path? +---@param target fs.path? +---@param option table local function fileCopy(source, target, option) if not source or not target then return @@ -477,6 +484,9 @@ local function fileCopy(source, target, option) end end +---@param source fs.path? +---@param target fs.path? +---@param option table local function fileSync(source, target, option) if not source or not target then return @@ -583,29 +593,29 @@ function m.fileRemove(path, option) end --- 复制文件(夹) ----@param source string ----@param target string +---@param source string|fs.path +---@param target string|fs.path ---@return table function m.fileCopy(source, target, option) option = buildOption(option) - source = fsAbsolute(source, option) - target = fsAbsolute(target, option) + local fsSource = fsAbsolute(source, option) + local fsTarget = fsAbsolute(target, option) - fileCopy(source, target, option) + fileCopy(fsSource, fsTarget, option) return option end --- 同步文件(夹) ----@param source string ----@param target string +---@param source string|fs.path +---@param target string|fs.path ---@return table function m.fileSync(source, target, option) option = buildOption(option) - source = fsAbsolute(source, option) - target = fsAbsolute(target, option) + local fsSource = fsAbsolute(source, option) + local fsTarget = fsAbsolute(target, option) - fileSync(source, target, option) + fileSync(fsSource, fsTarget, option) return option end diff --git a/script/glob/gitignore.lua b/script/glob/gitignore.lua index 4dad2747..de8fd005 100644 --- a/script/glob/gitignore.lua +++ b/script/glob/gitignore.lua @@ -164,12 +164,10 @@ function mt:getRelativePath(path) end ---@param callback async fun(path: string) +---@param hook? async fun(ev: string, ...) ---@async -function mt:scan(path, callback) +function mt:scan(path, callback, hook) local files = {} - if type(callback) ~= 'function' then - callback = nil - end local list = {} ---@async @@ -203,6 +201,9 @@ function mt:scan(path, callback) break end list[#list] = nil + if hook then + hook('scan', current) + end if not self:simpleMatch(current) then check(current) end diff --git a/script/jsonrpc.lua b/script/jsonrpc.lua index 91d6c9dd..7411fee8 100644 --- a/script/jsonrpc.lua +++ b/script/jsonrpc.lua @@ -50,6 +50,7 @@ function m.decode(reader) if not content then return nil, 'Proto read error' end + ---@type any local null = json.null json.null = nil local suc, res = pcall(json.decode, content) diff --git a/script/lclient.lua b/script/lclient.lua index ad1fff3d..e1504e61 100644 --- a/script/lclient.lua +++ b/script/lclient.lua @@ -80,7 +80,6 @@ function mt:reportHangs() end ---@param callback async fun(client: languageClient) ----@return languageClient function mt:start(callback) CLI = true @@ -208,8 +207,10 @@ function mt:registerFakers() 'textDocument/publishDiagnostics', 'workspace/configuration', 'workspace/semanticTokens/refresh', + 'workspace/diagnostic/refresh', 'window/workDoneProgress/create', 'window/showMessage', + 'window/showMessageRequest', 'window/logMessage', } do self:register(method, function () diff --git a/script/library.lua b/script/library.lua index 66c4d364..57aac066 100644 --- a/script/library.lua +++ b/script/library.lua @@ -209,6 +209,7 @@ local function initBuiltIn(uri) local langID = lang.id local version = config.get(uri, 'Lua.runtime.version') local encoding = config.get(uri, 'Lua.runtime.fileEncoding') + ---@type fs.path local metaPath = fs.path(METAPATH) / config.get(uri, 'Lua.runtime.meta'):gsub('%$%{(.-)%}', { version = version, language = langID, @@ -243,6 +244,7 @@ local function initBuiltIn(uri) goto CONTINUE end libName = libName .. '.lua' + ---@type fs.path local libPath = templateDir / libName local metaDoc = compileSingleMetaDoc(uri, fsu.loadFile(libPath), metaLang, status) if metaDoc then @@ -260,6 +262,7 @@ local function initBuiltIn(uri) end end +---@param libraryDir fs.path local function loadSingle3rdConfig(libraryDir) local configText = fsu.loadFile(libraryDir / 'config.lua') if not configText then @@ -321,7 +324,13 @@ local function load3rdConfigInDir(dir, configs, inner) end local function load3rdConfig(uri) - local configs = {} + local scp = scope.getScope(uri) + local configs = scp:get 'thirdConfigsCache' + if configs then + return configs + end + configs = {} + scp:set('thirdConfigsCache', configs) load3rdConfigInDir(innerThirdDir, configs, true) local thirdDirs = config.get(uri, 'Lua.workspace.userThirdParty') for _, thirdDir in ipairs(thirdDirs) do @@ -400,6 +409,15 @@ local function askFor3rd(uri, cfg) uri = uri, }, }, true) + else + client.setConfig({ + { + key = 'Lua.workspace.checkThirdParty', + action = 'set', + value = false, + uri = uri, + }, + }, false) end end @@ -420,11 +438,21 @@ local function wholeMatch(a, b) return true end -local function check3rdByWords(uri, text, configs) +local function check3rdByWords(uri, configs) if hasAsked then return end + if not files.isLua(uri) then + return + end + local id = 'check3rdByWords:' .. uri + await.close(id) await.call(function () ---@async + await.sleep(0.1) + local text = files.getText(uri) + if not text then + return + end for _, cfg in ipairs(configs) do if cfg.words then for _, word in ipairs(cfg.words) do @@ -436,7 +464,7 @@ local function check3rdByWords(uri, text, configs) end end end - end) + end, id) end local function check3rdByFileName(uri, configs) @@ -447,7 +475,10 @@ local function check3rdByFileName(uri, configs) if not path then return end + local id = 'check3rdByFileName:' .. uri + await.close(id) await.call(function () ---@async + await.sleep(0.1) for _, cfg in ipairs(configs) do if cfg.files then for _, filename in ipairs(cfg.files) do @@ -459,50 +490,62 @@ local function check3rdByFileName(uri, configs) end end end - end) -end - -local lastCheckedUri = {} -local function checkedUri(uri) - if lastCheckedUri[uri] - and timer.clock() - lastCheckedUri[uri] < 5 then - return false - end - lastCheckedUri[uri] = timer.clock() - return true + end, id) end -local thirdConfigs +---@async local function check3rd(uri) if hasAsked then return end + if ws.isIgnored(uri) then + return + end if not config.get(uri, 'Lua.workspace.checkThirdParty') then return end - if thirdConfigs == nil then - thirdConfigs = load3rdConfig(uri) or false + local scp = scope.getScope(uri) + if not scp:get 'canCheckThirdParty' then + return end + local thirdConfigs = load3rdConfig(uri) or false if not thirdConfigs then return end - if checkedUri(uri) then - if files.isLua(uri) then - local text = files.getText(uri) - if text then - check3rdByWords(uri, text, thirdConfigs) - end + check3rdByWords(uri, thirdConfigs) + check3rdByFileName(uri, thirdConfigs) +end + +local function check3rdOfWorkspace(suri) + local scp = scope.getScope(suri) + scp:set('thirdConfigsCache', nil) + scp:set('canCheckThirdParty', true) + local id = 'check3rdOfWorkspace:' .. scp:getName() + await.close(id) + ---@async + await.call(function () + ws.awaitReady(suri) + for uri in files.eachFile(suri) do + check3rd(uri) end - check3rdByFileName(uri, thirdConfigs) - end + for uri in files.eachDll() do + check3rd(uri) + end + end, id) end config.watch(function (uri, key, value, oldValue) if key:find '^Lua.runtime' then initBuiltIn(uri) end + if key == 'Lua.workspace.checkThirdParty' + or key == 'Lua.workspace.userThirdParty' + or key == '' then + check3rdOfWorkspace(uri) + end end) +---@async files.watch(function (ev, uri) if ev == 'update' or ev == 'dll' then @@ -510,11 +553,10 @@ files.watch(function (ev, uri) end end) -function m.init() - initBuiltIn(nil) - for _, scp in ipairs(ws.folders) do - initBuiltIn(scp.uri) +ws.watch(function (ev, uri) + if ev == 'startReload' then + initBuiltIn(uri) end -end +end) return m diff --git a/script/linked-table.lua b/script/linked-table.lua index 4d87e943..a63a528c 100644 --- a/script/linked-table.lua +++ b/script/linked-table.lua @@ -8,10 +8,14 @@ mt._size = 0 local HEAD = {'<HEAD>'} local TAIL = {'<TAIL>'} +---@param node any +---@return boolean function mt:has(node) return self._left[node] ~= nil end +---@param node any +---@return boolean function mt:isValidNode(node) if node == nil or node == HEAD @@ -21,6 +25,9 @@ function mt:isValidNode(node) return true end +---@param node any +---@param afterWho any +---@return boolean function mt:pushAfter(node, afterWho) if not self:isValidNode(node) then return false @@ -41,6 +48,9 @@ function mt:pushAfter(node, afterWho) return true end +---@param node any +---@param beforeWho any +---@return boolean function mt:pushBefore(node, beforeWho) if node == nil then return false @@ -52,6 +62,8 @@ function mt:pushBefore(node, beforeWho) return self:pushAfter(node, left) end +---@param node any +---@return boolean function mt:pop(node) if not self:isValidNode(node) then return false @@ -71,14 +83,20 @@ function mt:pop(node) return true end +---@param node any +---@return boolean function mt:pushHead(node) return self:pushAfter(node, HEAD) end +---@param node any +---@return boolean function mt:pushTail(node) return self:pushBefore(node, TAIL) end +---@param node any +---@return any function mt:getAfter(node) if node == nil then node = HEAD @@ -90,10 +108,12 @@ function mt:getAfter(node) return right end +---@return any function mt:getHead() return self:getAfter(HEAD) end +---@return any function mt:getBefore(node) if node == nil then node = TAIL @@ -105,18 +125,24 @@ function mt:getBefore(node) return left end +---@return any function mt:getTail() return self:getBefore(TAIL) end +---@return boolean function mt:popHead() return self:pop(self:getHead()) end +---@return boolean function mt:popTail() return self:pop(self:getTail()) end +---@param old any +---@param new any +---@return boolean function mt:replace(old, new) if not self:isValidNode(old) or not self:isValidNode(new) then @@ -137,10 +163,14 @@ function mt:replace(old, new) return true end +---@return integer function mt:getSize() return self._size end +---@param start any +---@param revert? boolean +---@return fun():any function mt:pairs(start, revert) if revert then if start == nil then @@ -171,6 +201,9 @@ function mt:pairs(start, revert) end end +---@param start any +---@param revert? boolean +---@return string function mt:dump(start, revert) local t = {} for node in self:pairs(start, revert) do @@ -186,6 +219,7 @@ function mt:reset() self._size = 0 end +---@return linked-table return function () local self = setmetatable({}, mt) self:reset() diff --git a/script/meta/bee/filesystem.lua b/script/meta/bee/filesystem.lua new file mode 100644 index 00000000..f6cdff79 --- /dev/null +++ b/script/meta/bee/filesystem.lua @@ -0,0 +1,91 @@ +---@class fs.path +---@operator div: fs.path +local fsPath = {} + +---@return string +function fsPath:string() +end + +---@return fs.path +function fsPath:parent_path() +end + +---@return boolean +function fsPath:is_relative() +end + +---@return fs.path +function fsPath:filename() +end + +---@class fs.status +local fsStatus = {} + +---@return string +function fsStatus:type() +end + +---@class fs +local fs = {} + +---@class fs.copy_options +---@field overwrite_existing integer +local copy_options + +fs.copy_options = copy_options + +---@param path string +---@return fs.path +function fs.path(path) +end + +---@return fs.path +function fs.exe_path() +end + +---@param path fs.path +---@return boolean +function fs.exists(path) +end + +---@param path fs.path +---@return boolean +function fs.is_directory(path) +end + +---@param path fs.path +---@return fun():fs.path +function fs.pairs(path) +end + +---@param path fs.path +---@return fs.path +function fs.canonical(path) +end + +---@param path fs.path +---@return fs.path +function fs.absolute(path) +end + +---@param path fs.path +function fs.create_directories(path) +end + +---@param path fs.path +---@return fs.status +function fs.symlink_status(path) +end + +---@param path fs.path +---@return boolean +function fs.remove(path) +end + +---@param source fs.path +---@param target fs.path +---@param options? `fs.copy_options.overwrite_existing` +function fs.copy_file(source, target, options) +end + +return fs diff --git a/script/parser/ast.lua b/script/parser/ast.lua deleted file mode 100644 index 648a6890..00000000 --- a/script/parser/ast.lua +++ /dev/null @@ -1,1997 +0,0 @@ -local tonumber = tonumber -local stringChar = string.char -local utf8Char = utf8.char -local tableUnpack = table.unpack -local mathType = math.type -local tableRemove = table.remove -local tableSort = table.sort -local print = print -local tostring = tostring - -_ENV = nil - -local DefaultState = { - lua = '', - options = {}, -} - -local State = DefaultState -local PushError -local PushDiag -local PushComment - --- goto 单独处理 -local RESERVED = { - ['and'] = true, - ['break'] = true, - ['do'] = true, - ['else'] = true, - ['elseif'] = true, - ['end'] = true, - ['false'] = true, - ['for'] = true, - ['function'] = true, - ['if'] = true, - ['in'] = true, - ['local'] = true, - ['nil'] = true, - ['not'] = true, - ['or'] = true, - ['repeat'] = true, - ['return'] = true, - ['then'] = true, - ['true'] = true, - ['until'] = true, - ['while'] = true, -} - -local VersionOp = { - ['&'] = {'Lua 5.3', 'Lua 5.4'}, - ['~'] = {'Lua 5.3', 'Lua 5.4'}, - ['|'] = {'Lua 5.3', 'Lua 5.4'}, - ['<<'] = {'Lua 5.3', 'Lua 5.4'}, - ['>>'] = {'Lua 5.3', 'Lua 5.4'}, - ['//'] = {'Lua 5.3', 'Lua 5.4'}, -} - -local SymbolAlias = { - ['||'] = 'or', - ['&&'] = 'and', - ['!='] = '~=', - ['!'] = 'not', -} - -local function checkOpVersion(op) - local versions = VersionOp[op.type] - if not versions then - return - end - for i = 1, #versions do - if versions[i] == State.version then - return - end - end - PushError { - type = 'UNSUPPORT_SYMBOL', - start = op.start, - finish = op.finish, - version = versions, - info = { - version = State.version, - } - } -end - -local function checkMissEnd(start) - if not State.MissEndErr then - return - end - local err = State.MissEndErr - State.MissEndErr = nil - local _, finish = State.lua:find('[%w_]+', start) - if not finish then - return - end - err.info.related = { - { - start = start, - finish = finish, - } - } - PushError { - type = 'MISS_END', - start = start, - finish = finish, - } -end - -local function getSelect(vararg, index) - return { - type = 'select', - start = vararg.start, - finish = vararg.finish, - vararg = vararg, - sindex = index, - } -end - -local function getValue(values, i) - if not values then - return nil, nil - end - local value = values[i] - if not value then - local last = values[#values] - if not last then - return nil, nil - end - if last.type == 'call' or last.type == 'varargs' then - return getSelect(last, i - #values + 1) - end - return nil, nil - end - if value.type == 'call' or value.type == 'varargs' then - value = getSelect(value, 1) - end - return value -end - -local function createLocal(key, effect, value, attrs) - if not key then - return nil - end - key.type = 'local' - key.effect = effect - key.value = value - key.attrs = attrs - if value then - key.range = value.finish - end - return key -end - -local function createCall(args, start, finish) - if args then - args.type = 'callargs' - args.start = start - args.finish = finish - end - return { - type = 'call', - start = start, - finish = finish, - args = args, - } -end - -local function packList(start, list, finish) - local lastFinish = start - local wantName = true - local count = 0 - for i = 1, #list do - local ast = list[i] - if ast.type == ',' then - if wantName or i == #list then - PushError { - type = 'UNEXPECT_SYMBOL', - start = ast.start, - finish = ast.finish, - info = { - symbol = ',', - } - } - end - wantName = true - else - if not wantName then - PushError { - type = 'MISS_SYMBOL', - start = lastFinish, - finish = ast.start - 1, - info = { - symbol = ',', - } - } - end - wantName = false - count = count + 1 - list[count] = list[i] - end - lastFinish = ast.finish + 1 - end - for i = count + 1, #list do - list[i] = nil - end - list.type = 'list' - list.start = start - list.finish = finish - 1 - return list -end - -local BinaryLevel = { - ['or'] = 1, - ['and'] = 2, - ['<='] = 3, - ['>='] = 3, - ['<'] = 3, - ['>'] = 3, - ['~='] = 3, - ['=='] = 3, - ['|'] = 4, - ['~'] = 5, - ['&'] = 6, - ['<<'] = 7, - ['>>'] = 7, - ['..'] = 8, - ['+'] = 9, - ['-'] = 9, - ['*'] = 10, - ['//'] = 10, - ['/'] = 10, - ['%'] = 10, - ['^'] = 11, -} - -local BinaryForward = { - [01] = true, - [02] = true, - [03] = true, - [04] = true, - [05] = true, - [06] = true, - [07] = true, - [08] = false, - [09] = true, - [10] = true, - [11] = false, -} - -local Defs = { - Nil = function (pos) - return { - type = 'nil', - start = pos, - finish = pos + 2, - } - end, - True = function (pos) - return { - type = 'boolean', - start = pos, - finish = pos + 3, - [1] = true, - } - end, - False = function (pos) - return { - type = 'boolean', - start = pos, - finish = pos + 4, - [1] = false, - } - end, - ShortComment = function (start, text, finish) - PushComment { - type = 'comment.short', - start = start, - finish = finish - 1, - text = text, - } - end, - LongComment = function (start, beforeEq, afterEq, str, close, finish) - PushComment { - type = 'comment.long', - start = start, - finish = finish - 1, - text = str, - } - if not close then - local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']' - local s, _, w = str:find('(%][%=]*%])[%c%s]*$') - if s then - PushError { - type = 'ERR_LCOMMENT_END', - start = finish - #str + s - 1, - finish = finish - #str + s + #w - 2, - info = { - symbol = endSymbol, - }, - fix = { - title = 'FIX_LCOMMENT_END', - { - start = finish - #str + s - 1, - finish = finish - #str + s + #w - 2, - text = endSymbol, - } - }, - } - end - PushError { - type = 'MISS_SYMBOL', - start = finish, - finish = finish, - info = { - symbol = endSymbol, - }, - fix = { - title = 'ADD_LCOMMENT_END', - { - start = finish, - finish = finish, - text = endSymbol, - } - }, - } - end - end, - CLongComment = function (start1, finish1, str, start2, finish2) - if State.options.nonstandardSymbol and State.options.nonstandardSymbol['/**/'] then - else - PushError { - type = 'ERR_C_LONG_COMMENT', - start = start1, - finish = finish2 - 1, - fix = { - title = 'FIX_C_LONG_COMMENT', - { - start = start1, - finish = finish1 - 1, - text = '--[[', - }, - { - start = start2, - finish = finish2 - 1, - text = '--]]' - }, - } - } - end - PushComment { - type = 'comment.clong', - start = start1, - finish = finish2 - 1, - text = str, - } - end, - CCommentPrefix = function (start, finish, commentFinish) - if State.options.nonstandardSymbol and State.options.nonstandardSymbol['//'] then - else - PushError { - type = 'ERR_COMMENT_PREFIX', - start = start, - finish = finish - 1, - fix = { - title = 'FIX_COMMENT_PREFIX', - { - start = start, - finish = finish - 1, - text = '--', - }, - } - } - end - PushComment { - type = 'comment.cshort', - start = start, - finish = commentFinish - 1, - text = '', - } - end, - String = function (start, quote, str, finish) - if quote == '`' then - if State.options.nonstandardSymbol and State.options.nonstandardSymbol['`'] then - else - PushError { - type = 'ERR_NONSTANDARD_SYMBOL', - start = start, - finish = finish - 1, - info = { - symbol = '"', - }, - fix = { - title = 'FIX_NONSTANDARD_SYMBOL', - symbol = '"', - { - start = start, - finish = start, - text = '"', - }, - { - start = finish - 1, - finish = finish - 1, - text = '"', - }, - } - } - end - end - return { - type = 'string', - start = start, - finish = finish - 1, - [1] = str, - [2] = quote, - } - end, - LongString = function (beforeEq, afterEq, str, missPos) - if missPos then - local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']' - local s, _, w = str:find('(%][%=]*%])[%c%s]*$') - if s then - PushError { - type = 'ERR_LSTRING_END', - start = missPos - #str + s - 1, - finish = missPos - #str + s + #w - 2, - info = { - symbol = endSymbol, - }, - fix = { - title = 'FIX_LSTRING_END', - { - start = missPos - #str + s - 1, - finish = missPos - #str + s + #w - 2, - text = endSymbol, - } - }, - } - end - PushError { - type = 'MISS_SYMBOL', - start = missPos, - finish = missPos, - info = { - symbol = endSymbol, - }, - fix = { - title = 'ADD_LSTRING_END', - { - start = missPos, - finish = missPos, - text = endSymbol, - } - }, - } - end - return '[' .. ('='):rep(afterEq-beforeEq) .. '[', str - end, - Char10 = function (char) - char = tonumber(char) - if not char or char < 0 or char > 255 then - return '' - end - return stringChar(char) - end, - Char16 = function (pos, char) - if State.version == 'Lua 5.1' then - PushError { - type = 'ERR_ESC', - start = pos-1, - finish = pos, - version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, - info = { - version = State.version, - } - } - return char - end - return stringChar(tonumber(char, 16)) - end, - CharUtf8 = function (pos, char) - if State.version ~= 'Lua 5.3' - and State.version ~= 'Lua 5.4' - and State.version ~= 'LuaJIT' - then - PushError { - type = 'ERR_ESC', - start = pos-3, - finish = pos-2, - version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, - info = { - version = State.version, - } - } - return char - end - if #char == 0 then - PushError { - type = 'UTF8_SMALL', - start = pos-3, - finish = pos, - } - return '' - end - local v = tonumber(char, 16) - if not v then - for i = 1, #char do - if not tonumber(char:sub(i, i), 16) then - PushError { - type = 'MUST_X16', - start = pos + i - 1, - finish = pos + i - 1, - } - end - end - return '' - end - if State.version == 'Lua 5.4' then - if v < 0 or v > 0x7FFFFFFF then - PushError { - type = 'UTF8_MAX', - start = pos-3, - finish = pos+#char, - info = { - min = '00000000', - max = '7FFFFFFF', - } - } - end - else - if v < 0 or v > 0x10FFFF then - PushError { - type = 'UTF8_MAX', - start = pos-3, - finish = pos+#char, - version = v <= 0x7FFFFFFF and 'Lua 5.4' or nil, - info = { - min = '000000', - max = '10FFFF', - } - } - end - end - if v >= 0 and v <= 0x10FFFF then - return utf8Char(v) - end - return '' - end, - Number = function (start, number, finish) - local n = tonumber(number) - if n then - State.LastNumber = { - type = mathType(n) == 'integer' and 'integer' or 'number', - start = start, - finish = finish - 1, - [1] = n, - } - State.LastRaw = number - return State.LastNumber - else - PushError { - type = 'MALFORMED_NUMBER', - start = start, - finish = finish - 1, - } - State.LastNumber = { - type = 'number', - start = start, - finish = finish - 1, - [1] = 0, - } - State.LastRaw = number - return State.LastNumber - end - end, - FFINumber = function (start, symbol) - local lastNumber = State.LastNumber - if State.LastRaw:find('.', 1, true) then - PushError { - type = 'UNKNOWN_SYMBOL', - start = start, - finish = start + #symbol - 1, - info = { - symbol = symbol, - } - } - lastNumber[1] = 0 - return - end - if State.version ~= 'LuaJIT' then - PushError { - type = 'UNSUPPORT_SYMBOL', - start = start, - finish = start + #symbol - 1, - version = 'LuaJIT', - info = { - version = State.version, - } - } - lastNumber[1] = 0 - end - end, - ImaginaryNumber = function (start, symbol) - local lastNumber = State.LastNumber - if State.version ~= 'LuaJIT' then - PushError { - type = 'UNSUPPORT_SYMBOL', - start = start, - finish = start + #symbol - 1, - version = 'LuaJIT', - info = { - version = State.version, - } - } - end - lastNumber[1] = 0 - end, - Integer2 = function (start, word) - if State.version ~= 'LuaJIT' then - PushError { - type = 'UNSUPPORT_SYMBOL', - start = start, - finish = start + 1, - version = 'LuaJIT', - info = { - version = State.version, - } - } - end - local num = 0 - for i = 1, #word do - if word:sub(i, i) == '1' then - num = num | (1 << (i - 1)) - end - end - return tostring(num) - end, - Name = function (start, str, finish) - local isKeyWord - if RESERVED[str] then - isKeyWord = true - elseif str == 'goto' then - if State.version ~= 'Lua 5.1' and State.version ~= 'LuaJIT' then - isKeyWord = true - end - end - if isKeyWord then - PushError { - type = 'KEYWORD', - start = start, - finish = finish - 1, - } - end - if not State.options.unicodeName and str:find '[\x80-\xff]' then - PushError { - type = 'UNICODE_NAME', - start = start, - finish = finish - 1, - } - end - return { - type = 'name', - start = start, - finish = finish - 1, - [1] = str, - } - end, - GetField = function (dot, field) - local obj = { - type = 'getfield', - field = field, - dot = dot, - start = dot.start, - finish = (field or dot).finish, - } - if field then - field.type = 'field' - field.parent = obj - end - return obj - end, - GetIndex = function (start, index, finish) - local obj = { - type = 'getindex', - bstart = start, - start = start, - finish = finish - 1, - index = index, - } - if index then - index.parent = obj - end - return obj - end, - GetMethod = function (colon, method) - local obj = { - type = 'getmethod', - method = method, - colon = colon, - start = colon.start, - finish = (method or colon).finish, - } - if method then - method.type = 'method' - method.parent = obj - end - return obj - end, - Single = function (unit) - unit.type = 'getname' - return unit - end, - Simple = function (units) - local last = units[1] - for i = 2, #units do - local current = units[i] - current.node = last - current.start = last.start - last.next = current - last = units[i] - end - return last - end, - SimpleCall = function (call) - if call.type ~= 'call' and call.type ~= 'getmethod' then - PushError { - type = 'EXP_IN_ACTION', - start = call.start, - finish = call.finish, - } - end - return call - end, - BinaryOp = function (start, op) - if SymbolAlias[op] then - if State.options.nonstandardSymbol and State.options.nonstandardSymbol[op] then - else - PushError { - type = 'ERR_NONSTANDARD_SYMBOL', - start = start, - finish = start + #op - 1, - info = { - symbol = SymbolAlias[op], - }, - fix = { - title = 'FIX_NONSTANDARD_SYMBOL', - symbol = SymbolAlias[op], - { - start = start, - finish = start + #op - 1, - text = SymbolAlias[op], - }, - } - } - end - op = SymbolAlias[op] - end - return { - type = op, - start = start, - finish = start + #op - 1, - } - end, - UnaryOp = function (start, op) - if SymbolAlias[op] then - if State.options.nonstandardSymbol and State.options.nonstandardSymbol[op] then - else - PushError { - type = 'ERR_NONSTANDARD_SYMBOL', - start = start, - finish = start + #op - 1, - info = { - symbol = SymbolAlias[op], - }, - fix = { - title = 'FIX_NONSTANDARD_SYMBOL', - symbol = SymbolAlias[op], - { - start = start, - finish = start + #op - 1, - text = SymbolAlias[op], - }, - } - } - end - op = SymbolAlias[op] - end - return { - type = op, - start = start, - finish = start + #op - 1, - } - end, - Unary = function (first, ...) - if not ... then - return nil - end - local list = {first, ...} - local e = list[#list] - for i = #list - 1, 1, -1 do - local op = list[i] - checkOpVersion(op) - e = { - type = 'unary', - op = op, - start = op.start, - finish = e.finish, - [1] = e, - } - end - return e - end, - SubBinary = function (op, symb) - if symb then - return op, symb - end - PushError { - type = 'MISS_EXP', - start = op.start, - finish = op.finish, - } - end, - Binary = function (first, op, second, ...) - if not first then - return second - end - if not op then - return first - end - if not ... then - checkOpVersion(op) - return { - type = 'binary', - op = op, - start = first.start, - finish = second.finish, - [1] = first, - [2] = second, - } - end - local list = {first, op, second, ...} - local ops = {} - for i = 2, #list, 2 do - ops[#ops+1] = i - end - tableSort(ops, function (a, b) - local op1 = list[a] - local op2 = list[b] - local lv1 = BinaryLevel[op1.type] - local lv2 = BinaryLevel[op2.type] - if lv1 == lv2 then - local forward = BinaryForward[lv1] - if forward then - return op1.start > op2.start - else - return op1.start < op2.start - end - else - return lv1 < lv2 - end - end) - local final - for i = #ops, 1, -1 do - local n = ops[i] - local op = list[n] - local left = list[n-1] - local right = list[n+1] - local exp = { - type = 'binary', - op = op, - start = left.start, - finish = right and right.finish or op.finish, - [1] = left, - [2] = right, - } - local leftIndex, rightIndex - if list[left] then - leftIndex = list[left[1]] - else - leftIndex = n - 1 - end - if list[right] then - rightIndex = list[right[2]] - else - rightIndex = n + 1 - end - - list[leftIndex] = exp - list[rightIndex] = exp - list[left] = leftIndex - list[right] = rightIndex - list[exp] = n - final = exp - - checkOpVersion(op) - end - return final - end, - Paren = function (start, exp, finish) - if exp and exp.type == 'paren' then - exp.start = start - exp.finish = finish - 1 - return exp - end - return { - type = 'paren', - start = start, - finish = finish - 1, - exp = exp - } - end, - VarArgs = function (dots) - dots.type = 'varargs' - return dots - end, - PackLoopArgs = function (start, list, finish) - local list = packList(start, list, finish) - if #list == 0 then - PushError { - type = 'MISS_LOOP_MIN', - start = finish, - finish = finish, - } - elseif #list == 1 then - PushError { - type = 'MISS_LOOP_MAX', - start = finish, - finish = finish, - } - end - return list - end, - PackInNameList = function (start, list, finish) - local list = packList(start, list, finish) - if #list == 0 then - PushError { - type = 'MISS_NAME', - start = start, - finish = finish, - } - end - return list - end, - PackInExpList = function (start, list, finish) - local list = packList(start, list, finish) - if #list == 0 then - PushError { - type = 'MISS_EXP', - start = start, - finish = finish, - } - end - return list - end, - PackExpList = function (start, list, finish) - local list = packList(start, list, finish) - return list - end, - PackNameList = function (start, list, finish) - local list = packList(start, list, finish) - return list - end, - Call = function (start, args, finish) - return createCall(args, start, finish-1) - end, - COMMA = function (start) - return { - type = ',', - start = start, - finish = start, - } - end, - SEMICOLON = function (start) - return { - type = ';', - start = start, - finish = start, - } - end, - DOTS = function (start) - return { - type = '...', - start = start, - finish = start + 2, - } - end, - COLON = function (start) - return { - type = ':', - start = start, - finish = start, - } - end, - ASSIGN = function (start, symbol) - if State.options.nonstandardSymbol and State.options.nonstandardSymbol[symbol] then - else - PushError { - type = 'UNSUPPORT_SYMBOL', - start = start, - finish = start + #symbol - 1, - info = { - version = 'Lua', - } - } - end - end, - DOT = function (start) - return { - type = '.', - start = start, - finish = start, - } - end, - Function = function (functionStart, functionFinish, name, args, actions, endStart, endFinish) - actions.type = 'function' - actions.start = functionStart - actions.finish = endFinish - 1 - actions.args = args - actions.keyword= { - functionStart, functionFinish - 1, - endStart, endFinish - 1, - } - checkMissEnd(functionStart) - if not name then - return actions - end - if name.type == 'getname' then - name.type = 'setname' - name.value = actions - elseif name.type == 'getfield' then - name.type = 'setfield' - name.value = actions - elseif name.type == 'getmethod' then - name.type = 'setmethod' - name.value = actions - elseif name.type == 'getindex' then - name.type = 'setfield' - name.value = actions - PushError { - type = 'INDEX_IN_FUNC_NAME', - start = name.bstart, - finish = name.finish, - } - end - name.range = actions.finish - name.vstart = functionStart - return name - end, - LocalFunction = function (start, name) - if name.type == 'function' then - PushError { - type = 'MISS_NAME', - start = name.keyword[2] + 1, - finish = name.keyword[2] + 1, - } - return name - end - if name.type ~= 'setname' then - PushError { - type = 'UNEXPECT_LFUNC_NAME', - start = name.start, - finish = name.finish, - } - return name - end - - local loc = createLocal(name, name.start, name.value) - loc.localfunction = true - loc.vstart = name.value.start - return name - end, - NamedFunction = function (name) - if name.type == 'function' then - PushError { - type = 'MISS_NAME', - start = name.keyword[2] + 1, - finish = name.keyword[2] + 1, - } - end - return name - end, - ExpFunction = function (func) - if func.type ~= 'function' then - PushError { - type = 'UNEXPECT_EFUNC_NAME', - start = func.start, - finish = func.finish, - } - return func.value - end - return func - end, - Table = function (start, tbl, finish) - tbl.type = 'table' - tbl.start = start - tbl.finish = finish - 1 - local wantField = true - local lastStart = start + 1 - local fieldCount = 0 - local n = 0 - for i = 1, #tbl do - local field = tbl[i] - if field.type == ',' or field.type == ';' then - if wantField then - PushError { - type = 'MISS_EXP', - start = lastStart, - finish = field.start - 1, - } - end - wantField = true - lastStart = field.finish + 1 - else - if not wantField then - PushError { - type = 'MISS_SEP_IN_TABLE', - start = lastStart, - finish = field.start - 1, - } - end - wantField = false - lastStart = field.finish + 1 - fieldCount = fieldCount + 1 - tbl[fieldCount] = field - if field.type == 'tableexp' then - n = n + 1 - field.tindex = n - end - end - end - for i = fieldCount + 1, #tbl do - tbl[i] = nil - end - return tbl - end, - NewField = function (start, field, value, finish) - local obj = { - type = 'tablefield', - start = start, - finish = finish-1, - field = field, - value = value, - } - if field then - field.type = 'field' - field.parent = obj - end - return obj - end, - NewIndex = function (start, index, value, finish) - local obj = { - type = 'tableindex', - start = start, - finish = finish-1, - index = index, - value = value, - } - if index then - index.parent = obj - end - return obj - end, - TableExp = function (start, value, finish) - if not value then - return - end - local obj = { - type = 'tableexp', - start = start, - finish = finish-1, - value = value, - } - return obj - end, - FuncArgs = function (start, args, finish) - args.type = 'funcargs' - args.start = start - args.finish = finish - 1 - local lastStart = start + 1 - local wantName = true - local argCount = 0 - for i = 1, #args do - local arg = args[i] - local argAst = arg - if argAst.type == ',' then - if wantName then - PushError { - type = 'MISS_NAME', - start = lastStart, - finish = argAst.start-1, - } - end - wantName = true - else - if not wantName then - PushError { - type = 'MISS_SYMBOL', - start = lastStart-1, - finish = argAst.start-1, - info = { - symbol = ',', - } - } - end - wantName = false - argCount = argCount + 1 - - if argAst.type == '...' then - args[argCount] = arg - if i < #args then - local a = args[i+1] - local b = args[#args] - PushError { - type = 'ARGS_AFTER_DOTS', - start = a.start, - finish = b.finish, - } - end - break - else - args[argCount] = createLocal(arg, arg.start) - end - end - lastStart = argAst.finish + 1 - end - for i = argCount + 1, #args do - args[i] = nil - end - if wantName and argCount > 0 then - PushError { - type = 'MISS_NAME', - start = lastStart, - finish = finish - 1, - } - end - return args - end, - Set = function (start, keys, eqFinish, values, finish) - for i = 1, #keys do - local key = keys[i] - if key.type == 'getname' then - key.type = 'setname' - key.value = getValue(values, i) - elseif key.type == 'getfield' then - key.type = 'setfield' - key.value = getValue(values, i) - elseif key.type == 'getindex' then - key.type = 'setindex' - key.value = getValue(values, i) - else - PushError { - type = 'UNEXPECT_SYMBOL', - start = eqFinish - 1, - finish = eqFinish - 1, - info = { - symbol = '=', - } - } - end - if key.value then - key.range = key.value.finish - end - end - if values then - for i = #keys+1, #values do - local value = values[i] - PushDiag('redundant-value', { - start = value.start, - finish = value.finish, - max = #keys, - passed = #values, - }) - end - end - return tableUnpack(keys) - end, - LocalAttr = function (attrs) - if #attrs == 0 then - return nil - end - for i = 1, #attrs do - local attr = attrs[i] - local attrAst = attr - attrAst.type = 'localattr' - if State.version ~= 'Lua 5.4' then - PushError { - type = 'UNSUPPORT_SYMBOL', - start = attrAst.start, - finish = attrAst.finish, - version = 'Lua 5.4', - info = { - version = State.version, - } - } - elseif attrAst[1] ~= 'const' and attrAst[1] ~= 'close' then - PushError { - type = 'UNKNOWN_TAG', - start = attrAst.start, - finish = attrAst.finish, - info = { - tag = attrAst[1], - } - } - elseif i > 1 then - PushError { - type = 'MULTI_TAG', - start = attrAst.start, - finish = attrAst.finish, - info = { - tag = attrAst[1], - } - } - end - end - attrs.start = attrs[1].start - attrs.finish = attrs[#attrs].finish - return attrs - end, - LocalName = function (name, attrs) - if not name then - return - end - name.attrs = attrs - return name - end, - Local = function (start, keys, values, finish) - for i = 1, #keys do - local key = keys[i] - local attrs = key.attrs - key.attrs = nil - local value = getValue(values, i) - createLocal(key, finish, value, attrs) - end - if values then - for i = #keys+1, #values do - local value = values[i] - PushDiag('redundant-value', { - start = value.start, - finish = value.finish, - max = #keys, - passed = #values, - }) - end - end - return tableUnpack(keys) - end, - Do = function (start, actions, endA, endB) - actions.type = 'do' - actions.start = start - actions.finish = endB - 1 - actions.keyword= { - start, start + #'do' - 1, - endA , endB - 1, - } - checkMissEnd(start) - return actions - end, - Break = function (start, finish) - return { - type = 'break', - start = start, - finish = finish - 1, - } - end, - Return = function (start, exps, finish) - exps.type = 'return' - exps.start = start - exps.finish = finish - 1 - return exps - end, - Label = function (start, name, finish) - if State.version == 'Lua 5.1' then - PushError { - type = 'UNSUPPORT_SYMBOL', - start = start, - finish = finish - 1, - version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, - info = { - version = State.version, - } - } - return - end - if not name then - return - end - name.type = 'label' - return name - end, - GoTo = function (start, name, finish) - if State.version == 'Lua 5.1' then - PushError { - type = 'UNSUPPORT_SYMBOL', - start = start, - finish = finish - 1, - version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, - info = { - version = State.version, - } - } - return - end - if not name then - return - end - name.type = 'goto' - return name - end, - IfBlock = function (ifStart, ifFinish, exp, thenStart, thenFinish, actions, finish) - actions.type = 'ifblock' - actions.start = ifStart - actions.finish = finish - 1 - actions.filter = exp - actions.keyword= { - ifStart, ifFinish - 1, - thenStart, thenFinish - 1, - } - return actions - end, - ElseIfBlock = function (elseifStart, elseifFinish, exp, thenStart, thenFinish, actions, finish) - actions.type = 'elseifblock' - actions.start = elseifStart - actions.finish = finish - 1 - actions.filter = exp - actions.keyword= { - elseifStart, elseifFinish - 1, - thenStart, thenFinish - 1, - } - return actions - end, - ElseBlock = function (elseStart, elseFinish, actions, finish) - actions.type = 'elseblock' - actions.start = elseStart - actions.finish = finish - 1 - actions.keyword= { - elseStart, elseFinish - 1, - } - return actions - end, - If = function (start, blocks, endStart, endFinish) - blocks.type = 'if' - blocks.start = start - blocks.finish = endFinish - 1 - local hasElse - for i = 1, #blocks do - local block = blocks[i] - if i == 1 and block.type ~= 'ifblock' then - PushError { - type = 'MISS_SYMBOL', - start = block.start, - finish = block.start, - info = { - symbol = 'if', - } - } - end - if hasElse then - PushError { - type = 'BLOCK_AFTER_ELSE', - start = block.start, - finish = block.finish, - } - end - if block.type == 'elseblock' then - hasElse = true - end - end - checkMissEnd(start) - return blocks - end, - Loop = function (forA, forB, arg, steps, doA, doB, blockStart, block, endA, endB) - local loc = createLocal(arg, blockStart, steps[1]) - block.type = 'loop' - block.start = forA - block.finish = endB - 1 - block.loc = loc - block.max = steps[2] - block.step = steps[3] - block.keyword= { - forA, forB - 1, - doA , doB - 1, - endA, endB - 1, - } - checkMissEnd(forA) - return block - end, - In = function (forA, forB, keys, inA, inB, exp, doA, doB, blockStart, block, endA, endB) - local func = tableRemove(exp, 1) - block.type = 'in' - block.start = forA - block.finish = endB - 1 - block.keys = keys - block.keyword= { - forA, forB - 1, - inA , inB - 1, - doA , doB - 1, - endA, endB - 1, - } - - local values - if func then - local call = createCall(exp, func.finish + 1, exp.finish) - if #exp == 0 then - exp[1] = getSelect(func, 2) - exp[2] = getSelect(func, 3) - exp[3] = getSelect(func, 4) - end - call.node = func - call.start = inA - call.finish = doB - 1 - func.next = call - func.iterator = true - values = { call } - keys.range = call.finish - end - for i = 1, #keys do - local loc = keys[i] - if values then - createLocal(loc, blockStart, getValue(values, i)) - else - createLocal(loc, blockStart) - end - end - checkMissEnd(forA) - return block - end, - While = function (whileA, whileB, filter, doA, doB, block, endA, endB) - block.type = 'while' - block.start = whileA - block.finish = endB - 1 - block.filter = filter - block.keyword= { - whileA, whileB - 1, - doA , doB - 1, - endA , endB - 1, - } - checkMissEnd(whileA) - return block - end, - Repeat = function (repeatA, repeatB, block, untilA, untilB, filter, finish) - block.type = 'repeat' - block.start = repeatA - block.finish = finish - block.filter = filter - block.keyword= { - repeatA, repeatB - 1, - untilA , untilB - 1, - } - return block - end, - RTContinue = function (_, pos, ...) - if State.options.nonstandardSymbol and State.options.nonstandardSymbol['continue'] then - return pos, ... - else - return false - end - end, - Continue = function (start, finish) - return { - type = 'nonstandardSymbol.continue', - start = start, - finish = finish - 1, - } - end, - Lua = function (start, actions, finish) - actions.type = 'main' - actions.start = start - actions.finish = finish - 1 - return actions - end, - - -- 捕获错误 - UnknownSymbol = function (start, symbol) - PushError { - type = 'UNKNOWN_SYMBOL', - start = start, - finish = start + #symbol - 1, - info = { - symbol = symbol, - } - } - end, - UnknownAction = function (start, symbol) - PushError { - type = 'UNKNOWN_SYMBOL', - start = start, - finish = start + #symbol - 1, - info = { - symbol = symbol, - } - } - end, - DirtyName = function (pos) - PushError { - type = 'MISS_NAME', - start = pos, - finish = pos, - } - return nil - end, - DirtyExp = function (pos) - PushError { - type = 'MISS_EXP', - start = pos, - finish = pos, - } - return nil - end, - MissExp = function (pos) - PushError { - type = 'MISS_EXP', - start = pos, - finish = pos, - } - end, - MissExponent = function (start, finish) - PushError { - type = 'MISS_EXPONENT', - start = start, - finish = finish - 1, - } - end, - MissQuote1 = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = '"' - } - } - end, - MissQuote2 = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = "'" - } - } - end, - MissQuote3 = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = "`" - } - } - end, - MissEscX = function (pos) - PushError { - type = 'MISS_ESC_X', - start = pos-2, - finish = pos+1, - } - end, - MissTL = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = '{', - } - } - end, - MissTR = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = '}', - } - } - end, - MissBR = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = ']', - } - } - end, - MissPL = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = '(', - } - } - end, - MissPR = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = ')', - } - } - end, - ErrEsc = function (pos) - PushError { - type = 'ERR_ESC', - start = pos-1, - finish = pos, - } - end, - MustX16 = function (pos, str) - PushError { - type = 'MUST_X16', - start = pos, - finish = pos + #str - 1, - } - end, - MissAssign = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = '=', - } - } - end, - MissTableSep = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = ',' - } - } - end, - MissField = function (pos) - PushError { - type = 'MISS_FIELD', - start = pos, - finish = pos, - } - end, - MissMethod = function (pos) - PushError { - type = 'MISS_METHOD', - start = pos, - finish = pos, - } - end, - MissLabel = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = '::', - } - } - end, - MissEnd = function (pos) - State.MissEndErr = PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = 'end', - } - } - return pos, pos - end, - MissDo = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = 'do', - } - } - return pos, pos - end, - MissComma = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = ',', - } - } - end, - MissIn = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = 'in', - } - } - return pos, pos - end, - MissUntil = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = 'until', - } - } - return pos, pos - end, - MissThen = function (pos) - PushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = 'then', - } - } - return pos, pos - end, - MissName = function (pos) - PushError { - type = 'MISS_NAME', - start = pos, - finish = pos, - } - end, - ExpInAction = function (start, exp, finish) - PushError { - type = 'EXP_IN_ACTION', - start = start, - finish = finish - 1, - } - -- 当exp为nil时,不能返回任何值,否则会产生带洞的actionlist - if exp then - return exp - else - return - end - end, - MissIf = function (start, block) - PushError { - type = 'MISS_SYMBOL', - start = start, - finish = start, - info = { - symbol = 'if', - } - } - return block - end, - MissGT = function (start) - PushError { - type = 'MISS_SYMBOL', - start = start, - finish = start, - info = { - symbol = '>' - } - } - end, - ErrAssign = function (start, finish) - PushError { - type = 'ERR_ASSIGN_AS_EQ', - start = start, - finish = finish - 1, - fix = { - title = 'FIX_ASSIGN_AS_EQ', - { - start = start, - finish = finish - 1, - text = '=', - } - } - } - end, - ErrEQ = function (start, finish) - PushError { - type = 'ERR_EQ_AS_ASSIGN', - start = start, - finish = finish - 1, - fix = { - title = 'FIX_EQ_AS_ASSIGN', - { - start = start, - finish = finish - 1, - text = '==', - } - } - } - return '==' - end, - ErrUEQ = function (start, finish) - PushError { - type = 'ERR_UEQ', - start = start, - finish = finish - 1, - fix = { - title = 'FIX_UEQ', - { - start = start, - finish = finish - 1, - text = '~=', - } - } - } - return '==' - end, - ErrThen = function (start, finish) - PushError { - type = 'ERR_THEN_AS_DO', - start = start, - finish = finish - 1, - fix = { - title = 'FIX_THEN_AS_DO', - { - start = start, - finish = finish - 1, - text = 'then', - } - } - } - return start, finish - end, - ErrDo = function (start, finish) - PushError { - type = 'ERR_DO_AS_THEN', - start = start, - finish = finish - 1, - fix = { - title = 'FIX_DO_AS_THEN', - { - start = start, - finish = finish - 1, - text = 'do', - } - } - } - return start, finish - end, - MissSpaceBetween = function (start) - PushError { - type = 'MISS_SPACE_BETWEEN', - start = start, - finish = start + 1, - fix = { - title = 'FIX_INSERT_SPACE', - { - start = start + 1, - finish = start, - text = ' ', - } - } - } - end, - CallArgSnip = function (name, tailStart, tailSymbol) - PushError { - type = 'UNEXPECT_SYMBOL', - start = tailStart, - finish = tailStart, - info = { - symbol = tailSymbol, - } - } - return name - end -} - -local function init(state) - State = state - PushError = state.pushError - PushDiag = state.pushDiag - PushComment = state.pushComment -end - -local function close() - State = DefaultState - PushError = function (...) end - PushDiag = function (...) end - PushComment = function (...) end -end - -return { - defs = Defs, - init = init, - close = close, -} diff --git a/script/parser/calcline.lua b/script/parser/calcline.lua deleted file mode 100644 index 2e944167..00000000 --- a/script/parser/calcline.lua +++ /dev/null @@ -1,94 +0,0 @@ -local m = require 'lpeglabel' -local util = require 'utility' - -local row -local fl -local NL = (m.P'\r\n' + m.S'\r\n') * m.Cp() / function (pos) - row = row + 1 - fl = pos -end -local ROWCOL = (NL + m.P(1))^0 -local function rowcol(str, n) - row = 1 - fl = 1 - ROWCOL:match(str:sub(1, n)) - local col = n - fl + 1 - return row, col -end - -local function rowcol_utf8(str, n) - row = 1 - fl = 1 - ROWCOL:match(str:sub(1, n)) - return row, util.utf8Len(str, fl, n) -end - -local function position(str, _row, _col) - local cur = 1 - local row = 1 - while true do - if row == _row then - return cur + _col - 1 - elseif row > _row then - return cur - 1 - end - local pos = str:find('[\r\n]', cur) - if not pos then - return #str - end - row = row + 1 - if str:sub(pos, pos+1) == '\r\n' then - cur = pos + 2 - else - cur = pos + 1 - end - end -end - -local function position_utf8(str, _row, _col) - local cur = 1 - local row = 1 - while true do - if row == _row then - return utf8.offset(str, _col, cur) - elseif row > _row then - return cur - 1 - end - local pos = str:find('[\r\n]', cur) - if not pos then - return #str - end - row = row + 1 - if str:sub(pos, pos+1) == '\r\n' then - cur = pos + 2 - else - cur = pos + 1 - end - end -end - -local NL = m.P'\r\n' + m.S'\r\n' - -local function line(str, row) - local count = 0 - local res - local LINE = m.Cmt((1 - NL)^0, function (_, _, c) - count = count + 1 - if count == row then - res = c - return false - end - return true - end) - local MATCH = (LINE * NL)^0 * LINE - MATCH:match(str) - return res -end - -return { - rowcol = rowcol, - rowcol_utf8 = rowcol_utf8, - position = position, - position_utf8 = position_utf8, - line = line, -} diff --git a/script/parser/compile.lua b/script/parser/compile.lua index 752728d1..915a2764 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -1,12 +1,109 @@ -local guide = require 'parser.guide' -local parse = require 'parser.parse' -local newparser = require 'parser.newparser' -local type = type -local tableInsert = table.insert -local pairs = pairs -local os = os - -local specials = { +local tokens = require 'parser.tokens' +local guide = require 'parser.guide' + +local sbyte = string.byte +local sfind = string.find +local smatch = string.match +local sgsub = string.gsub +local ssub = string.sub +local schar = string.char +local supper = string.upper +local uchar = utf8.char +local tconcat = table.concat +local tinsert = table.insert +local tointeger = math.tointeger +local tonumber = tonumber +local maxinteger = math.maxinteger +local assert = assert + +_ENV = nil + +---@alias parser.position integer + +---@param str string +---@return table<integer, boolean> +local function stringToCharMap(str) + local map = {} + local pos = 1 + while pos <= #str do + local byte = sbyte(str, pos, pos) + map[schar(byte)] = true + pos = pos + 1 + if ssub(str, pos, pos) == '-' + and pos < #str then + pos = pos + 1 + local byte2 = sbyte(str, pos, pos) + assert(byte < byte2) + for b = byte + 1, byte2 do + map[schar(b)] = true + end + pos = pos + 1 + end + end + return map +end + +local CharMapNumber = stringToCharMap '0-9' +local CharMapN16 = stringToCharMap 'xX' +local CharMapN2 = stringToCharMap 'bB' +local CharMapE10 = stringToCharMap 'eE' +local CharMapE16 = stringToCharMap 'pP' +local CharMapSign = stringToCharMap '+-' +local CharMapSB = stringToCharMap 'ao|~&=<>.*/%^+-' +local CharMapSU = stringToCharMap 'n#~!-' +local CharMapSimple = stringToCharMap '.:([\'"{' +local CharMapStrSH = stringToCharMap '\'"`' +local CharMapStrLH = stringToCharMap '[' +local CharMapTSep = stringToCharMap ',;' +local CharMapWord = stringToCharMap '_a-zA-Z\x80-\xff' + +local EscMap = { + ['a'] = '\a', + ['b'] = '\b', + ['f'] = '\f', + ['n'] = '\n', + ['r'] = '\r', + ['t'] = '\t', + ['v'] = '\v', + ['\\'] = '\\', + ['\''] = '\'', + ['\"'] = '\"', +} + +local NLMap = { + ['\n'] = true, + ['\r'] = true, + ['\r\n'] = true, +} + +local LineMulti = 10000 + +-- goto 单独处理 +local KeyWord = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +local Specials = { ['_G'] = true, ['rawset'] = true, ['rawget'] = true, @@ -18,491 +115,622 @@ local specials = { ['xpcall'] = true, ['pairs'] = true, ['ipairs'] = true, + ['assert'] = true, + ['error'] = true, + ['type'] = true, } -_ENV = nil +local UnarySymbol = { + ['not'] = 11, + ['#'] = 11, + ['~'] = 11, + ['-'] = 11, +} + +local BinarySymbol = { + ['or'] = 1, + ['and'] = 2, + ['<='] = 3, + ['>='] = 3, + ['<'] = 3, + ['>'] = 3, + ['~='] = 3, + ['=='] = 3, + ['|'] = 4, + ['~'] = 5, + ['&'] = 6, + ['<<'] = 7, + ['>>'] = 7, + ['..'] = 8, + ['+'] = 9, + ['-'] = 9, + ['*'] = 10, + ['//'] = 10, + ['/'] = 10, + ['%'] = 10, + ['^'] = 12, +} + +local BinaryAlias = { + ['&&'] = 'and', + ['||'] = 'or', + ['!='] = '~=', +} + +local BinaryActionAlias = { + ['='] = '==', +} + +local UnaryAlias = { + ['!'] = 'not', +} + +local SymbolForward = { + [01] = true, + [02] = true, + [03] = true, + [04] = true, + [05] = true, + [06] = true, + [07] = true, + [08] = false, + [09] = true, + [10] = true, + [11] = true, + [12] = false, +} + +local GetToSetMap = { + ['getglobal'] = 'setglobal', + ['getlocal'] = 'setlocal', + ['getfield'] = 'setfield', + ['getindex'] = 'setindex', + ['getmethod'] = 'setmethod', +} + +local ChunkFinishMap = { + ['end'] = true, + ['else'] = true, + ['elseif'] = true, + ['in'] = true, + ['then'] = true, + ['until'] = true, + [';'] = true, + [']'] = true, + [')'] = true, + ['}'] = true, +} + +local ChunkStartMap = { + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['for'] = true, + ['function'] = true, + ['if'] = true, + ['local'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['until'] = true, + ['while'] = true, +} + +local ListFinishMap = { + ['end'] = true, + ['else'] = true, + ['elseif'] = true, + ['in'] = true, + ['then'] = true, + ['do'] = true, + ['until'] = true, + ['for'] = true, + ['if'] = true, + ['local'] = true, + ['repeat'] = true, + ['return'] = true, + ['while'] = true, +} + +local State, Lua, Line, LineOffset, Chunk, Tokens, Index, LastTokenFinish, Mode, LocalCount local LocalLimit = 200 -local pushError, Compile, CompileBlock, Block, GoToTag, ENVMode, Compiled, LocalCount, Version, Root, Options -local function addRef(node, obj) - if not node.ref then - node.ref = {} - end - node.ref[#node.ref+1] = obj - obj.node = node -end +local parseExp, parseAction + +local pushError local function addSpecial(name, obj) - if not Root.specials then - Root.specials = {} + if not State.specials then + State.specials = {} end - if not Root.specials[name] then - Root.specials[name] = {} + if not State.specials[name] then + State.specials[name] = {} end - Root.specials[name][#Root.specials[name]+1] = obj + State.specials[name][#State.specials[name]+1] = obj obj.special = name end -local vmMap = { - ['getname'] = function (obj) - local loc = guide.getLocal(obj, obj[1], obj.start) - if loc then - obj.type = 'getlocal' - obj.loc = loc - addRef(loc, obj) - if loc.special then - addSpecial(loc.special, obj) - end - else - obj.type = 'getglobal' - local node = guide.getLocal(obj, ENVMode, obj.start) - if node then - addRef(node, obj) - end - local name = obj[1] - if specials[name] then - addSpecial(name, obj) - elseif Options and Options.special then - local asName = Options.special[name] - if specials[asName] then - addSpecial(asName, obj) - end - end +---@param offset integer +---@param leftOrRight '"left"'|'"right"' +local function getPosition(offset, leftOrRight) + if not offset or offset > #Lua then + return LineMulti * Line + #Lua - LineOffset + 1 + end + if leftOrRight == 'left' then + return LineMulti * Line + offset - LineOffset + else + return LineMulti * Line + offset - LineOffset + 1 + end +end + +---@return string? word +---@return parser.position? startPosition +---@return parser.position? finishPosition +local function peekWord() + local word = Tokens[Index + 1] + if not word then + return nil + end + if not CharMapWord[ssub(word, 1, 1)] then + return nil + end + local startPos = getPosition(Tokens[Index] , 'left') + local finishPos = getPosition(Tokens[Index] + #word - 1, 'right') + return word, startPos, finishPos +end + +local function lastRightPosition() + if Index < 2 then + return 0 + end + local token = Tokens[Index - 1] + if NLMap[token] then + return LastTokenFinish + elseif token then + return getPosition(Tokens[Index - 2] + #token - 1, 'right') + else + return getPosition(#Lua, 'right') + end +end + +local function missSymbol(symbol, start, finish) + pushError { + type = 'MISS_SYMBOL', + start = start or lastRightPosition(), + finish = finish or start or lastRightPosition(), + info = { + symbol = symbol, + } + } +end + +local function missExp() + pushError { + type = 'MISS_EXP', + start = lastRightPosition(), + finish = lastRightPosition(), + } +end + +local function missName(pos) + pushError { + type = 'MISS_NAME', + start = pos or lastRightPosition(), + finish = pos or lastRightPosition(), + } +end + +local function missEnd(relatedStart, relatedFinish) + pushError { + type = 'MISS_SYMBOL', + start = lastRightPosition(), + finish = lastRightPosition(), + info = { + symbol = 'end', + related = { + { + start = relatedStart, + finish = relatedFinish, + } + } + } + } + pushError { + type = 'MISS_END', + start = relatedStart, + finish = relatedFinish, + } +end + +local function unknownSymbol(start, finish, word) + local token = word or Tokens[Index + 1] + if not token then + return false + end + pushError { + type = 'UNKNOWN_SYMBOL', + start = start or getPosition(Tokens[Index], 'left'), + finish = finish or getPosition(Tokens[Index] + #token - 1, 'right'), + info = { + symbol = token, + } + } + return true +end + +local function skipUnknownSymbol(stopSymbol) + if unknownSymbol() then + Index = Index + 2 + return true + end + return false +end + +local function skipNL() + local token = Tokens[Index + 1] + if NLMap[token] then + if Index >= 2 and not NLMap[Tokens[Index - 1]] then + LastTokenFinish = getPosition(Tokens[Index - 2] + #Tokens[Index - 1] - 1, 'right') end - return obj - end, - ['getfield'] = function (obj) - Compile(obj.node, obj) - end, - ['call'] = function (obj) - Compile(obj.node, obj) - if obj.node and obj.node.type == 'getmethod' then - if not obj.args then - obj.args = { - type = 'callargs', - start = obj.start, - finish = obj.finish, - parent = obj, - } - end - local newNode = {} - for k, v in pairs(obj.node.node) do - newNode[k] = v - end - newNode.mirror = obj.node.node - newNode.dummy = true - newNode.parent = obj.args - obj.node.node.mirror = newNode - tableInsert(obj.args, 1, newNode) - Compiled[newNode] = true - end - Compile(obj.args, obj) - end, - ['callargs'] = function (obj) - for i = 1, #obj do - Compile(obj[i], obj) - end - end, - ['binary'] = function (obj) - Compile(obj[1], obj) - Compile(obj[2], obj) - end, - ['unary'] = function (obj) - Compile(obj[1], obj) - end, - ['varargs'] = function (obj) - local func = guide.getParentFunction(obj) - if func then - local index, vararg = guide.getFunctionVarArgs(func) - if not index then - pushError { - type = 'UNEXPECT_DOTS', - start = obj.start, - finish = obj.finish, + Line = Line + 1 + LineOffset = Tokens[Index] + #token + Index = Index + 2 + State.lines[Line] = LineOffset + return true + end + return false +end + +local function getSavePoint() + local index = Index + local line = Line + local lineOffset = LineOffset + local errs = State.errs + local errCount = #errs + return function () + Index = index + Line = line + LineOffset = lineOffset + for i = errCount + 1, #errs do + errs[i] = nil + end + end +end + +local function fastForwardToken(offset) + while true do + local myOffset = Tokens[Index] + if not myOffset + or myOffset >= offset then + break + end + local token = Tokens[Index + 1] + if NLMap[token] then + Line = Line + 1 + LineOffset = Tokens[Index] + #token + State.lines[Line] = LineOffset + end + Index = Index + 2 + end +end + +local function resolveLongString(finishMark) + skipNL() + local miss + local start = Tokens[Index] + local finishOffset = sfind(Lua, finishMark, start, true) + if not finishOffset then + finishOffset = #Lua + 1 + miss = true + end + local stringResult = start and ssub(Lua, start, finishOffset - 1) or '' + local lastLN = stringResult:find '[\r\n][^\r\n]*$' + if lastLN then + local result = stringResult + : gsub('\r\n?', '\n') + stringResult = result + end + fastForwardToken(finishOffset + #finishMark) + if miss then + local pos = getPosition(finishOffset - 1, 'right') + pushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = finishMark, + }, + fix = { + title = 'ADD_LSTRING_END', + { + start = pos, + finish = pos, + text = finishMark, } - end - if vararg then - if not vararg.ref then - vararg.ref = {} - end - vararg.ref[#vararg.ref+1] = obj - obj.node = vararg - end - end - end, - ['paren'] = function (obj) - Compile(obj.exp, obj) - end, - ['getindex'] = function (obj) - Compile(obj.node, obj) - Compile(obj.index, obj) - end, - ['setindex'] = function (obj) - Compile(obj.node, obj) - Compile(obj.index, obj) - Compile(obj.value, obj) - end, - ['getmethod'] = function (obj) - Compile(obj.node, obj) - Compile(obj.method, obj) - end, - ['setmethod'] = function (obj) - Compile(obj.node, obj) - Compile(obj.method, obj) - local value = obj.value - local localself = { - type = 'local', - start = value.start, - finish = value.start, - method = obj, - effect = obj.finish, - tag = 'self', - dummy = true, - [1] = 'self', + }, } - if not value.args then - value.args = { - type = 'funcargs', - start = obj.start, - finish = obj.finish, - } + end + return stringResult, getPosition(finishOffset + #finishMark - 1, 'right') +end + +local function parseLongString() + local start, finish, mark = sfind(Lua, '^(%[%=*%[)', Tokens[Index]) + if not mark then + return nil + end + fastForwardToken(finish + 1) + local startPos = getPosition(start, 'left') + local finishMark = sgsub(mark, '%[', ']') + local stringResult, finishPos = resolveLongString(finishMark) + return { + type = 'string', + start = startPos, + finish = finishPos, + [1] = stringResult, + [2] = mark, + } +end + +local function pushCommentHeadError(left) + if State.options.nonstandardSymbol['//'] then + return + end + pushError { + type = 'ERR_COMMENT_PREFIX', + start = left, + finish = left + 2, + fix = { + title = 'FIX_COMMENT_PREFIX', + { + start = left, + finish = left + 2, + text = '--', + }, + } + } +end + +local function pushLongCommentError(left, right) + if State.options.nonstandardSymbol['/**/'] then + return + end + pushError { + type = 'ERR_C_LONG_COMMENT', + start = left, + finish = right, + fix = { + title = 'FIX_C_LONG_COMMENT', + { + start = left, + finish = left + 2, + text = '--[[', + }, + { + start = right - 2, + finish = right, + text = '--]]' + }, + } + } +end + +local function skipComment(isAction) + local token = Tokens[Index + 1] + if token == '--' + or ( + token == '//' + and ( + isAction + or State.options.nonstandardSymbol['//'] + ) + ) then + local start = Tokens[Index] + local left = getPosition(start, 'left') + local chead = false + if token == '//' then + chead = true + pushCommentHeadError(left) end - tableInsert(value.args, 1, localself) - Compile(value, obj) - end, - ['function'] = function (obj) - local lastBlock = Block - local LastLocalCount = LocalCount - Block = obj - LocalCount = 0 - Compile(obj.args, obj) - for i = 1, #obj do - Compile(obj[i], obj) - end - Block = lastBlock - LocalCount = LastLocalCount - end, - ['funcargs'] = function (obj) - for i = 1, #obj do - Compile(obj[i], obj) - end - end, - ['table'] = function (obj) - for i = 1, #obj do - Compile(obj[i], obj) - end - end, - ['tablefield'] = function (obj) - Compile(obj.value, obj) - end, - ['tableindex'] = function (obj) - Compile(obj.index, obj) - Compile(obj.value, obj) - end, - ['tableexp'] = function (obj) - Compile(obj.value, obj) - end, - ['index'] = function (obj) - Compile(obj.index, obj) - end, - ['select'] = function (obj) - local vararg = obj.vararg - if vararg.parent then - if not vararg.extParent then - vararg.extParent = {} - end - vararg.extParent[#vararg.extParent+1] = obj - else - Compile(vararg, obj) - end - end, - ['setname'] = function (obj) - Compile(obj.value, obj) - local loc = guide.getLocal(obj, obj[1], obj.start) - if loc then - obj.type = 'setlocal' - obj.loc = loc - addRef(loc, obj) - if loc.attrs then - local const - for i = 1, #loc.attrs do - local attr = loc.attrs[i][1] - if attr == 'const' - or attr == 'close' then - const = true - break - end - end - if const then - pushError { - type = 'SET_CONST', - start = obj.start, - finish = obj.finish, - } - end - end - else - obj.type = 'setglobal' - local node = guide.getLocal(obj, ENVMode, obj.start) - if node then - addRef(node, obj) - end - local name = obj[1] - if specials[name] then - addSpecial(name, obj) - elseif Options and Options.special then - local asName = Options.special[name] - if specials[asName] then - addSpecial(asName, obj) - end - end + Index = Index + 2 + local longComment = start + 2 == Tokens[Index] and parseLongString() + if longComment then + longComment.type = 'comment.long' + longComment.text = longComment[1] + longComment.mark = longComment[2] + longComment[1] = nil + longComment[2] = nil + State.comms[#State.comms+1] = longComment + return true end - end, - ['local'] = function (obj) - local attrs = obj.attrs - if attrs then - for i = 1, #attrs do - Compile(attrs[i], obj) + while true do + local nl = Tokens[Index + 1] + if not nl or NLMap[nl] then + break end + Index = Index + 2 end - if Block then - if not Block.locals then - Block.locals = {} - end - Block.locals[#Block.locals+1] = obj - LocalCount = LocalCount + 1 - if LocalCount > LocalLimit then - pushError { - type = 'LOCAL_LIMIT', - start = obj.start, - finish = obj.finish, + local right = Tokens[Index] and (Tokens[Index] - 1) or #Lua + State.comms[#State.comms+1] = { + type = chead and 'comment.cshort' or 'comment.short', + start = left, + finish = getPosition(right, 'right'), + text = ssub(Lua, start + 2, right), + } + return true + end + if token == '/*' then + local start = Tokens[Index] + local left = getPosition(start, 'left') + Index = Index + 2 + local result, right = resolveLongString '*/' + pushLongCommentError(left, right) + State.comms[#State.comms+1] = { + type = 'comment.long', + start = left, + finish = right, + text = result, + } + return true + end + return false +end + +local function skipSpace(isAction) + repeat until not skipNL() + and not skipComment(isAction) +end + +local function expectAssign(isAction) + local token = Tokens[Index + 1] + if token == '=' then + Index = Index + 2 + return true + end + if token == '==' then + local left = getPosition(Tokens[Index], 'left') + local right = getPosition(Tokens[Index] + #token - 1, 'right') + pushError { + type = 'ERR_ASSIGN_AS_EQ', + start = left, + finish = right, + fix = { + title = 'FIX_ASSIGN_AS_EQ', + { + start = left, + finish = right, + text = '=', } + } + } + Index = Index + 2 + return true + end + if isAction then + if token == '+=' + or token == '-=' + or token == '*=' + or token == '/=' + or token == '%=' + or token == '^=' + or token == '//=' + or token == '|=' + or token == '&=' + or token == '>>=' + or token == '<<=' then + if not State.options.nonstandardSymbol[token] then + unknownSymbol() end + Index = Index + 2 + return true end - if obj.localfunction then - obj.localfunction = nil + end + return false +end + +local function parseLocalAttrs() + local attrs + while true do + skipSpace() + local token = Tokens[Index + 1] + if token ~= '<' then + break end - Compile(obj.value, obj) - if obj.value and obj.value.special then - addSpecial(obj.value.special, obj) + if not attrs then + attrs = { + type = 'localattrs', + } end - end, - ['setfield'] = function (obj) - Compile(obj.node, obj) - Compile(obj.value, obj) - end, - ['do'] = function (obj) - local lastBlock = Block - Block = obj - CompileBlock(obj, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals + local attr = { + type = 'localattr', + parent = attrs, + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index], 'right'), + } + attrs[#attrs+1] = attr + Index = Index + 2 + skipSpace() + local word, wstart, wfinish = peekWord() + if word then + attr[1] = word + attr.finish = wfinish + Index = Index + 2 + if word ~= 'const' + and word ~= 'close' then + pushError { + type = 'UNKNOWN_ATTRIBUTE', + start = wstart, + finish = wfinish, + } + end + else + missName() end - Block = lastBlock - end, - ['return'] = function (obj) - for i = 1, #obj do - Compile(obj[i], obj) + attr.finish = lastRightPosition() + skipSpace() + if Tokens[Index + 1] == '>' then + attr.finish = getPosition(Tokens[Index], 'right') + Index = Index + 2 + elseif Tokens[Index + 1] == '>=' then + attr.finish = getPosition(Tokens[Index], 'right') + pushError { + type = 'MISS_SPACE_BETWEEN', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + 1, 'right'), + } + Index = Index + 2 + else + missSymbol '>' end - if Block and Block[#Block] ~= obj then + if State.version ~= 'Lua 5.4' then pushError { - type = 'ACTION_AFTER_RETURN', - start = obj.start, - finish = obj.finish, + type = 'UNSUPPORT_SYMBOL', + start = attr.start, + finish = attr.finish, + version = 'Lua 5.4', + info = { + version = State.version + } } end - local func = guide.getParentFunction(obj) - if func then - if not func.returns then - func.returns = {} - end - func.returns[#func.returns+1] = obj + end + return attrs +end + +local function createLocal(obj, attrs) + obj.type = 'local' + obj.effect = obj.finish + + if attrs then + obj.attrs = attrs + attrs.parent = obj + end + + local chunk = Chunk[#Chunk] + if chunk then + local locals = chunk.locals + if not locals then + locals = {} + chunk.locals = locals end - end, - ['label'] = function (obj) - local block = guide.getBlock(obj) - if block then - if not block.labels then - block.labels = {} - end - local name = obj[1] - local label = guide.getLabel(block, name) - if label then - if Version == 'Lua 5.4' - or block == guide.getBlock(label) then - pushError { - type = 'REDEFINED_LABEL', - start = obj.start, - finish = obj.finish, - relative = { - { - label.start, - label.finish, - } - } - } - end - end - block.labels[name] = obj - end - end, - ['goto'] = function (obj) - GoToTag[#GoToTag+1] = obj - end, - ['if'] = function (obj) - for i = 1, #obj do - Compile(obj[i], obj) - end - end, - ['ifblock'] = function (obj) - local lastBlock = Block - Block = obj - Compile(obj.filter, obj) - CompileBlock(obj, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals - end - Block = lastBlock - end, - ['elseifblock'] = function (obj) - local lastBlock = Block - Block = obj - Compile(obj.filter, obj) - CompileBlock(obj, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals - end - Block = lastBlock - end, - ['elseblock'] = function (obj) - local lastBlock = Block - Block = obj - CompileBlock(obj, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals - end - Block = lastBlock - end, - ['loop'] = function (obj) - local lastBlock = Block - Block = obj - Compile(obj.loc, obj) - Compile(obj.max, obj) - Compile(obj.step, obj) - CompileBlock(obj, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals - end - Block = lastBlock - end, - ['in'] = function (obj) - local lastBlock = Block - Block = obj - local keys = obj.keys - for i = 1, #keys do - Compile(keys[i], obj) - end - CompileBlock(obj, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals - end - Block = lastBlock - end, - ['while'] = function (obj) - local lastBlock = Block - Block = obj - Compile(obj.filter, obj) - CompileBlock(obj, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals - end - Block = lastBlock - end, - ['repeat'] = function (obj) - local lastBlock = Block - Block = obj - CompileBlock(obj, obj) - Compile(obj.filter, obj) - if Block.locals then - LocalCount = LocalCount - #Block.locals - end - Block = lastBlock - end, - ['break'] = function (obj) - local block = guide.getBreakBlock(obj) - if block then - if not block.breaks then - block.breaks = {} - end - block.breaks[#block.breaks+1] = obj - else + locals[#locals+1] = obj + LocalCount = LocalCount + 1 + if LocalCount > LocalLimit then pushError { - type = 'BREAK_OUTSIDE', + type = 'LOCAL_LIMIT', start = obj.start, finish = obj.finish, } end - end, - ['main'] = function (obj) - Block = obj - Compile({ - type = 'local', - start = 0, - finish = 0, - effect = 0, - tag = '_ENV', - special= '_G', - [1] = ENVMode, - }, obj) - --- _ENV 是上值,不计入局部变量计数 - LocalCount = 0 - CompileBlock(obj, obj) - Block = nil - end, -} - -function CompileBlock(obj, parent) - for i = 1, #obj do - local act = obj[i] - act.parent = parent - local f = vmMap[act.type] - if f then - f(act) - end end + return obj end -function Compile(obj, parent) - if not obj then - return nil - end - if Compiled[obj] then - return - end - Compiled[obj] = true - obj.parent = parent - local f = vmMap[obj.type] - if not f then - return - end - f(obj) +local function pushChunk(chunk) + Chunk[#Chunk+1] = chunk end -local function compileGoTo(obj) - local name = obj[1] - local label = guide.getLabel(obj, name) - if not label then - pushError { - type = 'NO_VISIBLE_LABEL', - start = obj.start, - finish = obj.finish, - info = { - label = name, - } - } - return - end +local function resolveLable(label, obj) if not label.ref then label.ref = {} end @@ -538,10 +766,10 @@ local function compileGoTo(obj) local ref = refs[j] if ref.finish > label.finish then pushError { - type = 'JUMP_LOCAL_SCOPE', - start = obj.start, - finish = obj.finish, - info = { + type = 'JUMP_LOCAL_SCOPE', + start = obj.start, + finish = obj.finish, + info = { loc = loc[1], }, relative = { @@ -562,42 +790,3129 @@ local function compileGoTo(obj) end end -local function PostCompile() - for i = 1, #GoToTag do - compileGoTo(GoToTag[i]) +local function resolveGoTo(gotos) + for i = 1, #gotos do + local action = gotos[i] + local label = guide.getLabel(action, action[1]) + if label then + resolveLable(label, action) + else + pushError { + type = 'NO_VISIBLE_LABEL', + start = action.start, + finish = action.finish, + info = { + label = action[1], + } + } + end end end -return function (lua, mode, version, options) - do - local state, err = newparser(lua, mode, version, options) - return state, err - end - local state, err = parse(lua, mode, version, options) - if not state then - return nil, err - end - --if options and options.delay then - -- options.delay() - --end - local clock = os.clock() - pushError = state.pushError - ENVMode = state.ENVMode - Compiled = {} - GoToTag = {} +local function popChunk() + local chunk = Chunk[#Chunk] + if chunk.gotos then + resolveGoTo(chunk.gotos) + chunk.gotos = nil + end + local lastAction = chunk[#chunk] + if lastAction then + chunk.finish = lastAction.finish + end + Chunk[#Chunk] = nil +end + +local function parseNil() + if Tokens[Index + 1] ~= 'nil' then + return nil + end + local offset = Tokens[Index] + Index = Index + 2 + return { + type = 'nil', + start = getPosition(offset, 'left'), + finish = getPosition(offset + 2, 'right'), + } +end + +local function parseBoolean() + local word = Tokens[Index+1] + if word ~= 'true' + and word ~= 'false' then + return nil + end + local start = getPosition(Tokens[Index], 'left') + local finish = getPosition(Tokens[Index] + #word - 1, 'right') + Index = Index + 2 + return { + type = 'boolean', + start = start, + finish = finish, + [1] = word == 'true' and true or false, + } +end + +local function parseStringUnicode() + local offset = Tokens[Index] + 1 + if ssub(Lua, offset, offset) ~= '{' then + local pos = getPosition(offset, 'left') + missSymbol('{', pos) + return nil, offset + end + local leftPos = getPosition(offset, 'left') + local x16 = smatch(Lua, '^%w*', offset + 1) + local rightPos = getPosition(offset + #x16, 'right') + offset = offset + #x16 + 1 + if ssub(Lua, offset, offset) == '}' then + offset = offset + 1 + rightPos = rightPos + 1 + else + missSymbol('}', rightPos) + end + offset = offset + 1 + if #x16 == 0 then + pushError { + type = 'UTF8_SMALL', + start = leftPos, + finish = rightPos, + } + return '', offset + end + if State.version ~= 'Lua 5.3' + and State.version ~= 'Lua 5.4' + and State.version ~= 'LuaJIT' + then + pushError { + type = 'ERR_ESC', + start = leftPos - 2, + finish = rightPos, + version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return nil, offset + end + local byte = tonumber(x16, 16) + if not byte then + for i = 1, #x16 do + if not tonumber(ssub(x16, i, i), 16) then + pushError { + type = 'MUST_X16', + start = leftPos + i, + finish = leftPos + i + 1, + } + end + end + return nil, offset + end + if State.version == 'Lua 5.4' then + if byte < 0 or byte > 0x7FFFFFFF then + pushError { + type = 'UTF8_MAX', + start = leftPos, + finish = rightPos, + info = { + min = '00000000', + max = '7FFFFFFF', + } + } + return nil, offset + end + else + if byte < 0 or byte > 0x10FFFF then + pushError { + type = 'UTF8_MAX', + start = leftPos, + finish = rightPos, + version = byte <= 0x7FFFFFFF and 'Lua 5.4' or nil, + info = { + min = '000000', + max = '10FFFF', + } + } + end + end + if byte >= 0 and byte <= 0x10FFFF then + return uchar(byte), offset + end + return '', offset +end + +local stringPool = {} +local function parseShortString() + local mark = Tokens[Index+1] + local startOffset = Tokens[Index] + local startPos = getPosition(startOffset, 'left') + Index = Index + 2 + local stringIndex = 0 + local currentOffset = startOffset + 1 + local escs = {} + while true do + local token = Tokens[Index + 1] + if token == mark then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1) + Index = Index + 2 + break + end + if NLMap[token] then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1) + missSymbol(mark) + break + end + if not token then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = ssub(Lua, currentOffset or -1) + missSymbol(mark) + break + end + if token == '\\' then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1) + currentOffset = Tokens[Index] + Index = Index + 2 + if not Tokens[Index] then + goto CONTINUE + end + local escLeft = getPosition(currentOffset, 'left') + -- has space? + if Tokens[Index] - currentOffset > 1 then + local right = getPosition(currentOffset + 1, 'right') + pushError { + type = 'ERR_ESC', + start = escLeft, + finish = right, + } + escs[#escs+1] = escLeft + escs[#escs+1] = right + escs[#escs+1] = 'err' + goto CONTINUE + end + local nextToken = ssub(Tokens[Index + 1], 1, 1) + if EscMap[nextToken] then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = EscMap[nextToken] + currentOffset = Tokens[Index] + #nextToken + Index = Index + 2 + escs[#escs+1] = escLeft + escs[#escs+1] = escLeft + 2 + escs[#escs+1] = 'normal' + goto CONTINUE + end + if nextToken == mark then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = mark + currentOffset = Tokens[Index] + #nextToken + Index = Index + 2 + escs[#escs+1] = escLeft + escs[#escs+1] = escLeft + 2 + escs[#escs+1] = 'normal' + goto CONTINUE + end + if nextToken == 'z' then + Index = Index + 2 + repeat until not skipNL() + currentOffset = Tokens[Index] + escs[#escs+1] = escLeft + escs[#escs+1] = escLeft + 2 + escs[#escs+1] = 'normal' + goto CONTINUE + end + if CharMapNumber[nextToken] then + local numbers = smatch(Tokens[Index + 1], '^%d+') + if #numbers > 3 then + numbers = ssub(numbers, 1, 3) + end + currentOffset = Tokens[Index] + #numbers + fastForwardToken(currentOffset) + local right = getPosition(currentOffset - 1, 'right') + local byte = tointeger(numbers) + if byte and byte <= 255 then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = schar(byte) + else + pushError { + type = 'ERR_ESC', + start = escLeft, + finish = right, + } + end + escs[#escs+1] = escLeft + escs[#escs+1] = right + escs[#escs+1] = 'byte' + goto CONTINUE + end + if nextToken == 'x' then + local left = getPosition(Tokens[Index] - 1, 'left') + local x16 = ssub(Tokens[Index + 1], 2, 3) + local byte = tonumber(x16, 16) + if byte then + currentOffset = Tokens[Index] + 3 + stringIndex = stringIndex + 1 + stringPool[stringIndex] = schar(byte) + else + currentOffset = Tokens[Index] + 1 + pushError { + type = 'MISS_ESC_X', + start = getPosition(currentOffset, 'left'), + finish = getPosition(currentOffset + 1, 'right'), + } + end + local right = getPosition(currentOffset + 1, 'right') + escs[#escs+1] = escLeft + escs[#escs+1] = right + escs[#escs+1] = 'byte' + if State.version == 'Lua 5.1' then + pushError { + type = 'ERR_ESC', + start = left, + finish = left + 4, + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + end + Index = Index + 2 + goto CONTINUE + end + if nextToken == 'u' then + local str, newOffset = parseStringUnicode() + if str then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = str + end + currentOffset = newOffset + fastForwardToken(currentOffset - 1) + local right = getPosition(currentOffset + 1, 'right') + escs[#escs+1] = escLeft + escs[#escs+1] = right + escs[#escs+1] = 'unicode' + goto CONTINUE + end + if NLMap[nextToken] then + stringIndex = stringIndex + 1 + stringPool[stringIndex] = '\n' + currentOffset = Tokens[Index] + #nextToken + skipNL() + local right = getPosition(currentOffset + 1, 'right') + escs[#escs+1] = escLeft + escs[#escs+1] = escLeft + 1 + escs[#escs+1] = 'normal' + goto CONTINUE + end + local right = getPosition(currentOffset + 1, 'right') + pushError { + type = 'ERR_ESC', + start = escLeft, + finish = right, + } + escs[#escs+1] = escLeft + escs[#escs+1] = right + escs[#escs+1] = 'err' + end + Index = Index + 2 + ::CONTINUE:: + end + local stringResult = tconcat(stringPool, '', 1, stringIndex) + local str = { + type = 'string', + start = startPos, + finish = lastRightPosition(), + escs = #escs > 0 and escs or nil, + [1] = stringResult, + [2] = mark, + } + if mark == '`' then + if not State.options.nonstandardSymbol[mark] then + pushError { + type = 'ERR_NONSTANDARD_SYMBOL', + start = startPos, + finish = str.finish, + info = { + symbol = '"', + }, + fix = { + title = 'FIX_NONSTANDARD_SYMBOL', + symbol = '"', + { + start = startPos, + finish = startPos + 1, + text = '"', + }, + { + start = str.finish - 1, + finish = str.finish, + text = '"', + }, + } + } + end + end + return str +end + +local function parseString() + local c = Tokens[Index + 1] + if CharMapStrSH[c] then + return parseShortString() + end + if CharMapStrLH[c] then + return parseLongString() + end + return nil +end + +local function parseNumber10(start) + local integer = true + local integerPart = smatch(Lua, '^%d*', start) + local offset = start + #integerPart + -- float part + if ssub(Lua, offset, offset) == '.' then + local floatPart = smatch(Lua, '^%d*', offset + 1) + integer = false + offset = offset + #floatPart + 1 + end + -- exp part + local echar = ssub(Lua, offset, offset) + if CharMapE10[echar] then + integer = false + offset = offset + 1 + local nextChar = ssub(Lua, offset, offset) + if CharMapSign[nextChar] then + offset = offset + 1 + end + local exp = smatch(Lua, '^%d*', offset) + offset = offset + #exp + if #exp == 0 then + pushError { + type = 'MISS_EXPONENT', + start = getPosition(offset - 1, 'right'), + finish = getPosition(offset - 1, 'right'), + } + end + end + return tonumber(ssub(Lua, start, offset - 1)), offset, integer +end + +local function parseNumber16(start) + local integerPart = smatch(Lua, '^[%da-fA-F]*', start) + local offset = start + #integerPart + local integer = true + -- float part + if ssub(Lua, offset, offset) == '.' then + local floatPart = smatch(Lua, '^[%da-fA-F]*', offset + 1) + integer = false + offset = offset + #floatPart + 1 + if #integerPart == 0 and #floatPart == 0 then + pushError { + type = 'MUST_X16', + start = getPosition(offset - 1, 'right'), + finish = getPosition(offset - 1, 'right'), + } + end + else + if #integerPart == 0 then + pushError { + type = 'MUST_X16', + start = getPosition(offset - 1, 'right'), + finish = getPosition(offset - 1, 'right'), + } + return 0, offset + end + end + -- exp part + local echar = ssub(Lua, offset, offset) + if CharMapE16[echar] then + integer = false + offset = offset + 1 + local nextChar = ssub(Lua, offset, offset) + if CharMapSign[nextChar] then + offset = offset + 1 + end + local exp = smatch(Lua, '^%d*', offset) + offset = offset + #exp + end + local n = tonumber(ssub(Lua, start - 2, offset - 1)) + return n, offset, integer +end + +local function parseNumber2(start) + local bins = smatch(Lua, '^[01]*', start) + local offset = start + #bins + if State.version ~= 'LuaJIT' then + pushError { + type = 'UNSUPPORT_SYMBOL', + start = getPosition(start - 2, 'left'), + finish = getPosition(offset - 1, 'right'), + version = 'LuaJIT', + info = { + version = 'Lua 5.4', + } + } + end + return tonumber(bins, 2), offset +end + +local function dropNumberTail(offset, integer) + local _, finish, word = sfind(Lua, '^([%.%w_\x80-\xff]+)', offset) + if not finish then + return offset + end + if integer then + if supper(ssub(word, 1, 2)) == 'LL' then + if State.version ~= 'LuaJIT' then + pushError { + type = 'UNSUPPORT_SYMBOL', + start = getPosition(offset, 'left'), + finish = getPosition(offset + 1, 'right'), + version = 'LuaJIT', + info = { + version = State.version, + } + } + end + offset = offset + 2 + word = ssub(word, offset) + elseif supper(ssub(word, 1, 3)) == 'ULL' then + if State.version ~= 'LuaJIT' then + pushError { + type = 'UNSUPPORT_SYMBOL', + start = getPosition(offset, 'left'), + finish = getPosition(offset + 2, 'right'), + version = 'LuaJIT', + info = { + version = State.version, + } + } + end + offset = offset + 3 + word = ssub(word, offset) + end + end + if supper(ssub(word, 1, 1)) == 'I' then + if State.version ~= 'LuaJIT' then + pushError { + type = 'UNSUPPORT_SYMBOL', + start = getPosition(offset, 'left'), + finish = getPosition(offset, 'right'), + version = 'LuaJIT', + info = { + version = State.version, + } + } + end + offset = offset + 1 + word = ssub(word, offset) + end + if #word > 0 then + pushError { + type = 'MALFORMED_NUMBER', + start = getPosition(offset, 'left'), + finish = getPosition(finish, 'right'), + } + end + return finish + 1 +end + +local function parseNumber() + local offset = Tokens[Index] + if not offset then + return nil + end + local startPos = getPosition(offset, 'left') + local neg + if ssub(Lua, offset, offset) == '-' then + neg = true + offset = offset + 1 + end + local number, integer + local firstChar = ssub(Lua, offset, offset) + if firstChar == '.' then + number, offset = parseNumber10(offset) + integer = false + elseif firstChar == '0' then + local nextChar = ssub(Lua, offset + 1, offset + 1) + if CharMapN16[nextChar] then + number, offset, integer = parseNumber16(offset + 2) + elseif CharMapN2[nextChar] then + number, offset = parseNumber2(offset + 2) + integer = true + else + number, offset, integer = parseNumber10(offset) + end + elseif CharMapNumber[firstChar] then + number, offset, integer = parseNumber10(offset) + else + return nil + end + if not number then + number = 0 + end + if neg then + number = - number + end + local result = { + type = integer and 'integer' or 'number', + start = startPos, + finish = getPosition(offset - 1, 'right'), + [1] = number, + } + offset = dropNumberTail(offset, integer) + fastForwardToken(offset) + return result +end + +local function isKeyWord(word) + if KeyWord[word] then + return true + end + if word == 'goto' then + return State.version ~= 'Lua 5.1' + end + return false +end + +local function parseName(asAction) + local word = peekWord() + if not word then + return nil + end + if ChunkFinishMap[word] then + return nil + end + if asAction and ChunkStartMap[word] then + return nil + end + local startPos = getPosition(Tokens[Index], 'left') + local finishPos = getPosition(Tokens[Index] + #word - 1, 'right') + Index = Index + 2 + if not State.options.unicodeName and word:find '[\x80-\xff]' then + pushError { + type = 'UNICODE_NAME', + start = startPos, + finish = finishPos, + } + end + if isKeyWord(word) then + pushError { + type = 'KEYWORD', + start = startPos, + finish = finishPos, + } + end + return { + type = 'name', + start = startPos, + finish = finishPos, + [1] = word, + } +end + +local function parseNameOrList(parent) + local first = parseName() + if not first then + return nil + end + skipSpace() + local list + while true do + if Tokens[Index + 1] ~= ',' then + break + end + Index = Index + 2 + skipSpace() + local name = parseName(true) + if not name then + missName() + break + end + if not list then + list = { + type = 'list', + start = first.start, + finish = first.finish, + parent = parent, + [1] = first + } + end + list[#list+1] = name + list.finish = name.finish + end + return list or first +end + +local function dropTail() + local token = Tokens[Index + 1] + if token ~= '?' + and token ~= ':' then + return + end + local pl, pt, pp = 0, 0, 0 + while true do + local token = Tokens[Index + 1] + if not token then + break + end + if NLMap[token] then + break + end + if token == ',' then + if pl > 0 + or pt > 0 + or pp > 0 then + goto CONTINUE + else + break + end + end + if token == '<' then + pl = pl + 1 + goto CONTINUE + end + if token == '{' then + pt = pt + 1 + goto CONTINUE + end + if token == '(' then + pp = pp + 1 + goto CONTINUE + end + if token == '>' then + if pl <= 0 then + break + end + pl = pl - 1 + goto CONTINUE + end + if token == '}' then + if pt <= 0 then + break + end + pt = pt - 1 + goto CONTINUE + end + if token == ')' then + if pp <= 0 then + break + end + pp = pp - 1 + goto CONTINUE + end + ::CONTINUE:: + Index = Index + 2 + end +end + +local function parseExpList(mini) + local list + local wantSep = false + while true do + skipSpace() + local token = Tokens[Index + 1] + if not token then + break + end + if ListFinishMap[token] then + break + end + if token == ',' then + local sepPos = getPosition(Tokens[Index], 'right') + if not wantSep then + pushError { + type = 'UNEXPECT_SYMBOL', + start = getPosition(Tokens[Index], 'left'), + finish = sepPos, + info = { + symbol = ',', + } + } + end + wantSep = false + Index = Index + 2 + goto CONTINUE + else + if mini then + if wantSep then + break + end + local nextToken = peekWord() + if isKeyWord(nextToken) + and nextToken ~= 'function' + and nextToken ~= 'true' + and nextToken ~= 'false' + and nextToken ~= 'nil' + and nextToken ~= 'not' then + break + end + end + local exp = parseExp() + if not exp then + break + end + dropTail() + if wantSep then + missSymbol(',', list[#list].finish, exp.start) + end + wantSep = true + if not list then + list = { + type = 'list', + start = exp.start, + } + end + list[#list+1] = exp + list.finish = exp.finish + exp.parent = list + end + ::CONTINUE:: + end + if not list then + return nil + end + if not wantSep then + missExp() + end + return list +end + +local function parseIndex() + local start = getPosition(Tokens[Index], 'left') + Index = Index + 2 + skipSpace() + local exp = parseExp() + local index = { + type = 'index', + start = start, + finish = exp and exp.finish or (start + 1), + index = exp + } + if exp then + exp.parent = index + else + missExp() + end + skipSpace() + if Tokens[Index + 1] == ']' then + index.finish = getPosition(Tokens[Index], 'right') + Index = Index + 2 + else + missSymbol ']' + end + return index +end + +local function parseTable() + local tbl = { + type = 'table', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index], 'right'), + } + Index = Index + 2 + local index = 0 + local tindex = 0 + local wantSep = false + while true do + skipSpace(true) + local token = Tokens[Index + 1] + if token == '}' then + Index = Index + 2 + break + end + if CharMapTSep[token] then + if not wantSep then + missExp() + end + wantSep = false + Index = Index + 2 + goto CONTINUE + end + local lastRight = lastRightPosition() + + if peekWord() then + local savePoint = getSavePoint() + local name = parseName() + if name then + skipSpace() + if Tokens[Index + 1] == '=' then + Index = Index + 2 + if wantSep then + pushError { + type = 'MISS_SEP_IN_TABLE', + start = lastRight, + finish = getPosition(Tokens[Index], 'left'), + } + end + wantSep = true + skipSpace() + local fvalue = parseExp() + local tfield = { + type = 'tablefield', + start = name.start, + finish = name.finish, + range = fvalue and fvalue.finish, + node = tbl, + parent = tbl, + field = name, + value = fvalue, + } + name.type = 'field' + name.parent = tfield + if fvalue then + fvalue.parent = tfield + else + missExp() + end + index = index + 1 + tbl[index] = tfield + goto CONTINUE + end + end + savePoint() + end + + local exp = parseExp(true) + if exp then + if wantSep then + pushError { + type = 'MISS_SEP_IN_TABLE', + start = lastRight, + finish = exp.start, + } + end + wantSep = true + if exp.type == 'varargs' then + index = index + 1 + tbl[index] = exp + exp.parent = tbl + goto CONTINUE + end + index = index + 1 + tindex = tindex + 1 + local texp = { + type = 'tableexp', + start = exp.start, + finish = exp.finish, + tindex = tindex, + parent = tbl, + value = exp, + } + exp.parent = texp + tbl[index] = texp + goto CONTINUE + end + + if token == '[' then + if wantSep then + pushError { + type = 'MISS_SEP_IN_TABLE', + start = lastRight, + finish = getPosition(Tokens[Index], 'left'), + } + end + wantSep = true + local tindex = parseIndex() + skipSpace() + tindex.type = 'tableindex' + tindex.node = tbl + tindex.parent = tbl + index = index + 1 + tbl[index] = tindex + if expectAssign() then + skipSpace() + local ivalue = parseExp() + if ivalue then + ivalue.parent = tindex + tindex.range = ivalue.finish + tindex.value = ivalue + else + missExp() + end + else + missSymbol '=' + end + goto CONTINUE + end + + missSymbol '}' + break + ::CONTINUE:: + end + tbl.finish = lastRightPosition() + return tbl +end + +local function addDummySelf(node, call) + if node.type ~= 'getmethod' then + return + end + -- dummy param `self` + if not call.args then + call.args = { + type = 'callargs', + start = call.start, + finish = call.finish, + parent = call, + } + end + local self = { + type = 'self', + start = node.colon.start, + finish = node.colon.finish, + parent = call.args, + [1] = 'self', + } + tinsert(call.args, 1, self) +end + +local function checkAmbiguityCall(call, parenPos) + if State.version ~= 'Lua 5.1' then + return + end + local node = call.node + local nodeRow = guide.rowColOf(node.finish) + local callRow = guide.rowColOf(parenPos) + if nodeRow == callRow then + return + end + pushError { + type = 'AMBIGUOUS_SYNTAX', + start = parenPos, + finish = call.finish, + } +end + +local function parseSimple(node, funcName) + local lastMethod + while true do + if lastMethod and node.node == lastMethod then + if node.type ~= 'call' then + missSymbol('(', node.node.finish, node.node.finish) + end + lastMethod = nil + end + skipSpace() + local token = Tokens[Index + 1] + if token == '.' then + local dot = { + type = token, + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index], 'right'), + } + Index = Index + 2 + skipSpace() + local field = parseName(true) + local getfield = { + type = 'getfield', + start = node.start, + finish = lastRightPosition(), + node = node, + dot = dot, + field = field + } + if field then + field.parent = getfield + field.type = 'field' + else + pushError { + type = 'MISS_FIELD', + start = lastRightPosition(), + finish = lastRightPosition(), + } + end + node.parent = getfield + node.next = getfield + node = getfield + elseif token == ':' then + local colon = { + type = token, + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index], 'right'), + } + Index = Index + 2 + skipSpace() + local method = parseName(true) + local getmethod = { + type = 'getmethod', + start = node.start, + finish = lastRightPosition(), + node = node, + colon = colon, + method = method + } + if method then + method.parent = getmethod + method.type = 'method' + else + pushError { + type = 'MISS_METHOD', + start = lastRightPosition(), + finish = lastRightPosition(), + } + end + node.parent = getmethod + node.next = getmethod + node = getmethod + if lastMethod then + missSymbol('(', node.node.finish, node.node.finish) + end + lastMethod = getmethod + elseif token == '(' then + if funcName then + break + end + local startPos = getPosition(Tokens[Index], 'left') + local call = { + type = 'call', + start = node.start, + node = node, + } + Index = Index + 2 + local args = parseExpList() + if Tokens[Index + 1] == ')' then + call.finish = getPosition(Tokens[Index], 'right') + Index = Index + 2 + else + call.finish = lastRightPosition() + missSymbol ')' + end + if args then + args.type = 'callargs' + args.start = startPos + args.finish = call.finish + args.parent = call + call.args = args + end + addDummySelf(node, call) + checkAmbiguityCall(call, startPos) + node.parent = call + node = call + elseif token == '{' then + if funcName then + break + end + local tbl = parseTable() + local call = { + type = 'call', + start = node.start, + finish = tbl.finish, + node = node, + } + local args = { + type = 'callargs', + start = tbl.start, + finish = tbl.finish, + parent = call, + [1] = tbl, + } + call.args = args + addDummySelf(node, call) + tbl.parent = args + node.parent = call + node = call + elseif CharMapStrSH[token] then + if funcName then + break + end + local str = parseShortString() + local call = { + type = 'call', + start = node.start, + finish = str.finish, + node = node, + } + local args = { + type = 'callargs', + start = str.start, + finish = str.finish, + parent = call, + [1] = str, + } + call.args = args + addDummySelf(node, call) + str.parent = args + node.parent = call + node = call + elseif CharMapStrLH[token] then + local str = parseLongString() + if str then + if funcName then + break + end + local call = { + type = 'call', + start = node.start, + finish = str.finish, + node = node, + } + local args = { + type = 'callargs', + start = str.start, + finish = str.finish, + parent = call, + [1] = str, + } + call.args = args + addDummySelf(node, call) + str.parent = args + node.parent = call + node = call + else + local index = parseIndex() + local bstart = index.start + index.type = 'getindex' + index.start = node.start + index.node = node + node.next = index + node.parent = index + node = index + if funcName then + pushError { + type = 'INDEX_IN_FUNC_NAME', + start = bstart, + finish = index.finish, + } + end + end + else + break + end + end + if node.type == 'call' + and node.node == lastMethod then + lastMethod = nil + end + if node == lastMethod then + if funcName then + lastMethod = nil + end + end + if lastMethod then + missSymbol('(', lastMethod.finish) + end + return node +end + +local function parseVarargs() + local varargs = { + type = 'varargs', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + 2, 'right'), + } + Index = Index + 2 + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + if chunk.vararg then + if not chunk.vararg.ref then + chunk.vararg.ref = {} + end + chunk.vararg.ref[#chunk.vararg.ref+1] = varargs + varargs.node = chunk.vararg + break + end + if chunk.type == 'main' then + break + end + if chunk.type == 'function' then + pushError { + type = 'UNEXPECT_DOTS', + start = varargs.start, + finish = varargs.finish, + } + break + end + end + return varargs +end + +local function parseParen() + local pl = Tokens[Index] + local paren = { + type = 'paren', + start = getPosition(pl, 'left'), + finish = getPosition(pl, 'right') + } + Index = Index + 2 + skipSpace() + local exp = parseExp() + if exp then + paren.exp = exp + paren.finish = exp.finish + exp.parent = paren + else + missExp() + end + skipSpace() + if Tokens[Index + 1] == ')' then + paren.finish = getPosition(Tokens[Index], 'right') + Index = Index + 2 + else + missSymbol ')' + end + return paren +end + +local function getLocal(name, pos) + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + local locals = chunk.locals + if locals then + local res + for n = 1, #locals do + local loc = locals[n] + if loc.effect > pos then + break + end + if loc[1] == name then + if not res or res.effect < loc.effect then + res = loc + end + end + end + if res then + return res + end + end + end +end + +local function resolveName(node) + if not node then + return nil + end + local loc = getLocal(node[1], node.start) + if loc then + node.type = 'getlocal' + node.node = loc + if not loc.ref then + loc.ref = {} + end + loc.ref[#loc.ref+1] = node + if loc.special then + addSpecial(loc.special, node) + end + else + node.type = 'getglobal' + local env = getLocal(State.ENVMode, node.start) + if env then + node.node = env + if not env.ref then + env.ref = {} + end + env.ref[#env.ref+1] = node + end + end + local name = node[1] + if Specials[name] then + addSpecial(name, node) + else + local ospeicals = State.options.special + if ospeicals and ospeicals[name] then + addSpecial(ospeicals[name], node) + end + end + return node +end + +local function isChunkFinishToken(token) + local currentChunk = Chunk[#Chunk] + if not currentChunk then + return false + end + local tp = currentChunk.type + if tp == 'main' then + return false + end + if tp == 'for' + or tp == 'in' + or tp == 'loop' + or tp == 'function' then + return token == 'end' + end + if tp == 'if' + or tp == 'ifblock' + or tp == 'elseifblock' + or tp == 'elseblock' then + return token == 'then' + or token == 'end' + or token == 'else' + or token == 'elseif' + end + if tp == 'repeat' then + return token == 'until' + end + return true +end + +local function parseActions() + local rtn, last + while true do + skipSpace(true) + local token = Tokens[Index + 1] + if token == ';' then + Index = Index + 2 + goto CONTINUE + end + if ChunkFinishMap[token] + and isChunkFinishToken(token) then + break + end + local action, failed = parseAction() + if failed then + if not skipUnknownSymbol() then + break + end + end + if action then + if not rtn and action.type == 'return' then + rtn = action + end + last = action + end + ::CONTINUE:: + end + if rtn and rtn ~= last then + pushError { + type = 'ACTION_AFTER_RETURN', + start = rtn.start, + finish = rtn.finish, + } + end +end + +local function parseParams(params) + local lastSep + local hasDots + while true do + skipSpace() + local token = Tokens[Index + 1] + if not token or token == ')' then + if lastSep then + missName() + end + break + end + if token == ',' then + if lastSep or lastSep == nil then + missName() + else + lastSep = true + end + Index = Index + 2 + goto CONTINUE + end + if token == '...' then + if lastSep == false then + missSymbol ',' + end + lastSep = false + if not params then + params = {} + end + local vararg = { + type = '...', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + 2, 'right'), + parent = params, + [1] = '...', + } + local chunk = Chunk[#Chunk] + chunk.vararg = vararg + params[#params+1] = vararg + if hasDots then + pushError { + type = 'ARGS_AFTER_DOTS', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + 2, 'right'), + } + end + hasDots = true + Index = Index + 2 + goto CONTINUE + end + if CharMapWord[ssub(token, 1, 1)] then + if lastSep == false then + missSymbol ',' + end + lastSep = false + if not params then + params = {} + end + params[#params+1] = createLocal { + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + #token - 1, 'right'), + parent = params, + [1] = token, + } + if hasDots then + pushError { + type = 'ARGS_AFTER_DOTS', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + #token - 1, 'right'), + } + end + if isKeyWord(token) then + pushError { + type = 'KEYWORD', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + #token - 1, 'right'), + } + end + Index = Index + 2 + goto CONTINUE + end + skipUnknownSymbol '%,%)%.' + ::CONTINUE:: + end + return params +end + +local function parseFunction(isLocal, isAction) + local funcLeft = getPosition(Tokens[Index], 'left') + local funcRight = getPosition(Tokens[Index] + 7, 'right') + local func = { + type = 'function', + start = funcLeft, + finish = funcRight, + keyword = { + [1] = funcLeft, + [2] = funcRight, + }, + } + Index = Index + 2 + local LastLocalCount = LocalCount LocalCount = 0 - Version = version - Root = state.ast - if Root then - Root.state = state - end - Options = options - if type(state.ast) == 'table' then - Compile(state.ast) - end - PostCompile() - state.compileClock = os.clock() - clock - Compiled = nil - GoToTag = nil - return state + skipSpace(true) + local hasLeftParen = Tokens[Index + 1] == '(' + if not hasLeftParen then + local name = parseName() + if name then + local simple = parseSimple(name, true) + if isLocal then + if simple == name then + createLocal(name) + else + resolveName(name) + pushError { + type = 'UNEXPECT_LFUNC_NAME', + start = simple.start, + finish = simple.finish, + } + end + else + resolveName(name) + end + func.name = simple + func.finish = simple.finish + if not isAction then + simple.parent = func + pushError { + type = 'UNEXPECT_EFUNC_NAME', + start = simple.start, + finish = simple.finish, + } + end + skipSpace(true) + hasLeftParen = Tokens[Index + 1] == '(' + end + end + pushChunk(func) + local params + if func.name and func.name.type == 'getmethod' then + if func.name.type == 'getmethod' then + params = {} + params[1] = createLocal { + start = funcRight, + finish = funcRight, + parent = params, + [1] = 'self', + } + params[1].type = 'self' + end + end + if hasLeftParen then + local parenLeft = getPosition(Tokens[Index], 'left') + Index = Index + 2 + params = parseParams(params) + if params then + params.type = 'funcargs' + params.start = parenLeft + params.finish = lastRightPosition() + params.parent = func + func.args = params + end + skipSpace(true) + if Tokens[Index + 1] == ')' then + local parenRight = getPosition(Tokens[Index], 'right') + func.finish = parenRight + if params then + params.finish = parenRight + end + Index = Index + 2 + skipSpace(true) + else + func.finish = lastRightPosition() + if params then + params.finish = func.finish + end + missSymbol ')' + end + else + missSymbol '(' + end + parseActions() + popChunk() + if Tokens[Index + 1] == 'end' then + local endLeft = getPosition(Tokens[Index], 'left') + local endRight = getPosition(Tokens[Index] + 2, 'right') + func.keyword[3] = endLeft + func.keyword[4] = endRight + func.finish = endRight + Index = Index + 2 + else + func.finish = lastRightPosition() + missEnd(funcLeft, funcRight) + end + LocalCount = LastLocalCount + return func +end + +local function parseExpUnit() + local token = Tokens[Index + 1] + if token == '(' then + local paren = parseParen() + return parseSimple(paren, false) + end + + if token == '...' then + local varargs = parseVarargs() + return varargs + end + + if token == '{' then + local table = parseTable() + return table + end + + if CharMapStrSH[token] then + local string = parseShortString() + return string + end + + if CharMapStrLH[token] then + local string = parseLongString() + return string + end + + local number = parseNumber() + if number then + return number + end + + if ChunkFinishMap[token] then + return nil + end + + if token == 'nil' then + return parseNil() + end + + if token == 'true' + or token == 'false' then + return parseBoolean() + end + + if token == 'function' then + return parseFunction() + end + + local node = parseName() + if node then + return parseSimple(resolveName(node), false) + end + + return nil +end + +local function parseUnaryOP() + local token = Tokens[Index + 1] + local symbol = UnarySymbol[token] and token or UnaryAlias[token] + if not symbol then + return nil + end + local myLevel = UnarySymbol[symbol] + local op = { + type = symbol, + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + #symbol - 1, 'right'), + } + Index = Index + 2 + return op, myLevel +end + +---@param level integer # op level must greater than this level +local function parseBinaryOP(asAction, level) + local token = Tokens[Index + 1] + local symbol = (BinarySymbol[token] and token) + or BinaryAlias[token] + or (not asAction and BinaryActionAlias[token]) + if not symbol then + return nil + end + if symbol == '//' and State.options.nonstandardSymbol['//'] then + return nil + end + local myLevel = BinarySymbol[symbol] + if level and myLevel < level then + return nil + end + local op = { + type = symbol, + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + #token - 1, 'right'), + } + if not asAction then + if token == '=' then + pushError { + type = 'ERR_EQ_AS_ASSIGN', + start = op.start, + finish = op.finish, + fix = { + title = 'FIX_EQ_AS_ASSIGN', + { + start = op.start, + finish = op.finish, + text = '==', + } + } + } + end + end + if BinaryAlias[token] then + if not State.options.nonstandardSymbol[token] then + pushError { + type = 'ERR_NONSTANDARD_SYMBOL', + start = op.start, + finish = op.finish, + info = { + symbol = symbol, + }, + fix = { + title = 'FIX_NONSTANDARD_SYMBOL', + symbol = symbol, + { + start = op.start, + finish = op.finish, + text = symbol, + }, + } + } + end + end + if token == '//' + or token == '<<' + or token == '>>' then + if State.version ~= 'Lua 5.3' + and State.version ~= 'Lua 5.4' then + pushError { + type = 'UNSUPPORT_SYMBOL', + version = {'Lua 5.3', 'Lua 5.4'}, + start = op.start, + finish = op.finish, + info = { + version = State.version, + } + } + end + end + Index = Index + 2 + return op, myLevel +end + +function parseExp(asAction, level) + local exp + local uop, uopLevel = parseUnaryOP() + if uop then + skipSpace() + local child = parseExp(asAction, uopLevel) + -- 预计算负数 + if uop.type == '-' + and child + and (child.type == 'number' or child.type == 'integer') then + child.start = uop.start + child[1] = - child[1] + exp = child + else + exp = { + type = 'unary', + op = uop, + start = uop.start, + finish = child and child.finish or uop.finish, + [1] = child, + } + if child then + child.parent = exp + else + missExp() + end + end + else + exp = parseExpUnit() + if not exp then + return nil + end + end + + while true do + skipSpace() + local bop, bopLevel = parseBinaryOP(asAction, level) + if not bop then + break + end + + ::AGAIN:: + skipSpace() + local isForward = SymbolForward[bopLevel] + local child = parseExp(asAction, isForward and (bopLevel + 0.5) or bopLevel) + if not child then + if skipUnknownSymbol() then + goto AGAIN + else + missExp() + end + end + local bin = { + type = 'binary', + start = exp.start, + finish = child and child.finish or bop.finish, + op = bop, + [1] = exp, + [2] = child + } + exp.parent = bin + if child then + child.parent = bin + end + exp = bin + end + + return exp +end + +local function skipSeps() + while true do + skipSpace() + if Tokens[Index + 1] == ',' then + missExp() + Index = Index + 2 + else + break + end + end +end + +---@return parser.object? first +---@return parser.object? second +---@return parser.object[]? rest +local function parseSetValues() + skipSpace() + local first = parseExp() + if not first then + return nil + end + skipSpace() + if Tokens[Index + 1] ~= ',' then + return first + end + Index = Index + 2 + skipSeps() + local second = parseExp() + if not second then + missExp() + return first + end + skipSpace() + if Tokens[Index + 1] ~= ',' then + return first, second + end + Index = Index + 2 + skipSeps() + local third = parseExp() + if not third then + missExp() + return first, second + end + + local rest = { third } + while true do + skipSpace() + if Tokens[Index + 1] ~= ',' then + return first, second, rest + end + Index = Index + 2 + skipSeps() + local exp = parseExp() + if not exp then + missExp() + return first, second, rest + end + rest[#rest+1] = exp + end +end + +local function pushActionIntoCurrentChunk(action) + local chunk = Chunk[#Chunk] + if chunk then + chunk[#chunk+1] = action + action.parent = chunk + end +end + +---@return parser.object? second +---@return parser.object[]? rest +local function parseVarTails(parser, isLocal) + if Tokens[Index + 1] ~= ',' then + return nil + end + Index = Index + 2 + skipSpace() + local second = parser(true) + if not second then + missName() + return nil + end + if isLocal then + createLocal(second, parseLocalAttrs()) + end + skipSpace() + if Tokens[Index + 1] ~= ',' then + return second + end + Index = Index + 2 + skipSeps() + local third = parser(true) + if not third then + missName() + return second + end + if isLocal then + createLocal(third, parseLocalAttrs()) + end + local rest = { third } + while true do + skipSpace() + if Tokens[Index + 1] ~= ',' then + return second, rest + end + Index = Index + 2 + skipSeps() + local name = parser(true) + if not name then + missName() + return second, rest + end + if isLocal then + createLocal(name, parseLocalAttrs()) + end + rest[#rest+1] = name + end +end + +local function bindValue(n, v, index, lastValue, isLocal, isSet) + if isLocal then + if v and v.special then + addSpecial(v.special, n) + end + elseif isSet then + n.type = GetToSetMap[n.type] or n.type + if n.type == 'setlocal' then + local loc = n.node + if loc.attrs then + pushError { + type = 'SET_CONST', + start = n.start, + finish = n.finish, + } + end + end + end + if not v and lastValue then + if lastValue.type == 'call' + or lastValue.type == 'varargs' then + v = lastValue + if not v.extParent then + v.extParent = {} + end + end + end + if v then + if v.type == 'call' + or v.type == 'varargs' then + local select = { + type = 'select', + sindex = index, + start = v.start, + finish = v.finish, + vararg = v + } + if v.parent then + v.extParent[#v.extParent+1] = select + else + v.parent = select + end + v = select + end + n.value = v + n.range = v.finish + v.parent = n + end +end + +local function parseMultiVars(n1, parser, isLocal) + local n2, nrest = parseVarTails(parser, isLocal) + skipSpace() + local v1, v2, vrest + local isSet + local max = 1 + if expectAssign(not isLocal) then + v1, v2, vrest = parseSetValues() + isSet = true + if not v1 then + missExp() + end + end + local index = 1 + bindValue(n1, v1, index, nil, isLocal, isSet) + local lastValue = v1 + local lastVar = n1 + if n2 then + max = 2 + if not v2 then + index = 2 + end + bindValue(n2, v2, index, lastValue, isLocal, isSet) + lastValue = v2 or lastValue + lastVar = n2 + pushActionIntoCurrentChunk(n2) + end + if nrest then + for i = 1, #nrest do + local n = nrest[i] + local v = vrest and vrest[i] + max = i + 2 + if not v then + index = index + 1 + end + bindValue(n, v, index, lastValue, isLocal, isSet) + lastValue = v or lastValue + lastVar = n + pushActionIntoCurrentChunk(n) + end + end + + if isLocal then + local effect = lastValue and lastValue.finish or lastVar.finish + n1.effect = effect + if n2 then + n2.effect = effect + end + if nrest then + for i = 1, #nrest do + nrest[i].effect = effect + end + end + end + + if v2 and not n2 then + v2.redundant = { + max = max, + passed = 2, + } + pushActionIntoCurrentChunk(v2) + end + if vrest then + for i = 1, #vrest do + local v = vrest[i] + if not nrest or not nrest[i] then + v.redundant = { + max = max, + passed = i + 2, + } + pushActionIntoCurrentChunk(v) + end + end + end + + return n1, isSet +end + +local function compileExpAsAction(exp) + pushActionIntoCurrentChunk(exp) + if GetToSetMap[exp.type] then + skipSpace() + local isLocal + if exp.type == 'getlocal' and exp[1] == State.ENVMode then + exp.special = nil + local loc = createLocal(exp, parseLocalAttrs()) + loc.locPos = exp.start + loc.effect = maxinteger + isLocal = true + skipSpace() + end + local action, isSet = parseMultiVars(exp, parseExp, isLocal) + if isSet + or action.type == 'getmethod' then + return action + end + end + + if exp.type == 'call' then + if exp.node.special == 'error' then + for i = #Chunk, 1, -1 do + local block = Chunk[i] + if block.type == 'ifblock' + or block.type == 'elseifblock' + or block.type == 'elseblock' + or block.type == 'function' then + block.hasError = true + break + end + end + end + return exp + end + + if exp.type == 'binary' then + if GetToSetMap[exp[1].type] then + local op = exp.op + if op.type == '==' then + pushError { + type = 'ERR_ASSIGN_AS_EQ', + start = op.start, + finish = op.finish, + fix = { + title = 'FIX_ASSIGN_AS_EQ', + { + start = op.start, + finish = op.finish, + text = '=', + } + } + } + return + end + end + end + + pushError { + type = 'EXP_IN_ACTION', + start = exp.start, + finish = exp.finish, + } + + return exp +end + +local function parseLocal() + local locPos = getPosition(Tokens[Index], 'left') + Index = Index + 2 + skipSpace() + local word = peekWord() + if not word then + missName() + return nil + end + + if word == 'function' then + local func = parseFunction(true, true) + local name = func.name + if name then + func.name = nil + name.value = func + name.vstart = func.start + name.range = func.finish + name.locPos = locPos + func.parent = name + pushActionIntoCurrentChunk(name) + return name + else + missName(func.keyword[2]) + pushActionIntoCurrentChunk(func) + return func + end + end + + local name = parseName(true) + if not name then + missName() + return nil + end + local loc = createLocal(name, parseLocalAttrs()) + loc.locPos = locPos + loc.effect = maxinteger + pushActionIntoCurrentChunk(loc) + skipSpace() + parseMultiVars(loc, parseName, true) + + return loc +end + +local function parseDo() + local doLeft = getPosition(Tokens[Index], 'left') + local doRight = getPosition(Tokens[Index] + 1, 'right') + local obj = { + type = 'do', + start = doLeft, + finish = doRight, + keyword = { + [1] = doLeft, + [2] = doRight, + }, + } + Index = Index + 2 + pushActionIntoCurrentChunk(obj) + pushChunk(obj) + parseActions() + popChunk() + if Tokens[Index + 1] == 'end' then + obj.finish = getPosition(Tokens[Index] + 2, 'right') + obj.keyword[3] = getPosition(Tokens[Index], 'left') + obj.keyword[4] = getPosition(Tokens[Index] + 2, 'right') + Index = Index + 2 + else + missEnd(doLeft, doRight) + end + if obj.locals then + LocalCount = LocalCount - #obj.locals + end + + return obj +end + +local function parseReturn() + local returnLeft = getPosition(Tokens[Index], 'left') + local returnRight = getPosition(Tokens[Index] + 5, 'right') + Index = Index + 2 + skipSpace() + local rtn = parseExpList(true) + if rtn then + rtn.type = 'return' + rtn.start = returnLeft + else + rtn = { + type = 'return', + start = returnLeft, + finish = returnRight, + } + end + pushActionIntoCurrentChunk(rtn) + for i = #Chunk, 1, -1 do + local block = Chunk[i] + if block.type == 'function' + or block.type == 'main' then + if not block.returns then + block.returns = {} + end + block.returns[#block.returns+1] = rtn + break + end + end + for i = #Chunk, 1, -1 do + local block = Chunk[i] + if block.type == 'ifblock' + or block.type == 'elseifblock' + or block.type == 'elseblock' + or block.type == 'function' then + block.hasReturn = true + break + end + end + + return rtn +end + +local function parseLabel() + local left = getPosition(Tokens[Index], 'left') + Index = Index + 2 + skipSpace() + local label = parseName() + skipSpace() + + if not label then + missName() + end + + if Tokens[Index + 1] == '::' then + Index = Index + 2 + else + if label then + missSymbol '::' + end + end + + if not label then + return nil + end + + label.type = 'label' + pushActionIntoCurrentChunk(label) + + local block = guide.getBlock(label) + if block then + if not block.labels then + block.labels = {} + end + local name = label[1] + local olabel = guide.getLabel(block, name) + if olabel then + if State.version == 'Lua 5.4' + or block == guide.getBlock(olabel) then + pushError { + type = 'REDEFINED_LABEL', + start = label.start, + finish = label.finish, + relative = { + { + olabel.start, + olabel.finish, + } + } + } + end + end + block.labels[name] = label + end + + if State.version == 'Lua 5.1' then + pushError { + type = 'UNSUPPORT_SYMBOL', + start = left, + finish = lastRightPosition(), + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return + end + return label +end + +local function parseGoTo() + local start = getPosition(Tokens[Index], 'left') + Index = Index + 2 + skipSpace() + + local action = parseName() + if not action then + missName() + return nil + end + + action.type = 'goto' + action.keyStart = start + + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + if chunk.type == 'function' + or chunk.type == 'main' then + if not chunk.gotos then + chunk.gotos = {} + end + chunk.gotos[#chunk.gotos+1] = action + break + end + end + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + if chunk.type == 'ifblock' + or chunk.type == 'elseifblock' + or chunk.type == 'elseblock' then + chunk.hasGoTo = true + break + end + end + + pushActionIntoCurrentChunk(action) + return action +end + +local function parseIfBlock(parent) + local ifLeft = getPosition(Tokens[Index], 'left') + local ifRight = getPosition(Tokens[Index] + 1, 'right') + Index = Index + 2 + local ifblock = { + type = 'ifblock', + parent = parent, + start = ifLeft, + finish = ifRight, + keyword = { + [1] = ifLeft, + [2] = ifRight, + } + } + skipSpace() + local filter = parseExp() + if filter then + ifblock.filter = filter + ifblock.finish = filter.finish + filter.parent = ifblock + else + missExp() + end + skipSpace() + local thenToken = Tokens[Index + 1] + if thenToken == 'then' + or thenToken == 'do' then + ifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right') + ifblock.keyword[3] = getPosition(Tokens[Index], 'left') + ifblock.keyword[4] = ifblock.finish + if thenToken == 'do' then + pushError { + type = 'ERR_THEN_AS_DO', + start = ifblock.keyword[3], + finish = ifblock.keyword[4], + fix = { + title = 'FIX_THEN_AS_DO', + { + start = ifblock.keyword[3], + finish = ifblock.keyword[4], + text = 'then', + } + } + } + end + Index = Index + 2 + else + missSymbol 'then' + end + pushChunk(ifblock) + parseActions() + popChunk() + ifblock.finish = lastRightPosition() + if ifblock.locals then + LocalCount = LocalCount - #ifblock.locals + end + return ifblock +end + +local function parseElseIfBlock(parent) + local ifLeft = getPosition(Tokens[Index], 'left') + local ifRight = getPosition(Tokens[Index] + 5, 'right') + local elseifblock = { + type = 'elseifblock', + parent = parent, + start = ifLeft, + finish = ifRight, + keyword = { + [1] = ifLeft, + [2] = ifRight, + } + } + Index = Index + 2 + skipSpace() + local filter = parseExp() + if filter then + elseifblock.filter = filter + elseifblock.finish = filter.finish + filter.parent = elseifblock + else + missExp() + end + skipSpace() + local thenToken = Tokens[Index + 1] + if thenToken == 'then' + or thenToken == 'do' then + elseifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right') + elseifblock.keyword[3] = getPosition(Tokens[Index], 'left') + elseifblock.keyword[4] = elseifblock.finish + if thenToken == 'do' then + pushError { + type = 'ERR_THEN_AS_DO', + start = elseifblock.keyword[3], + finish = elseifblock.keyword[4], + fix = { + title = 'FIX_THEN_AS_DO', + { + start = elseifblock.keyword[3], + finish = elseifblock.keyword[4], + text = 'then', + } + } + } + end + Index = Index + 2 + else + missSymbol 'then' + end + pushChunk(elseifblock) + parseActions() + popChunk() + elseifblock.finish = lastRightPosition() + if elseifblock.locals then + LocalCount = LocalCount - #elseifblock.locals + end + return elseifblock +end + +local function parseElseBlock(parent) + local ifLeft = getPosition(Tokens[Index], 'left') + local ifRight = getPosition(Tokens[Index] + 3, 'right') + local elseblock = { + type = 'elseblock', + parent = parent, + start = ifLeft, + finish = ifRight, + keyword = { + [1] = ifLeft, + [2] = ifRight, + } + } + Index = Index + 2 + skipSpace() + pushChunk(elseblock) + parseActions() + popChunk() + elseblock.finish = lastRightPosition() + if elseblock.locals then + LocalCount = LocalCount - #elseblock.locals + end + return elseblock +end + +local function parseIf() + local token = Tokens[Index + 1] + local left = getPosition(Tokens[Index], 'left') + local action = { + type = 'if', + start = left, + finish = getPosition(Tokens[Index] + #token - 1, 'right'), + } + pushActionIntoCurrentChunk(action) + if token ~= 'if' then + missSymbol('if', left, left) + end + local hasElse + while true do + local word = Tokens[Index + 1] + local child + if word == 'if' then + child = parseIfBlock(action) + elseif word == 'elseif' then + child = parseElseIfBlock(action) + elseif word == 'else' then + child = parseElseBlock(action) + end + if not child then + break + end + if hasElse then + pushError { + type = 'BLOCK_AFTER_ELSE', + start = child.start, + finish = child.finish, + } + end + if word == 'else' then + hasElse = true + end + action[#action+1] = child + action.finish = child.finish + skipSpace() + end + + if Tokens[Index + 1] == 'end' then + action.finish = getPosition(Tokens[Index] + 2, 'right') + Index = Index + 2 + else + missEnd(action[1].keyword[1], action[1].keyword[2]) + end + + return action +end + +local function parseFor() + local action = { + type = 'for', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + 2, 'right'), + keyword = {}, + } + action.keyword[1] = action.start + action.keyword[2] = action.finish + Index = Index + 2 + pushActionIntoCurrentChunk(action) + pushChunk(action) + skipSpace() + local nameOrList = parseNameOrList(action) + if not nameOrList then + missName() + end + skipSpace() + -- for i = + if expectAssign() then + action.type = 'loop' + + skipSpace() + local expList = parseExpList() + local name + if nameOrList then + if nameOrList.type == 'name' then + name = nameOrList + else + name = nameOrList[1] + end + end + if name then + local loc = createLocal(name) + loc.parent = action + action.finish = name.finish + action.loc = loc + end + if expList then + expList.parent = action + local value = expList[1] + if value then + value.parent = expList + action.init = value + action.finish = expList[#expList].finish + end + local max = expList[2] + if max then + max.parent = expList + action.max = max + action.finish = max.finish + else + pushError { + type = 'MISS_LOOP_MAX', + start = lastRightPosition(), + finish = lastRightPosition(), + } + end + local step = expList[3] + if step then + step.parent = expList + action.step = step + action.finish = step.finish + end + else + pushError { + type = 'MISS_LOOP_MIN', + start = lastRightPosition(), + finish = lastRightPosition(), + } + end + + if action.loc then + action.loc.effect = action.finish + end + elseif Tokens[Index + 1] == 'in' then + action.type = 'in' + local inLeft = getPosition(Tokens[Index], 'left') + local inRight = getPosition(Tokens[Index] + 1, 'right') + Index = Index + 2 + skipSpace() + + local exps = parseExpList() + + action.finish = inRight + action.keyword[3] = inLeft + action.keyword[4] = inRight + + local list + if nameOrList and nameOrList.type == 'name' then + list = { + type = 'list', + start = nameOrList.start, + finish = nameOrList.finish, + parent = action, + [1] = nameOrList, + } + else + list = nameOrList + end + + if exps then + local lastExp = exps[#exps] + if lastExp then + action.finish = lastExp.finish + end + + action.exps = exps + exps.parent = action + for i = 1, #exps do + local exp = exps[i] + exp.parent = exps + end + else + missExp() + end + + if list then + local lastName = list[#list] + list.range = lastName and lastName.range or inRight + action.keys = list + for i = 1, #list do + local loc = createLocal(list[i]) + loc.parent = action + loc.effect = action.finish + end + end + else + missSymbol 'in' + end + + skipSpace() + local doToken = Tokens[Index + 1] + if doToken == 'do' + or doToken == 'then' then + local left = getPosition(Tokens[Index], 'left') + local right = getPosition(Tokens[Index] + #doToken - 1, 'right') + action.finish = left + action.keyword[#action.keyword+1] = left + action.keyword[#action.keyword+1] = right + if doToken == 'then' then + pushError { + type = 'ERR_DO_AS_THEN', + start = left, + finish = right, + fix = { + title = 'FIX_DO_AS_THEN', + { + start = left, + finish = right, + text = 'do', + } + } + } + end + Index = Index + 2 + else + missSymbol 'do' + end + + skipSpace() + parseActions() + popChunk() + + skipSpace() + if Tokens[Index + 1] == 'end' then + action.finish = getPosition(Tokens[Index] + 2, 'right') + action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left') + action.keyword[#action.keyword+1] = action.finish + Index = Index + 2 + else + missEnd(action.keyword[1], action.keyword[2]) + end + + if action.locals then + LocalCount = LocalCount - #action.locals + end + + return action +end + +local function parseWhile() + local action = { + type = 'while', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + 4, 'right'), + keyword = {}, + } + action.keyword[1] = action.start + action.keyword[2] = action.finish + Index = Index + 2 + + skipSpace() + local nextToken = Tokens[Index + 1] + local filter = nextToken ~= 'do' + and nextToken ~= 'then' + and parseExp() + if filter then + action.filter = filter + action.finish = filter.finish + filter.parent = action + else + missExp() + end + + skipSpace() + local doToken = Tokens[Index + 1] + if doToken == 'do' + or doToken == 'then' then + local left = getPosition(Tokens[Index], 'left') + local right = getPosition(Tokens[Index] + #doToken - 1, 'right') + action.finish = left + action.keyword[#action.keyword+1] = left + action.keyword[#action.keyword+1] = right + if doToken == 'then' then + pushError { + type = 'ERR_DO_AS_THEN', + start = left, + finish = right, + fix = { + title = 'FIX_DO_AS_THEN', + { + start = left, + finish = right, + text = 'do', + } + } + } + end + Index = Index + 2 + else + missSymbol 'do' + end + + pushActionIntoCurrentChunk(action) + pushChunk(action) + skipSpace() + parseActions() + popChunk() + + skipSpace() + if Tokens[Index + 1] == 'end' then + action.finish = getPosition(Tokens[Index] + 2, 'right') + action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left') + action.keyword[#action.keyword+1] = action.finish + Index = Index + 2 + else + missEnd(action.keyword[1], action.keyword[2]) + end + + if action.locals then + LocalCount = LocalCount - #action.locals + end + + return action +end + +local function parseRepeat() + local action = { + type = 'repeat', + start = getPosition(Tokens[Index], 'left'), + finish = getPosition(Tokens[Index] + 5, 'right'), + keyword = {}, + } + action.keyword[1] = action.start + action.keyword[2] = action.finish + Index = Index + 2 + + pushActionIntoCurrentChunk(action) + pushChunk(action) + skipSpace() + parseActions() + + skipSpace() + if Tokens[Index + 1] == 'until' then + action.finish = getPosition(Tokens[Index] + 4, 'right') + action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left') + action.keyword[#action.keyword+1] = action.finish + Index = Index + 2 + + skipSpace() + local filter = parseExp() + if filter then + action.filter = filter + filter.parent = action + else + missExp() + end + + else + missSymbol 'until' + end + + popChunk() + if action.filter then + action.finish = action.filter.finish + end + + if action.locals then + LocalCount = LocalCount - #action.locals + end + + return action +end + +local function parseBreak() + local returnLeft = getPosition(Tokens[Index], 'left') + local returnRight = getPosition(Tokens[Index] + #Tokens[Index + 1] - 1, 'right') + Index = Index + 2 + skipSpace() + local action = { + type = 'break', + start = returnLeft, + finish = returnRight, + } + + local ok + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + if chunk.type == 'function' then + break + end + if chunk.type == 'while' + or chunk.type == 'in' + or chunk.type == 'loop' + or chunk.type == 'repeat' + or chunk.type == 'for' then + if not chunk.breaks then + chunk.breaks = {} + end + chunk.breaks[#chunk.breaks+1] = action + ok = true + break + end + end + for i = #Chunk, 1, -1 do + local chunk = Chunk[i] + if chunk.type == 'ifblock' + or chunk.type == 'elseifblock' + or chunk.type == 'elseblock' then + chunk.hasBreak = true + break + end + end + if not ok and Mode == 'Lua' then + pushError { + type = 'BREAK_OUTSIDE', + start = action.start, + finish = action.finish, + } + end + + pushActionIntoCurrentChunk(action) + return action +end + +function parseAction() + local token = Tokens[Index + 1] + + if token == '::' then + return parseLabel() + end + + if token == 'local' then + return parseLocal() + end + + if token == 'if' + or token == 'elseif' + or token == 'else' then + return parseIf() + end + + if token == 'for' then + return parseFor() + end + + if token == 'do' then + return parseDo() + end + + if token == 'return' then + return parseReturn() + end + + if token == 'break' then + return parseBreak() + end + + if token == 'continue' and State.options.nonstandardSymbol['continue'] then + return parseBreak() + end + + if token == 'while' then + return parseWhile() + end + + if token == 'repeat' then + return parseRepeat() + end + + if token == 'goto' and isKeyWord 'goto' then + return parseGoTo() + end + + if token == 'function' then + local exp = parseFunction(false, true) + local name = exp.name + if name then + exp.name = nil + name.type = GetToSetMap[name.type] + name.value = exp + name.vstart = exp.start + name.range = exp.finish + exp.parent = name + if name.type == 'setlocal' then + local loc = name.node + if loc.attrs then + pushError { + type = 'SET_CONST', + start = name.start, + finish = name.finish, + } + end + end + pushActionIntoCurrentChunk(name) + return name + else + pushActionIntoCurrentChunk(exp) + missName(exp.keyword[2]) + return exp + end + end + + local exp = parseExp(true) + if exp then + local action = compileExpAsAction(exp) + if action then + return action + end + end + return nil, true +end + +local function skipFirstComment() + if Tokens[Index + 1] ~= '#' then + return + end + while true do + Index = Index + 2 + local token = Tokens[Index + 1] + if not token then + break + end + if NLMap[token] then + skipNL() + break + end + end +end + +local function parseLua() + local main = { + type = 'main', + start = 0, + finish = 0, + } + pushChunk(main) + createLocal{ + type = 'local', + start = -1, + finish = -1, + effect = -1, + parent = main, + tag = '_ENV', + special= '_G', + [1] = State.ENVMode, + } + LocalCount = 0 + skipFirstComment() + while true do + parseActions() + if Index <= #Tokens then + unknownSymbol() + Index = Index + 2 + else + break + end + end + popChunk() + main.finish = getPosition(#Lua, 'right') + + return main +end + +local function initState(lua, version, options) + Lua = lua + Line = 0 + LineOffset = 1 + LastTokenFinish = 0 + LocalCount = 0 + Chunk = {} + Tokens = tokens(lua) + Index = 1 + ---@class parser.state + ---@field uri uri + local state = { + version = version, + lua = lua, + ast = {}, + errs = {}, + comms = {}, + lines = { + [0] = 1, + }, + options = options or {}, + } + if not state.options.nonstandardSymbol then + state.options.nonstandardSymbol = {} + end + State = state + if version == 'Lua 5.1' or version == 'LuaJIT' then + state.ENVMode = '@fenv' + else + state.ENVMode = '_ENV' + end + + pushError = function (err) + local errs = state.errs + if err.finish < err.start then + err.finish = err.start + end + local last = errs[#errs] + if last then + if last.start <= err.start and last.finish >= err.finish then + return + end + end + err.level = err.level or 'Error' + errs[#errs+1] = err + return err + end + + state.pushError = pushError +end + +return function (lua, mode, version, options) + Mode = mode + initState(lua, version, options) + skipSpace() + if mode == 'Lua' then + State.ast = parseLua() + elseif mode == 'Nil' then + State.ast = parseNil() + elseif mode == 'Boolean' then + State.ast = parseBoolean() + elseif mode == 'String' then + State.ast = parseString() + elseif mode == 'Number' then + State.ast = parseNumber() + elseif mode == 'Name' then + State.ast = parseName() + elseif mode == 'Exp' then + State.ast = parseExp() + elseif mode == 'Action' then + State.ast = parseAction() + end + + if State.ast then + State.ast.state = State + end + + while true do + if Index <= #Tokens then + unknownSymbol() + Index = Index + 2 + else + break + end + end + + return State end diff --git a/script/parser/grammar.lua b/script/parser/grammar.lua deleted file mode 100644 index a28b7950..00000000 --- a/script/parser/grammar.lua +++ /dev/null @@ -1,573 +0,0 @@ -local re = require 'parser.relabel' -local m = require 'lpeglabel' -local ast = require 'parser.ast' - -local scriptBuf = '' -local compiled = {} -local defs = ast.defs - --- goto 可以作为名字,合法性之后处理 -local RESERVED = { - ['and'] = true, - ['break'] = true, - ['do'] = true, - ['else'] = true, - ['elseif'] = true, - ['end'] = true, - ['false'] = true, - ['for'] = true, - ['function'] = true, - ['if'] = true, - ['in'] = true, - ['local'] = true, - ['nil'] = true, - ['not'] = true, - ['or'] = true, - ['repeat'] = true, - ['return'] = true, - ['then'] = true, - ['true'] = true, - ['until'] = true, - ['while'] = true, -} - -defs.nl = (m.P'\r\n' + m.S'\r\n') -defs.s = m.S' \t' -defs.S = - defs.s -defs.ea = '\a' -defs.eb = '\b' -defs.ef = '\f' -defs.en = '\n' -defs.er = '\r' -defs.et = '\t' -defs.ev = '\v' -defs['nil'] = m.Cp() / function () return nil end -defs['false'] = m.Cp() / function () return false end - -defs.NotReserved = function (_, _, str) - if RESERVED[str] then - return false - end - return true -end -defs.Reserved = function (_, _, str) - if RESERVED[str] then - return true - end - return false -end -defs.None = function () end -defs.np = m.Cp() / function (n) return n+1 end -defs.NameBody = m.R('az', 'AZ', '__', '\x80\xff') * m.R('09', 'az', 'AZ', '__', '\x80\xff')^0 -defs.NoNil = function (o) - if o == nil then - return - end - return o -end - -m.setmaxstack(1000) - -local eof = re.compile '!. / %{SYNTAX_ERROR}' - -local function grammar(tag) - return function (script) - scriptBuf = script .. '\r\n' .. scriptBuf - compiled[tag] = re.compile(scriptBuf, defs) * eof - end -end - -local function errorpos(pos, err) - return { - type = 'UNKNOWN', - start = pos or 0, - finish = pos or 0, - err = err, - } -end - -grammar 'Comment' [[ -Comment <- LongComment - / '--' ShortComment -LongComment <- ({} '--[' {} {:eq: '='* :} {} '[' %nl? - {(!CommentClose .)*} - ((CommentClose / %nil) {})) - -> LongComment - / ( - {} '/*' {} %nl? - {(!'*/' .)*} - {} '*/' {} - ) - -> CLongComment -CommentClose <- {']' =eq ']'} -ShortComment <- ({} {(!%nl .)*} {}) - -> ShortComment -]] - -grammar 'Sp' [[ -Sp <- (Comment / %nl / %s)* -Sps <- (Comment / %nl / %s)+ -]] - -grammar 'Common' [[ -Word <- [a-zA-Z0-9_] -Cut <- !Word -X16 <- [a-fA-F0-9] -Rest <- (!%nl .)* - -AND <- Sp {'and'} Cut -BREAK <- Sp 'break' Cut -FALSE <- Sp 'false' Cut -GOTO <- Sp 'goto' Cut -LOCAL <- Sp 'local' Cut -NIL <- Sp 'nil' Cut -NOT <- Sp 'not' Cut -OR <- Sp {'or'} Cut -RETURN <- Sp 'return' Cut -TRUE <- Sp 'true' Cut -CONTINUE <- Sp 'continue' Cut - -DO <- Sp {} 'do' {} Cut - / Sp({} 'then' {} Cut) -> ErrDo -IF <- Sp {} 'if' {} Cut -ELSE <- Sp {} 'else' {} Cut -ELSEIF <- Sp {} 'elseif' {} Cut -END <- Sp {} 'end' {} Cut -FOR <- Sp {} 'for' {} Cut -FUNCTION <- Sp {} 'function' {} Cut -IN <- Sp {} 'in' {} Cut -REPEAT <- Sp {} 'repeat' {} Cut -THEN <- Sp {} 'then' {} Cut - / Sp({} 'do' {} Cut) -> ErrThen -UNTIL <- Sp {} 'until' {} Cut -WHILE <- Sp {} 'while' {} Cut - - -Esc <- '\' -> '' - EChar -EChar <- 'a' -> ea - / 'b' -> eb - / 'f' -> ef - / 'n' -> en - / 'r' -> er - / 't' -> et - / 'v' -> ev - / '\' - / '"' - / "'" - / %nl - / ('z' (%nl / %s)*) -> '' - / ({} 'x' {X16 X16}) -> Char16 - / ([0-9] [0-9]? [0-9]?) -> Char10 - / ('u{' {} {Word*} '}') -> CharUtf8 - -- 错误处理 - / 'x' {} -> MissEscX - / 'u' !'{' {} -> MissTL - / 'u{' Word* !'}' {} -> MissTR - / {} -> ErrEsc - -BOR <- Sp {'|'} -BXOR <- Sp {'~'} !'=' -BAND <- Sp {'&'} -Bshift <- Sp {BshiftList} -BshiftList <- '<<' - / '>>' -Concat <- Sp {'..'} -Adds <- Sp {AddsList} -AddsList <- '+' - / '-' -Muls <- Sp {MulsList} -MulsList <- '*' - / '//' - / '/' - / '%' -Unary <- Sp {} {UnaryList} -UnaryList <- NOT - / '#' - / '-' - / '~' !'=' -POWER <- Sp {'^'} - -BinaryOp <-( Sp {} {'or' / '||'} Cut - / Sp {} {'and' / '&&'} Cut - / Sp {} {'<=' / '>=' / '<'!'<' / '>'!'>' / '~=' / '==' / '!='} - / Sp {} ({} '=' {}) -> ErrEQ - / Sp {} ({} '!=' {}) -> ErrUEQ - / Sp {} {'|'} - / Sp {} {'~'} - / Sp {} {'&'} - / Sp {} {'<<' / '>>'} - / Sp {} {'..'} !'.' - / Sp {} {'+' / '-'} - / Sp {} {'*' / '//' / '/' / '%'} - / Sp {} {'^'} - )-> BinaryOp -UnaryOp <-( Sp {} {'not' Cut / '#' / '~' !'=' / '-' !'-' / '!' !'='} - )-> UnaryOp - -PL <- Sp '(' -PR <- Sp ')' -BL <- Sp '[' !'[' !'=' -BR <- Sp ']' -TL <- Sp '{' -TR <- Sp '}' -COMMA <- Sp ({} ',') - -> COMMA -SEMICOLON <- Sp ({} ';') - -> SEMICOLON -DOTS <- Sp ({} '...') - -> DOTS -DOT <- Sp ({} '.' !'.') - -> DOT -COLON <- Sp ({} ':' !':') - -> COLON -LABEL <- Sp '::' -ASSIGN <- Sp '=' !'=' - / Sp ({} {'+=' / '-=' / '*=' / '\='}) - -> ASSIGN -AssignOrEQ <- Sp ({} '==' {}) - -> ErrAssign - / ASSIGN - -DirtyBR <- BR / {} -> MissBR -DirtyTR <- TR / {} -> MissTR -DirtyPR <- PR / {} -> MissPR -DirtyLabel <- LABEL / {} -> MissLabel -NeedEnd <- END / {} -> MissEnd -NeedDo <- DO / {} -> MissDo -NeedAssign <- ASSIGN / {} -> MissAssign -NeedComma <- COMMA / {} -> MissComma -NeedIn <- IN / {} -> MissIn -NeedUntil <- UNTIL / {} -> MissUntil -NeedThen <- THEN / {} -> MissThen -]] - -grammar 'Nil' [[ -Nil <- Sp ({} -> Nil) NIL -]] - -grammar 'Boolean' [[ -Boolean <- Sp ({} -> True) TRUE - / Sp ({} -> False) FALSE -]] - -grammar 'String' [[ -String <- Sp ({} StringDef {}) - -> String -StringDef <- {'"'} - {~(Esc / !%nl !'"' .)*~} -> 1 - ('"' / {} -> MissQuote1) - / {"'"} - {~(Esc / !%nl !"'" .)*~} -> 1 - ("'" / {} -> MissQuote2) - / {'`'} - {(!%nl !'`' .)*} -> 1 - ('`' / {} -> MissQuote3) - / ('[' {} {:eq: '='* :} {} '[' %nl? - {(!StringClose .)*} -> 1 - (StringClose / {})) - -> LongString -StringClose <- ']' =eq ']' -]] - -grammar 'Number' [[ -Number <- Sp ({} {~ '-'? NumberDef ~} {}) -> Number - NumberSuffix? - ErrNumber? -NumberDef <- Number16 / Integer2 / Number10 -NumberSuffix<- ({} {[uU]? [lL] [lL]}) -> FFINumber - / ({} {[iI]}) -> ImaginaryNumber -ErrNumber <- ({} {([0-9a-zA-Z] / '.')+}) -> UnknownSymbol - -Number10 <- Float10 Float10Exp? - / Integer10 Float10? Float10Exp? -Integer10 <- [0-9]+ ('.' [0-9]*)? -Float10 <- '.' [0-9]+ -Float10Exp <- [eE] [+-]? [0-9]+ - / ({} [eE] [+-]? {}) -> MissExponent - -Number16 <- '0' [xX] Float16 Float16Exp? - / '0' [xX] Integer16 Float16? Float16Exp? -Integer16 <- X16+ ('.' X16*)? - / ({} {Word*}) -> MustX16 -Float16 <- '.' X16+ - / '.' ({} {Word*}) -> MustX16 -Float16Exp <- [pP] [+-]? [0-9]+ - / ({} [pP] [+-]? {}) -> MissExponent - -Integer2 <- ({} '0' [bB] {[01]+}) - -> Integer2 -]] - -grammar 'Name' [[ -Name <- Sp ({} NameBody {}) - -> Name -NameBody <- {%NameBody} -KeyWord <- Sp NameBody=>Reserved -MustName <- Name / DirtyName -DirtyName <- {} -> DirtyName -]] - -grammar 'DocType' [[ -DocType <- (!%nl !')' !',' DocChar)+ -DocChar <- '(' (!%nl !')' .)+ ')'? - / '<' (!%nl !'>' .)+ '>'? - / . -]] - -grammar 'Exp' [[ -Exp <- (UnUnit BinUnit*) - -> Binary -BinUnit <- (BinaryOp UnUnit?) - -> SubBinary -UnUnit <- Number - / (UnaryOp+ (ExpUnit / MissExp)) - -> Unary - / ExpUnit -ExpUnit <- Nil - / Boolean - / String - / Number - / Dots - / Table - / ExpFunction - / Simple - -Simple <- {| Prefix (Sp Suffix)* |} - -> Simple -Prefix <- Sp ({} PL DirtyExp DirtyPR {}) - -> Paren - / Single -Single <- !FUNCTION Name - -> Single -Suffix <- SuffixWithoutCall - / ({} PL SuffixCall DirtyPR {}) - -> Call -SuffixCall <- Sp ({} {| (COMMA / CallArg)+ |} {}) - -> PackExpList - / %nil -CallArg <- Sp (Name {} {'?'? ':'} Sps DocType) - -> CallArgSnip - / Exp->NoNil -SuffixWithoutCall - <- (DOT (Name / MissField)) - -> GetField - / ({} BL DirtyExp DirtyBR {}) - -> GetIndex - / (COLON (Name / MissMethod) NeedCall) - -> GetMethod - / ({} {| Table |} {}) - -> Call - / ({} {| String |} {}) - -> Call -NeedCall <- (!(Sp CallStart) {} -> MissPL)? -MissField <- {} -> MissField -MissMethod <- {} -> MissMethod -CallStart <- PL - / TL - / '"' - / "'" - / '[' '='* '[' - -DirtyExp <- !THEN !DO !END Exp - / {} -> DirtyExp -MaybeExp <- Exp / MissExp -MissExp <- {} -> MissExp -ExpList <- Sp {| MaybeExp (Sp ',' MaybeExp)* |} - -Dots <- DOTS - -> VarArgs - -Table <- Sp ({} TL {| TableField* |} DirtyTR {}) - -> Table -TableField <- COMMA - / SEMICOLON - / Dots - / NewIndex - / NewField - / TableExp -Index <- BL DirtyExp DirtyBR -NewIndex <- Sp ({} Index NeedAssign DirtyExp {}) - -> NewIndex -NewField <- Sp ({} MustName ASSIGN DirtyExp {}) - -> NewField -TableExp <- Sp ({} Exp {}) - -> TableExp - -ExpFunction <- Function - -> ExpFunction -Function <- FunctionBody - -> Function -FunctionBody - <- FUNCTION FuncName FuncArgs - {| (!END Action)* |} - NeedEnd - / FUNCTION FuncName FuncArgsMiss - {| %nil |} - NeedEnd -FuncName <- !END {| Single (Sp SuffixWithoutCall)* |} - -> Simple - / %nil - -FuncArgs <- Sp ({} PL {| FuncArg+ |} DirtyPR {}) - -> FuncArgs - / PL DirtyPR %nil -FuncArgsMiss<- {} -> MissPL DirtyPR %nil -FuncArg <- DOTS - / Name - / COMMA - --- 纯占位,修改了 `relabel.lua` 使重复定义不抛错 -Action <- !END . -]] - -grammar 'Action' [[ -Action <- Sp (CrtAction / UnkAction) -CrtAction <- Semicolon - / Do - / Break - / Return - / Label - / GoTo - / If - / For - / While - / Repeat - / NamedFunction - / LocalFunction - / Local - / Set - / Continue - / Call - / ExpInAction -UnkAction <- ({} {Word+}) - -> UnknownAction - / ({} '//' {} (LongComment / ShortComment) {}) - -> CCommentPrefix - / ({} {. (!Sps !CrtAction .)*}) - -> UnknownAction -ExpInAction <- Sp ({} Exp {}) - -> ExpInAction - -Semicolon <- Sp ';' -SimpleList <- {| Simple (Sp ',' Simple)* |} - -Do <- Sp ({} - 'do' Cut - {| (!END Action)* |} - NeedEnd) - -> Do - -Break <- Sp ({} BREAK {}) - -> Break - -Continue <- Sp ({} CONTINUE {}) - => RTContinue - -> Continue - -Return <- Sp ({} RETURN ReturnExpList {}) - -> Return -ReturnExpList - <- Sp !END !ELSEIF !ELSE {| Exp (Sp ',' MaybeExp)* |} - / Sp {| %nil |} - -Label <- Sp ({} LABEL MustName DirtyLabel {}) - -> Label - -GoTo <- Sp ({} GOTO MustName {}) - -> GoTo - -If <- Sp ({} {| IfHead IfBody* |} NeedEnd) - -> If - -IfHead <- Sp (IfPart {}) -> IfBlock - / Sp (ElseIfPart {}) -> ElseIfBlock - / Sp (ElsePart {}) -> ElseBlock -IfBody <- Sp (ElseIfPart {}) -> ElseIfBlock - / Sp (ElsePart {}) -> ElseBlock -IfPart <- IF DirtyExp NeedThen - {| (!ELSEIF !ELSE !END Action)* |} -ElseIfPart <- ELSEIF DirtyExp NeedThen - {| (!ELSEIF !ELSE !END Action)* |} -ElsePart <- ELSE - {| (!ELSEIF !ELSE !END Action)* |} - -For <- Loop / In - -Loop <- LoopBody - -> Loop -LoopBody <- FOR LoopArgs NeedDo - {} {| (!END Action)* |} - NeedEnd -LoopArgs <- MustName AssignOrEQ - ({} {| (COMMA / !DO !END Exp->NoNil)* |} {}) - -> PackLoopArgs - -In <- InBody - -> In -InBody <- FOR InNameList NeedIn InExpList NeedDo - {} {| (!END Action)* |} - NeedEnd -InNameList <- ({} {| (COMMA / !IN !DO !END Name->NoNil)* |} {}) - -> PackInNameList -InExpList <- ({} {| (COMMA / !DO !DO !END Exp->NoNil)* |} {}) - -> PackInExpList - -While <- WhileBody - -> While -WhileBody <- WHILE DirtyExp NeedDo - {| (!END Action)* |} - NeedEnd - -Repeat <- (RepeatBody {}) - -> Repeat -RepeatBody <- REPEAT - {| (!UNTIL Action)* |} - NeedUntil DirtyExp - -LocalAttr <- {| (Sp '<' Sp MustName Sp LocalAttrEnd)+ |} - -> LocalAttr -LocalAttrEnd<- ({} '>' &'=') -> MissSpaceBetween - / '>' - / {} -> MissGT -Local <- Sp ({} LOCAL LocalNameList ((AssignOrEQ ExpList) / %nil) {}) - -> Local -Set <- Sp ({} SimpleList AssignOrEQ {} ExpList {}) - -> Set -LocalNameList - <- {| LocalName (Sp ',' LocalName)* |} -LocalName <- (MustName LocalAttr?) - -> LocalName - -NamedFunction - <- Function - -> NamedFunction - -Call <- Simple - -> SimpleCall - -LocalFunction - <- Sp ({} LOCAL Function) - -> LocalFunction -]] - -grammar 'Lua' [[ -Lua <- Head? - ({} {| Action* |} {}) -> Lua - Sp -Head <- '#' (!%nl .)* -]] - -return function (lua, mode) - local gram = compiled[mode] or compiled['Lua'] - local r, _, pos = gram:match(lua) - if not r then - local err = errorpos(pos) - return nil, err - end - if type(r) ~= 'table' then - return nil - end - - return r -end diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 06169b09..e4faf47f 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -4,15 +4,18 @@ local type = type ---@class parser.object ---@field bindDocs parser.object[] ---@field bindGroup parser.object[] ----@field bindSources parser.object[] +---@field bindSource parser.object ---@field value parser.object ---@field parent parser.object ---@field type string ---@field special string ---@field tag string ----@field args parser.object[] +---@field args { [integer]: parser.object, start: integer, finish: integer } ---@field locals parser.object[] ----@field returns parser.object[] +---@field returns? parser.object[] +---@field breaks? parser.object[] +---@field exps parser.object[] +---@field keys parser.object ---@field uri uri ---@field start integer ---@field finish integer @@ -42,6 +45,7 @@ local type = type ---@field exp parser.object ---@field alias parser.object ---@field class parser.object +---@field enum parser.object ---@field vararg parser.object ---@field param parser.object ---@field overload parser.object @@ -49,6 +53,8 @@ local type = type ---@field upvalues table<string, string[]> ---@field ref parser.object[] ---@field returnIndex integer +---@field assignIndex integer +---@field docIndex integer ---@field docs parser.object[] ---@field state table ---@field comment table @@ -65,6 +71,8 @@ local type = type ---@field hasGoTo? true ---@field hasReturn? true ---@field hasBreak? true +---@field hasError? true +---@field [integer] parser.object|any ---@field _root parser.object ---@class guide @@ -134,6 +142,7 @@ local childMap = { ['doc.class'] = {'class', '#extends', '#signs', 'comment'}, ['doc.type'] = {'#types', 'name', 'comment'}, ['doc.alias'] = {'alias', 'extends', 'comment'}, + ['doc.enum'] = {'enum', 'extends', 'comment'}, ['doc.param'] = {'param', 'extends', 'comment'}, ['doc.return'] = {'#returns', 'comment'}, ['doc.field'] = {'field', 'extends', 'comment'}, @@ -154,6 +163,7 @@ local childMap = { ['doc.as'] = {'as'}, ['doc.cast'] = {'loc', '#casts'}, ['doc.cast.block'] = {'extends'}, + ['doc.operator'] = {'op', 'exp', 'extends'} } ---@type table<string, fun(obj: parser.object, list: parser.object[])> @@ -194,6 +204,43 @@ end return f end}) +local eachChildMap = setmetatable({}, {__index = function (self, name) + local defs = childMap[name] + if not defs then + self[name] = false + return false + end + local text = {} + text[#text+1] = 'local obj, callback = ...' + for _, def in ipairs(defs) do + if def == '#' then + text[#text+1] = [[ +for i = 1, #obj do + callback(obj[i]) +end +]] + elseif type(def) == 'string' and def:sub(1, 1) == '#' then + local key = def:sub(2) + text[#text+1] = ([[ +local childs = obj.%s +if childs then + for i = 1, #childs do + callback(childs[i]) + end +end +]]):format(key) + elseif type(def) == 'string' then + text[#text+1] = ('callback(obj.%s)'):format(def) + else + text[#text+1] = ('callback(obj[%q])'):format(def) + end + end + local buf = table.concat(text, '\n') + local f = load(buf, buf, 't') + self[name] = f + return f +end}) + m.actionMap = { ['main'] = {'#'}, ['repeat'] = {'#'}, @@ -209,34 +256,8 @@ m.actionMap = { ['funcargs'] = {'#'}, } -local inf = 1 / 0 -local nan = 0 / 0 - -local function isInteger(n) - if math.type then - return math.type(n) == 'integer' - else - return type(n) == 'number' and n % 1 == 0 - end -end - -local function formatNumber(n) - if n == inf - or n == -inf - or n == nan - or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则 - return ('%q'):format(n) - end - if isInteger(n) then - return tostring(n) - end - local str = ('%.10f'):format(n) - str = str:gsub('%.?0*$', '') - return str -end - --- 是否是字面量 ----@param obj parser.object +---@param obj table ---@return boolean function m.isLiteral(obj) local tp = obj.type @@ -252,6 +273,8 @@ function m.isLiteral(obj) or tp == 'doc.type.string' or tp == 'doc.type.integer' or tp == 'doc.type.boolean' + or tp == 'doc.type.code' + or tp == 'doc.type.array' end --- 获取字面量 @@ -273,7 +296,7 @@ end --- 寻找父函数 ---@param obj parser.object ----@return parser.object +---@return parser.object? function m.getParentFunction(obj) for _ = 1, 10000 do obj = obj.parent @@ -290,7 +313,7 @@ end --- 寻找所在区块 ---@param obj parser.object ----@return parser.object +---@return parser.object? function m.getBlock(obj) for _ = 1, 10000 do if not obj then @@ -319,7 +342,7 @@ end --- 寻找所在父区块 ---@param obj parser.object ----@return parser.object +---@return parser.object? function m.getParentBlock(obj) for _ = 1, 10000 do obj = obj.parent @@ -336,7 +359,7 @@ end --- 寻找所在可break的父区块 ---@param obj parser.object ----@return parser.object +---@return parser.object? function m.getBreakBlock(obj) for _ = 1, 10000 do obj = obj.parent @@ -373,7 +396,7 @@ end --- 寻找所在父类型 ---@param obj parser.object ----@return parser.object +---@return parser.object? function m.getParentType(obj, want) for _ = 1, 10000 do obj = obj.parent @@ -406,8 +429,7 @@ function m.getRoot(obj) end local parent = obj.parent if not parent then - log.error('Can not find out root:', obj.type) - return nil + error('Can not find out root:' .. tostring(obj.type)) end obj = parent end @@ -436,56 +458,54 @@ function m.getENV(source, start) or m.getLocal(source, '@fenv', start) end ---- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。 ---- 如果函数是主函数,则返回`0, nil`。 ----@return table ----@return integer -function m.getFunctionVarArgs(func) - if func.type == 'main' then - return 0, nil - end - if func.type ~= 'function' then - return nil, nil - end - local args = func.args - if not args then - return nil, nil - end - for i = 1, #args do - local arg = args[i] - if arg.type == '...' then - return i, arg - end - end - return nil, nil -end - --- 获取指定区块中可见的局部变量 ---@param source parser.object ---@param name string # 变量名 ---@param pos integer # 可见位置 ---@return parser.object? function m.getLocal(source, name, pos) - local root = m.getRoot(source) - local res - m.eachSourceContain(root, pos, function (src) - local locals = src.locals - if not locals then - return + local block = source + -- find nearest source + for _ = 1, 10000 do + if not block then + return nil end - for i = 1, #locals do - local loc = locals[i] - if loc.effect > pos then - break - end - if loc[1] == name then - if not res or res.effect < loc.effect then - res = loc + if block.start <= pos + and block.finish >= pos + and blockTypes[block.type] then + break + end + block = block.parent + end + + m.eachSourceContain(block, pos, function (src) + if blockTypes[src.type] + and (src.finish - src.start) < (block.finish - src.start) then + block = src + end + end) + + for _ = 1, 10000 do + if not block then + break + end + local res + if block.locals then + for _, loc in ipairs(block.locals) do + if loc[1] == name + and loc.effect <= pos then + if not res or res.effect < loc.effect then + res = loc + end end end end - end) - return res + if res then + return res + end + block = block.parent + end + return nil end --- 获取指定区块中所有的可见局部变量名称 @@ -507,25 +527,25 @@ function m.getVisibleLocals(block, pos) end --- 获取指定区块中可见的标签 ----@param block table ----@param name string {comment = '标签名'} +---@param block parser.object +---@param name string function m.getLabel(block, name) - block = m.getBlock(block) + local current = m.getBlock(block) for _ = 1, 10000 do - if not block then + if not current then return nil end - local labels = block.labels + local labels = current.labels if labels then local label = labels[name] if label then return label end end - if block.type == 'function' then + if current.type == 'function' then return nil end - block = m.getParentBlock(block) + current = m.getParentBlock(current) end error('guide.getLocal overstack') end @@ -749,6 +769,16 @@ function m.eachSource(ast, callback) end end +---@param source parser.object +---@param callback fun(src: parser.object) +function m.eachChild(source, callback) + local f = eachChildMap[source.type] + if not f then + return + end + f(source, callback) +end + --- 获取指定的 special function m.eachSpecialOf(ast, name, callback) local root = m.getRoot(ast) @@ -797,6 +827,8 @@ function m.positionToOffset(state, position) return m.positionToOffsetByLines(state.lines, position) end +---@param lines integer[] +---@param offset integer function m.offsetToPositionByLines(lines, offset) local left = 0 local right = #lines @@ -854,11 +886,14 @@ local isSetMap = { ['label'] = true, ['doc.class'] = true, ['doc.alias'] = true, + ['doc.enum'] = true, ['doc.field'] = true, ['doc.class.name'] = true, ['doc.alias.name'] = true, + ['doc.enum.name'] = true, ['doc.field.name'] = true, ['doc.type.field'] = true, + ['doc.type.array'] = true, } function m.isSet(source) local tp = source.type @@ -972,6 +1007,8 @@ function m.getKeyName(obj) return obj.class[1] elseif tp == 'doc.alias' then return obj.alias[1] + elseif tp == 'doc.enum' then + return obj.enum[1] elseif tp == 'doc.field' then return obj.field[1] elseif tp == 'doc.field.name' then @@ -1035,6 +1072,8 @@ function m.getKeyType(obj) return 'string' elseif tp == 'doc.alias' then return 'string' + elseif tp == 'doc.enum' then + return 'string' elseif tp == 'doc.field' then return type(obj.field[1]) elseif tp == 'doc.type.field' then @@ -1058,9 +1097,9 @@ end --- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。 ---@param a table ---@param b table ----@return string|boolean mode ----@return table pathA? ----@return table pathB? +---@return string|false mode +---@return table? pathA +---@return table? pathB function m.getPath(a, b, sameFunction) --- 首先测试双方在同一个函数内 if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then @@ -1084,16 +1123,20 @@ function m.getPath(a, b, sameFunction) local pathB = {} for _ = 1, 1000 do objA = m.getParentBlock(objA) - pathA[#pathA+1] = objA - if (not sameFunction and objA.type == 'function') or objA.type == 'main' then - break + if objA then + pathA[#pathA+1] = objA + if (not sameFunction and objA.type == 'function') or objA.type == 'main' then + break + end end end for _ = 1, 1000 do objB = m.getParentBlock(objB) - pathB[#pathB+1] = objB - if (not sameFunction and objA.type == 'function') or objB.type == 'main' then - break + if objB then + pathB[#pathB+1] = objB + if (not sameFunction and objB.type == 'function') or objB.type == 'main' then + break + end end end -- pathA: {1, 2, 3, 4, 5} @@ -1108,7 +1151,7 @@ function m.getPath(a, b, sameFunction) end end if not start then - return nil + return false end -- pathA: { 1, 2, 3} -- pathB: {5, 6, 2, 3} @@ -1201,6 +1244,15 @@ function m.isInString(ast, position) end) end +function m.isInComment(ast, offset) + for _, com in ipairs(ast.state.comms) do + if offset >= com.start and offset <= com.finish then + return true + end + end + return false +end + function m.isOOP(source) if source.type == 'setmethod' or source.type == 'getmethod' then @@ -1221,6 +1273,7 @@ local basicTypeMap = { ['false'] = true, ['nil'] = true, ['boolean'] = true, + ['integer'] = true, ['number'] = true, ['string'] = true, ['table'] = true, @@ -1235,4 +1288,10 @@ function m.isBasicType(str) return basicTypeMap[str] == true end +---@param source parser.object +---@return boolean +function m.isBlockType(source) + return blockTypes[source.type] == true +end + return m diff --git a/script/parser/init.lua b/script/parser/init.lua index 219f8900..bc004f77 100644 --- a/script/parser/init.lua +++ b/script/parser/init.lua @@ -1,13 +1,8 @@ local api = { - grammar = require 'parser.grammar', - parse = require 'parser.parse', compile = require 'parser.compile', - split = require 'parser.split', - calcline = require 'parser.calcline', lines = require 'parser.lines', guide = require 'parser.guide', luadoc = require 'parser.luadoc', - tokens = require 'parser.tokens', } return api diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index d8e31950..51161565 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -1,22 +1,28 @@ local m = require 'lpeglabel' local re = require 'parser.relabel' local guide = require 'parser.guide' -local parser = require 'parser.newparser' +local compile = require 'parser.compile' local util = require 'utility' local TokenTypes, TokenStarts, TokenFinishs, TokenContents, TokenMarks -local Ci, Offset, pushWarning, NextComment, Lines +---@type integer +local Ci +---@type integer +local Offset +local pushWarning, NextComment, Lines local parseType, parseTypeUnit ---@type any local Parser = re.compile([[ Main <- (Token / Sp)* Sp <- %s+ X16 <- [a-fA-F0-9] -Token <- Integer / Name / String / Symbol +Token <- Integer / Name / String / Code / Symbol Name <- ({} {%name} {}) -> Name -Integer <- ({} {[0-9]+} !'.' {}) +Integer <- ({} {'-'? [0-9]+} !'.' {}) -> Integer +Code <- ({} '`' { (!'`' .)*} '`' {}) + -> Code String <- ({} StringDef {}) -> String StringDef <- {'"'} @@ -48,7 +54,7 @@ EChar <- 'a' -> ea / ([0-9] [0-9]? [0-9]?) -> Char10 / ('u{' {X16*} '}') -> CharUtf8 Symbol <- ({} { - [:|,<>()?+#`{}] + [:|,;<>()?+#{}] / '[]' / '...' / '[' @@ -67,6 +73,7 @@ Symbol <- ({} { ev = '\v', name = (m.R('az', 'AZ', '09', '\x80\xff') + m.S('_')) * (m.R('az', 'AZ', '__', '09', '\x80\xff') + m.S('_.*-'))^0, Char10 = function (char) + ---@type integer? char = tonumber(char) if not char or char < 0 or char > 255 then return '' @@ -114,6 +121,13 @@ Symbol <- ({} { TokenFinishs[Ci] = finish - 1 TokenContents[Ci] = math.tointeger(content) end, + Code = function (start, content, finish) + Ci = Ci + 1 + TokenTypes[Ci] = 'code' + TokenStarts[Ci] = start + TokenFinishs[Ci] = finish - 1 + TokenContents[Ci] = content + end, Symbol = function (start, content, finish) Ci = Ci + 1 TokenTypes[Ci] = 'symbol' @@ -128,10 +142,12 @@ Symbol <- ({} { ---@field signs parser.object[] ---@field originalComment parser.object ---@field as? parser.object - -local function trim(str) - return str:match '^%s*(%S+)%s*$' -end +---@field touch? integer +---@field module? string +---@field async? boolean +---@field versions? table[] +---@field names? parser.object[] +---@field path? string local function parseTokens(text, offset) Ci = 0 @@ -149,11 +165,13 @@ local function peekToken() return TokenTypes[Ci+1], TokenContents[Ci+1] end +---@return string? tokenType +---@return string? tokenContent local function nextToken() Ci = Ci + 1 if not TokenTypes[Ci] then Ci = Ci - 1 - return nil + return nil, nil end return TokenTypes[Ci], TokenContents[Ci] end @@ -171,6 +189,7 @@ local function getStart() return TokenStarts[Ci] + Offset end +---@return integer local function getFinish() if Ci == 0 then return Offset @@ -273,6 +292,11 @@ local function parseTable(parent) } do + local needCloseParen + if checkToken('symbol', '(', 1) then + nextToken() + needCloseParen = true + end field.name = parseName('doc.field.name', field) or parseIndexField('doc.field.name', field) if not field.name then @@ -299,10 +323,14 @@ local function parseTable(parent) break end field.finish = getFinish() + if needCloseParen then + nextSymbolOrError ')' + end end typeUnit.fields[#typeUnit.fields+1] = field - if checkToken('symbol', ',', 1) then + if checkToken('symbol', ',', 1) + or checkToken('symbol', ';', 1) then nextToken() else nextSymbolOrError('}') @@ -412,11 +440,35 @@ local function parseTypeUnitFunction(parent) end if checkToken('symbol', ':', 1) then nextToken() + local needCloseParen + if checkToken('symbol', '(', 1) then + nextToken() + needCloseParen = true + end while true do + local name + try(function () + local returnName = parseName('doc.return.name', typeUnit) + or parseDots('doc.return.name', typeUnit) + if not returnName then + return false + end + if checkToken('symbol', ':', 1) then + nextToken() + name = returnName + return true + end + if returnName[1] == '...' then + name = returnName + return false + end + return false + end) local rtn = parseType(typeUnit) if not rtn then break end + rtn.name = name if checkToken('symbol', '?', 1) then nextToken() rtn.optional = true @@ -428,6 +480,9 @@ local function parseTypeUnitFunction(parent) break end end + if needCloseParen then + nextSymbolOrError ')' + end end typeUnit.finish = getFinish() return typeUnit @@ -534,6 +589,22 @@ local function parseString(parent) return str end +local function parseCode(parent) + local tp, content = peekToken() + if not tp or tp ~= 'code' then + return nil + end + nextToken() + local code = { + type = 'doc.type.code', + start = getStart(), + finish = getFinish(), + parent = parent, + [1] = content, + } + return code +end + local function parseInteger(parent) local tp, content = peekToken() if not tp or tp ~= 'integer' then @@ -584,22 +655,18 @@ function parseTypeUnit(parent) local result = parseFunction(parent) or parseTable(parent) or parseString(parent) + or parseCode(parent) or parseInteger(parent) or parseBoolean(parent) - or parseDots('doc.type.name', parent) or parseParen(parent) if not result then - local literal = checkToken('symbol', '`', 1) - if literal then - nextToken() - end result = parseName('doc.type.name', parent) + or parseDots('doc.type.name', parent) if not result then return nil end - if literal then - result.literal = true - nextSymbolOrError '`' + if result[1] == '...' then + result[1] = 'unknown' end end while true do @@ -749,8 +816,9 @@ local docSwitch = util.switch() : case 'class' : call(function () local result = { - type = 'doc.class', - fields = {}, + type = 'doc.class', + fields = {}, + operators = {}, } result.class = parseName('doc.class.name', result) if not result.class then @@ -793,7 +861,20 @@ local docSwitch = util.switch() end) : case 'type' : call(function () - return parseType() + local first = parseType() + if not first then + return nil + end + local rests + while checkToken('symbol', ',', 1) do + nextToken() + local rest = parseType() + if not rests then + rests = {} + end + rests[#rests+1] = rest + end + return first, rests end) : case 'alias' : call(function () @@ -864,6 +945,10 @@ local docSwitch = util.switch() returns = {}, } while true do + local dots = parseDots('doc.return.name') + if dots then + Ci = Ci - 1 + end local docType = parseType(result) if not docType then break @@ -875,7 +960,13 @@ local docSwitch = util.switch() nextToken() docType.optional = true end - docType.name = parseName('doc.return.name', docType) + if dots then + docType.name = dots + dots.parent = docType + else + docType.name = parseName('doc.return.name', docType) + or parseDots('doc.return.name', docType) + end result.returns[#result.returns+1] = docType if not checkToken('symbol', ',', 1) then break @@ -1250,8 +1341,7 @@ local docSwitch = util.switch() if checkToken('symbol', '?', 1) then block.optional = true nextToken() - block.start = block.start or getStart() - block.finish = block.finish + block.finish = getFinish() else block.extends = parseType(block) if block.extends then @@ -1263,6 +1353,7 @@ local docSwitch = util.switch() if block.optional or block.extends then result.casts[#result.casts+1] = block end + result.finish = block.finish if checkToken('symbol', ',', 1) then nextToken() @@ -1273,8 +1364,87 @@ local docSwitch = util.switch() return result end) + : case 'operator' + : call(function () + local result = { + type = 'doc.operator', + start = getFinish(), + finish = getFinish(), + } + + local op = parseName('doc.operator.name', result) + if not op then + pushWarning { + type = 'LUADOC_MISS_OPERATOR_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.op = op + result.finish = op.finish + + if checkToken('symbol', '(', 1) then + nextToken() + local exp = parseType(result) + if exp then + result.exp = exp + result.finish = exp.finish + end + nextSymbolOrError ')' + end + + nextSymbolOrError ':' + + local ret = parseType(result) + if ret then + result.extends = ret + result.finish = ret.finish + end -local function convertTokens() + return result + end) + : case 'source' + : call(function (doc) + local fullSource = doc:sub(#'source' + 1) + if not fullSource or fullSource == '' then + return + end + fullSource = util.trim(fullSource) + if fullSource == '' then + return + end + local source, line, char = fullSource:match('^(.-):?(%d*):?(%d*)$') + source = source or fullSource + line = tonumber(line) or 1 + char = tonumber(char) or 0 + local result = { + type = 'doc.source', + start = getStart(), + finish = getFinish(), + path = source, + line = line, + char = char, + } + return result + end) + : case 'enum' + : call(function () + local name = parseName('doc.enum.name') + if not name then + return nil + end + local result = { + type = 'doc.enum', + start = name.start, + finish = name.finish, + enum = name, + } + name.parent = result + return result + end) + +local function convertTokens(doc) local tp, text = nextToken() if not tp then return @@ -1287,7 +1457,7 @@ local function convertTokens() } return nil end - return docSwitch(text) + return docSwitch(text, doc) end local function trimTailComment(text) @@ -1302,7 +1472,7 @@ local function trimTailComment(text) comment = text:sub(3) end if comment:find '^%s*[\'"[]' then - local state = parser(comment:gsub('^%s+', ''), 'String') + local state = compile(comment:gsub('^%s+', ''), 'String') if state and state.ast then comment = state.ast[1] end @@ -1327,10 +1497,17 @@ local function buildLuaDoc(comment) local doc = text:sub(startPos) parseTokens(doc, comment.start + startPos) - local result = convertTokens() + local result, rests = convertTokens(doc) if result then result.range = comment.finish - local cstart = text:find('%S', (result.firstFinish or result.finish) - comment.start) + local finish = result.firstFinish or result.finish + if rests then + for _, rest in ipairs(rests) do + rest.range = comment.finish + finish = rest.firstFinish or result.finish + end + end + local cstart = text:find('%S', finish - comment.start) if cstart and cstart < comment.finish then result.comment = { type = 'doc.tailcomment', @@ -1339,11 +1516,16 @@ local function buildLuaDoc(comment) parent = result, text = trimTailComment(text:sub(cstart)), } + if rests then + for _, rest in ipairs(rests) do + rest.comment = result.comment + end + end end end if result then - return result + return result, rests end return { @@ -1355,37 +1537,54 @@ local function buildLuaDoc(comment) } end -local function isTailComment(text, binded) - local lastDoc = binded[#binded] - local left = lastDoc.originalComment.start +local function isTailComment(text, doc) + if not doc then + return false + end + local left = doc.originalComment.start local row, col = guide.rowColOf(left) local lineStart = Lines[row] or 0 local hasCodeBefore = text:sub(lineStart, lineStart + col):find '[%w_]' return hasCodeBefore end -local function isNextLine(binded, doc) - if not binded then +local function isContinuedDoc(lastDoc, nextDoc) + if not nextDoc then return false end - local lastDoc = binded[#binded] + if nextDoc.type == 'doc.diagnostic' then + return true + end if lastDoc.type == 'doc.type' - or lastDoc.type == 'doc.module' then - return false + or lastDoc.type == 'doc.module' + or lastDoc.type == 'doc.enum' then + if nextDoc.type ~= 'doc.comment' then + return false + end end if lastDoc.type == 'doc.class' - or lastDoc.type == 'doc.field' then - if doc.type ~= 'doc.field' - and doc.type ~= 'doc.comment' - and doc.type ~= 'doc.overload' then + or lastDoc.type == 'doc.field' + or lastDoc.type == 'doc.operator' then + if nextDoc.type ~= 'doc.field' + and nextDoc.type ~= 'doc.operator' + and nextDoc.type ~= 'doc.comment' + and nextDoc.type ~= 'doc.overload' + and nextDoc.type ~= 'doc.source' then return false end end - if doc.type == 'doc.cast' then + if nextDoc.type == 'doc.cast' then + return false + end + return true +end + +local function isNextLine(lastDoc, nextDoc) + if not nextDoc then return false end local lastRow = guide.rowColOf(lastDoc.finish) - local newRow = guide.rowColOf(doc.start) + local newRow = guide.rowColOf(nextDoc.start) return newRow - lastRow == 1 end @@ -1408,6 +1607,7 @@ local function bindGeneric(binded) end end if doc.type == 'doc.param' + or doc.type == 'doc.vararg' or doc.type == 'doc.return' or doc.type == 'doc.type' or doc.type == 'doc.class' @@ -1418,11 +1618,115 @@ local function bindGeneric(binded) src.type = 'doc.generic.name' end end) + guide.eachSourceType(doc, 'doc.type.code', function (src) + local name = src[1] + if generics[name] then + src.type = 'doc.generic.name' + src.literal = true + end + end) + end + end +end + +local function bindDoc(source, binded) + local isParam = source.type == 'self' + or source.type == 'local' + and (source.parent.type == 'funcargs' + or ( source.parent.type == 'in' + and source.finish <= source.parent.keys.finish + ) + ) + local ok = false + for _, doc in ipairs(binded) do + if doc.bindSource then + goto CONTINUE + end + if doc.type == 'doc.class' + or doc.type == 'doc.deprecated' + or doc.type == 'doc.version' + or doc.type == 'doc.module' + or doc.type == 'doc.source' then + if source.type == 'function' + or isParam then + goto CONTINUE + end + elseif doc.type == 'doc.type' then + if source.type == 'function' + or isParam + or source._bindedDocType then + goto CONTINUE + end + source._bindedDocType = true + elseif doc.type == 'doc.overload' then + if not source.bindDocs then + source.bindDocs = {} + end + source.bindDocs[#source.bindDocs+1] = doc + if source.type ~= 'function' then + doc.bindSource = source + end + elseif doc.type == 'doc.param' then + local suc + if isParam + and doc.param[1] == source[1] then + suc = true + elseif source.type == '...' + and doc.param[1] == '...' then + suc = true + elseif source.type == 'self' + and doc.param[1] == 'self' then + suc = true + end + if source.type == 'function' then + if not source.bindDocs then + source.bindDocs = {} + end + source.bindDocs[#source.bindDocs+1] = doc + end + + if not suc then + goto CONTINUE + end + elseif doc.type == 'doc.vararg' then + if source.type ~= '...' then + goto CONTINUE + end + elseif doc.type == 'doc.return' + or doc.type == 'doc.generic' + or doc.type == 'doc.async' + or doc.type == 'doc.nodiscard' then + if source.type ~= 'function' then + goto CONTINUE + end + elseif doc.type == 'doc.enum' then + if source.type == 'table' then + goto OK + end + if source.value and source.value.type == 'table' then + if not source.value.bindDocs then + source.value.bindDocs = {} + end + source.value.bindDocs[#source.value.bindDocs+1] = doc + doc.bindSource = source.value + end + goto CONTINUE + elseif doc.type ~= 'doc.comment' then + goto CONTINUE end + ::OK:: + if not source.bindDocs then + source.bindDocs = {} + end + source.bindDocs[#source.bindDocs+1] = doc + doc.bindSource = source + ok = true + ::CONTINUE:: end + return ok end -local function bindDocsBetween(sources, binded, bindSources, start, finish) +local function bindDocsBetween(sources, binded, start, finish) -- 用二分法找到第一个 local max = #sources local index @@ -1445,6 +1749,7 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish) end end + local ok = false -- 从前往后进行绑定 for i = index, max do local src = sources[i] @@ -1452,12 +1757,6 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish) if src.start >= finish then break end - -- 遇到table后中断,处理以下情况: - -- ---@type AAA - -- local t = {x = 1, y = 2} - if src.type == 'table' then - break - end if src.start >= start then if src.type == 'local' or src.type == 'self' @@ -1468,13 +1767,18 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish) or src.type == 'setfield' or src.type == 'setindex' or src.type == 'setmethod' - or src.type == 'function' then - src.bindDocs = binded - bindSources[#bindSources+1] = src + or src.type == 'function' + or src.type == 'table' + or src.type == '...' then + if bindDoc(src, binded) then + ok = true + end end end end end + + return ok end local function bindReturnIndex(binded) @@ -1489,25 +1793,64 @@ local function bindReturnIndex(binded) end end -local function bindClassAndFields(binded) +local function bindCommentsToDoc(doc, comments) + doc.bindComments = comments + for _, comment in ipairs(comments) do + comment.bindSource = doc + end +end + +local function bindCommentsAndFields(binded) local class + local comments = {} + local source for _, doc in ipairs(binded) do if doc.type == 'doc.class' then -- 多个class连续写在一起,只有最后一个class可以绑定source if class then - class.bindSources = nil + class.bindSource = nil + end + if source then + doc.source = source + source.bindSource = doc end class = doc + bindCommentsToDoc(doc, comments) + comments = {} elseif doc.type == 'doc.field' then if class then class.fields[#class.fields+1] = doc doc.class = class end - end + if source then + doc.source = source + source.bindSource = doc + end + bindCommentsToDoc(doc, comments) + comments = {} + elseif doc.type == 'doc.operator' then + if class then + class.operators[#class.operators+1] = doc + doc.class = class + end + bindCommentsToDoc(doc, comments) + comments = {} + elseif doc.type == 'doc.alias' + or doc.type == 'doc.enum' then + bindCommentsToDoc(doc, comments) + comments = {} + elseif doc.type == 'doc.comment' then + comments[#comments+1] = doc + elseif doc.type == 'doc.source' then + source = doc + goto CONTINUE + end + source = nil + ::CONTINUE:: end end -local function bindDoc(sources, binded) +local function bindDocWithSources(sources, binded) if not binded then return end @@ -1515,19 +1858,17 @@ local function bindDoc(sources, binded) if not lastDoc then return end - local bindSources = {} for _, doc in ipairs(binded) do doc.bindGroup = binded - doc.bindSources = bindSources end bindGeneric(binded) + bindCommentsAndFields(binded) + bindReturnIndex(binded) local row = guide.rowColOf(lastDoc.finish) - bindDocsBetween(sources, binded, bindSources, guide.positionOf(row, 0), lastDoc.start) - if #bindSources == 0 then - bindDocsBetween(sources, binded, bindSources, guide.positionOf(row + 1, 0), guide.positionOf(row + 2, 0)) + local suc = bindDocsBetween(sources, binded, guide.positionOf(row, 0), lastDoc.start) + if not suc then + bindDocsBetween(sources, binded, guide.positionOf(row + 1, 0), guide.positionOf(row + 2, 0)) end - bindReturnIndex(binded) - bindClassAndFields(binded) end local bindDocAccept = { @@ -1547,19 +1888,44 @@ local function bindDocs(state) return a.start < b.start end) local binded - for _, doc in ipairs(state.ast.docs) do - if not isNextLine(binded, doc) then - bindDoc(sources, binded) + for i, doc in ipairs(state.ast.docs) do + if not binded then binded = {} state.ast.docs.groups[#state.ast.docs.groups+1] = binded end binded[#binded+1] = doc - if isTailComment(text, binded) then - bindDoc(sources, binded) + if isTailComment(text, doc) then + bindDocWithSources(sources, binded) binded = nil + else + local nextDoc = state.ast.docs[i+1] + if not isNextLine(doc, nextDoc) then + bindDocWithSources(sources, binded) + binded = nil + end + if not isContinuedDoc(doc, nextDoc) + and not isTailComment(text, nextDoc) then + bindDocWithSources(sources, binded) + binded = nil + end + end + end +end + +local function findTouch(state, doc) + local text = state.lua + local pos = guide.positionToOffset(state, doc.originalComment.start) + for i = pos - 2, 1, -1 do + local c = text:sub(i, i) + if c == '\r' + or c == '\n' then + break + elseif c ~= ' ' + and c ~= '\t' then + doc.touch = guide.offsetToPosition(state, i) + break end end - bindDoc(sources, binded) end return function (state) @@ -1589,25 +1955,40 @@ return function (state) return comment end + local function insertDoc(doc, comment) + ast.docs[#ast.docs+1] = doc + doc.parent = ast.docs + if ast.start > doc.start then + ast.start = doc.start + end + if ast.finish < doc.finish then + ast.finish = doc.finish + end + doc.originalComment = comment + if comment.type == 'comment.long' then + findTouch(state, doc) + end + end + while true do local comment = NextComment() if not comment then break end - local doc = buildLuaDoc(comment) + local doc, rests = buildLuaDoc(comment) if doc then - ast.docs[#ast.docs+1] = doc - doc.parent = ast.docs - if ast.start > doc.start then - ast.start = doc.start - end - if ast.finish < doc.finish then - ast.finish = doc.finish + insertDoc(doc, comment) + if rests then + for _, rest in ipairs(rests) do + insertDoc(rest, comment) + end end - doc.originalComment = comment end end + ast.docs.start = ast.start + ast.docs.finish = ast.finish + if #ast.docs == 0 then return end diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua deleted file mode 100644 index 630c12c2..00000000 --- a/script/parser/newparser.lua +++ /dev/null @@ -1,3855 +0,0 @@ -local tokens = require 'parser.tokens' -local guide = require 'parser.guide' - -local sbyte = string.byte -local sfind = string.find -local smatch = string.match -local sgsub = string.gsub -local ssub = string.sub -local schar = string.char -local supper = string.upper -local uchar = utf8.char -local tconcat = table.concat -local tinsert = table.insert -local tointeger = math.tointeger -local mtype = math.type -local tonumber = tonumber -local maxinteger = math.maxinteger -local assert = assert -local next = next - -_ENV = nil - ----@alias parser.position integer - ----@param str string ----@return table<integer, boolean> -local function stringToCharMap(str) - local map = {} - local pos = 1 - while pos <= #str do - local byte = sbyte(str, pos, pos) - map[schar(byte)] = true - pos = pos + 1 - if ssub(str, pos, pos) == '-' - and pos < #str then - pos = pos + 1 - local byte2 = sbyte(str, pos, pos) - assert(byte < byte2) - for b = byte + 1, byte2 do - map[schar(b)] = true - end - pos = pos + 1 - end - end - return map -end - -local CharMapNumber = stringToCharMap '0-9' -local CharMapN16 = stringToCharMap 'xX' -local CharMapN2 = stringToCharMap 'bB' -local CharMapE10 = stringToCharMap 'eE' -local CharMapE16 = stringToCharMap 'pP' -local CharMapSign = stringToCharMap '+-' -local CharMapSB = stringToCharMap 'ao|~&=<>.*/%^+-' -local CharMapSU = stringToCharMap 'n#~!-' -local CharMapSimple = stringToCharMap '.:([\'"{' -local CharMapStrSH = stringToCharMap '\'"`' -local CharMapStrLH = stringToCharMap '[' -local CharMapTSep = stringToCharMap ',;' -local CharMapWord = stringToCharMap '_a-zA-Z\x80-\xff' - -local EscMap = { - ['a'] = '\a', - ['b'] = '\b', - ['f'] = '\f', - ['n'] = '\n', - ['r'] = '\r', - ['t'] = '\t', - ['v'] = '\v', - ['\\'] = '\\', - ['\''] = '\'', - ['\"'] = '\"', -} - -local NLMap = { - ['\n'] = true, - ['\r'] = true, - ['\r\n'] = true, -} - -local LineMulti = 10000 - --- goto 单独处理 -local KeyWord = { - ['and'] = true, - ['break'] = true, - ['do'] = true, - ['else'] = true, - ['elseif'] = true, - ['end'] = true, - ['false'] = true, - ['for'] = true, - ['function'] = true, - ['if'] = true, - ['in'] = true, - ['local'] = true, - ['nil'] = true, - ['not'] = true, - ['or'] = true, - ['repeat'] = true, - ['return'] = true, - ['then'] = true, - ['true'] = true, - ['until'] = true, - ['while'] = true, -} - -local Specials = { - ['_G'] = true, - ['rawset'] = true, - ['rawget'] = true, - ['setmetatable'] = true, - ['require'] = true, - ['dofile'] = true, - ['loadfile'] = true, - ['pcall'] = true, - ['xpcall'] = true, - ['pairs'] = true, - ['ipairs'] = true, - ['assert'] = true, -} - -local UnarySymbol = { - ['not'] = 11, - ['#'] = 11, - ['~'] = 11, - ['-'] = 11, -} - -local BinarySymbol = { - ['or'] = 1, - ['and'] = 2, - ['<='] = 3, - ['>='] = 3, - ['<'] = 3, - ['>'] = 3, - ['~='] = 3, - ['=='] = 3, - ['|'] = 4, - ['~'] = 5, - ['&'] = 6, - ['<<'] = 7, - ['>>'] = 7, - ['..'] = 8, - ['+'] = 9, - ['-'] = 9, - ['*'] = 10, - ['//'] = 10, - ['/'] = 10, - ['%'] = 10, - ['^'] = 12, -} - -local BinaryAlias = { - ['&&'] = 'and', - ['||'] = 'or', - ['!='] = '~=', -} - -local BinaryActionAlias = { - ['='] = '==', -} - -local UnaryAlias = { - ['!'] = 'not', -} - -local SymbolForward = { - [01] = true, - [02] = true, - [03] = true, - [04] = true, - [05] = true, - [06] = true, - [07] = true, - [08] = false, - [09] = true, - [10] = true, - [11] = true, - [12] = false, -} - -local GetToSetMap = { - ['getglobal'] = 'setglobal', - ['getlocal'] = 'setlocal', - ['getfield'] = 'setfield', - ['getindex'] = 'setindex', - ['getmethod'] = 'setmethod', -} - -local ChunkFinishMap = { - ['end'] = true, - ['else'] = true, - ['elseif'] = true, - ['in'] = true, - ['then'] = true, - ['until'] = true, - [';'] = true, - [']'] = true, - [')'] = true, - ['}'] = true, -} - -local ChunkStartMap = { - ['do'] = true, - ['else'] = true, - ['elseif'] = true, - ['for'] = true, - ['function'] = true, - ['if'] = true, - ['local'] = true, - ['repeat'] = true, - ['return'] = true, - ['then'] = true, - ['until'] = true, - ['while'] = true, -} - -local ListFinishMap = { - ['end'] = true, - ['else'] = true, - ['elseif'] = true, - ['in'] = true, - ['then'] = true, - ['do'] = true, - ['until'] = true, - ['for'] = true, - ['if'] = true, - ['local'] = true, - ['repeat'] = true, - ['return'] = true, - ['while'] = true, -} - -local State, Lua, Line, LineOffset, Chunk, Tokens, Index, LastTokenFinish, Mode, LocalCount - -local LocalLimit = 200 - -local parseExp, parseAction - -local pushError - -local function addSpecial(name, obj) - if not State.specials then - State.specials = {} - end - if not State.specials[name] then - State.specials[name] = {} - end - State.specials[name][#State.specials[name]+1] = obj - obj.special = name -end - ----@param offset integer ----@param leftOrRight '"left"'|'"right"' -local function getPosition(offset, leftOrRight) - if not offset or offset > #Lua then - return LineMulti * Line + #Lua - LineOffset + 1 - end - if leftOrRight == 'left' then - return LineMulti * Line + offset - LineOffset - else - return LineMulti * Line + offset - LineOffset + 1 - end -end - ----@return string word ----@return parser.position startPosition ----@return parser.position finishPosition ----@return integer newOffset -local function peekWord() - local word = Tokens[Index + 1] - if not word then - return nil - end - if not CharMapWord[ssub(word, 1, 1)] then - return nil - end - local startPos = getPosition(Tokens[Index] , 'left') - local finishPos = getPosition(Tokens[Index] + #word - 1, 'right') - return word, startPos, finishPos -end - -local function lastRightPosition() - if Index < 2 then - return 0 - end - local token = Tokens[Index - 1] - if NLMap[token] then - return LastTokenFinish - elseif token then - return getPosition(Tokens[Index - 2] + #token - 1, 'right') - else - return getPosition(#Lua, 'right') - end -end - -local function missSymbol(symbol, start, finish) - pushError { - type = 'MISS_SYMBOL', - start = start or lastRightPosition(), - finish = finish or start or lastRightPosition(), - info = { - symbol = symbol, - } - } -end - -local function missExp() - pushError { - type = 'MISS_EXP', - start = lastRightPosition(), - finish = lastRightPosition(), - } -end - -local function missName(pos) - pushError { - type = 'MISS_NAME', - start = pos or lastRightPosition(), - finish = pos or lastRightPosition(), - } -end - -local function missEnd(relatedStart, relatedFinish) - pushError { - type = 'MISS_SYMBOL', - start = lastRightPosition(), - finish = lastRightPosition(), - info = { - symbol = 'end', - related = { - { - start = relatedStart, - finish = relatedFinish, - } - } - } - } - pushError { - type = 'MISS_END', - start = relatedStart, - finish = relatedFinish, - } -end - -local function unknownSymbol(start, finish, word) - local token = word or Tokens[Index + 1] - if not token then - return false - end - pushError { - type = 'UNKNOWN_SYMBOL', - start = start or getPosition(Tokens[Index], 'left'), - finish = finish or getPosition(Tokens[Index] + #token - 1, 'right'), - info = { - symbol = token, - } - } - return true -end - -local function skipUnknownSymbol(stopSymbol) - if unknownSymbol() then - Index = Index + 2 - return true - end - return false -end - -local function skipNL() - local token = Tokens[Index + 1] - if NLMap[token] then - if Index >= 2 and not NLMap[Tokens[Index - 1]] then - LastTokenFinish = getPosition(Tokens[Index - 2] + #Tokens[Index - 1] - 1, 'right') - end - Line = Line + 1 - LineOffset = Tokens[Index] + #token - Index = Index + 2 - State.lines[Line] = LineOffset - return true - end - return false -end - -local function getSavePoint() - local index = Index - local line = Line - local lineOffset = LineOffset - local errs = State.errs - local errCount = #errs - return function () - Index = index - Line = line - LineOffset = lineOffset - for i = errCount + 1, #errs do - errs[i] = nil - end - end -end - -local function fastForwardToken(offset) - while true do - local myOffset = Tokens[Index] - if not myOffset - or myOffset >= offset then - break - end - local token = Tokens[Index + 1] - if NLMap[token] then - Line = Line + 1 - LineOffset = Tokens[Index] + #token - State.lines[Line] = LineOffset - end - Index = Index + 2 - end -end - -local function resolveLongString(finishMark) - skipNL() - local miss - local start = Tokens[Index] - local finishOffset = sfind(Lua, finishMark, start, true) - if not finishOffset then - finishOffset = #Lua + 1 - miss = true - end - local stringResult = start and ssub(Lua, start, finishOffset - 1) or '' - local lastLN = stringResult:find '[\r\n][^\r\n]*$' - if lastLN then - local result = stringResult - : gsub('\r\n?', '\n') - stringResult = result - end - fastForwardToken(finishOffset + #finishMark) - if miss then - local pos = getPosition(finishOffset - 1, 'right') - pushError { - type = 'MISS_SYMBOL', - start = pos, - finish = pos, - info = { - symbol = finishMark, - }, - fix = { - title = 'ADD_LSTRING_END', - { - start = pos, - finish = pos, - text = finishMark, - } - }, - } - end - return stringResult, getPosition(finishOffset + #finishMark - 1, 'right') -end - -local function parseLongString() - local start, finish, mark = sfind(Lua, '^(%[%=*%[)', Tokens[Index]) - if not mark then - return nil - end - fastForwardToken(finish + 1) - local startPos = getPosition(start, 'left') - local finishMark = sgsub(mark, '%[', ']') - local stringResult, finishPos = resolveLongString(finishMark) - return { - type = 'string', - start = startPos, - finish = finishPos, - [1] = stringResult, - [2] = mark, - } -end - -local function pushCommentHeadError(left) - if State.options.nonstandardSymbol['//'] then - return - end - pushError { - type = 'ERR_COMMENT_PREFIX', - start = left, - finish = left + 2, - fix = { - title = 'FIX_COMMENT_PREFIX', - { - start = left, - finish = left + 2, - text = '--', - }, - } - } -end - -local function pushLongCommentError(left, right) - if State.options.nonstandardSymbol['/**/'] then - return - end - pushError { - type = 'ERR_C_LONG_COMMENT', - start = left, - finish = right, - fix = { - title = 'FIX_C_LONG_COMMENT', - { - start = left, - finish = left + 2, - text = '--[[', - }, - { - start = right - 2, - finish = right, - text = '--]]' - }, - } - } -end - -local function skipComment(isAction) - local token = Tokens[Index + 1] - if token == '--' - or ( - token == '//' - and ( - isAction - or State.options.nonstandardSymbol['//'] - ) - ) then - local start = Tokens[Index] - local left = getPosition(start, 'left') - local chead = false - if token == '//' then - chead = true - pushCommentHeadError(left) - end - Index = Index + 2 - local longComment = start + 2 == Tokens[Index] and parseLongString() - if longComment then - longComment.type = 'comment.long' - longComment.text = longComment[1] - longComment.mark = longComment[2] - longComment[1] = nil - longComment[2] = nil - State.comms[#State.comms+1] = longComment - return true - end - while true do - local nl = Tokens[Index + 1] - if not nl or NLMap[nl] then - break - end - Index = Index + 2 - end - State.comms[#State.comms+1] = { - type = chead and 'comment.cshort' or 'comment.short', - start = left, - finish = lastRightPosition(), - text = ssub(Lua, start + 2, Tokens[Index] and (Tokens[Index] - 1) or #Lua), - } - return true - end - if token == '/*' then - local start = Tokens[Index] - local left = getPosition(start, 'left') - Index = Index + 2 - local result, right = resolveLongString '*/' - pushLongCommentError(left, right) - State.comms[#State.comms+1] = { - type = 'comment.long', - start = left, - finish = right, - text = result, - } - return true - end - return false -end - -local function skipSpace(isAction) - repeat until not skipNL() - and not skipComment(isAction) -end - -local function expectAssign(isAction) - local token = Tokens[Index + 1] - if token == '=' then - Index = Index + 2 - return true - end - if token == '==' then - local left = getPosition(Tokens[Index], 'left') - local right = getPosition(Tokens[Index] + #token - 1, 'right') - pushError { - type = 'ERR_ASSIGN_AS_EQ', - start = left, - finish = right, - fix = { - title = 'FIX_ASSIGN_AS_EQ', - { - start = left, - finish = right, - text = '=', - } - } - } - Index = Index + 2 - return true - end - if isAction then - if token == '+=' - or token == '-=' - or token == '*=' - or token == '/=' then - if not State.options.nonstandardSymbol[token] then - unknownSymbol() - end - Index = Index + 2 - return true - end - end - return false -end - -local function parseLocalAttrs() - local attrs - while true do - skipSpace() - local token = Tokens[Index + 1] - if token ~= '<' then - break - end - if not attrs then - attrs = { - type = 'localattrs', - } - end - local attr = { - type = 'localattr', - parent = attrs, - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index], 'right'), - } - attrs[#attrs+1] = attr - Index = Index + 2 - skipSpace() - local word, wstart, wfinish = peekWord() - if word then - attr[1] = word - attr.finish = wfinish - Index = Index + 2 - if word ~= 'const' - and word ~= 'close' then - pushError { - type = 'UNKNOWN_ATTRIBUTE', - start = wstart, - finish = wfinish, - } - end - else - missName() - end - attr.finish = lastRightPosition() - skipSpace() - if Tokens[Index + 1] == '>' then - attr.finish = getPosition(Tokens[Index], 'right') - Index = Index + 2 - elseif Tokens[Index + 1] == '>=' then - attr.finish = getPosition(Tokens[Index], 'right') - pushError { - type = 'MISS_SPACE_BETWEEN', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + 1, 'right'), - } - Index = Index + 2 - else - missSymbol '>' - end - if State.version ~= 'Lua 5.4' then - pushError { - type = 'UNSUPPORT_SYMBOL', - start = attr.start, - finish = attr.finish, - version = 'Lua 5.4', - info = { - version = State.version - } - } - end - end - return attrs -end - -local function createLocal(obj, attrs) - obj.type = 'local' - obj.effect = obj.finish - - if attrs then - obj.attrs = attrs - attrs.parent = obj - end - - local chunk = Chunk[#Chunk] - if chunk then - local locals = chunk.locals - if not locals then - locals = {} - chunk.locals = locals - end - locals[#locals+1] = obj - LocalCount = LocalCount + 1 - if LocalCount > LocalLimit then - pushError { - type = 'LOCAL_LIMIT', - start = obj.start, - finish = obj.finish, - } - end - end - return obj -end - -local function pushChunk(chunk) - Chunk[#Chunk+1] = chunk -end - -local function resolveLable(label, obj) - if not label.ref then - label.ref = {} - end - label.ref[#label.ref+1] = obj - obj.node = label - - -- 如果有局部变量在 goto 与 label 之间声明, - -- 并在 label 之后使用,则算作语法错误 - - -- 如果 label 在 goto 之前声明,那么不会有中间声明的局部变量 - if obj.start > label.start then - return - end - - local block = guide.getBlock(obj) - local locals = block and block.locals - if not locals then - return - end - - for i = 1, #locals do - local loc = locals[i] - -- 检查局部变量声明位置为 goto 与 label 之间 - if loc.start < obj.start or loc.finish > label.finish then - goto CONTINUE - end - -- 检查局部变量的使用位置在 label 之后 - local refs = loc.ref - if not refs then - goto CONTINUE - end - for j = 1, #refs do - local ref = refs[j] - if ref.finish > label.finish then - pushError { - type = 'JUMP_LOCAL_SCOPE', - start = obj.start, - finish = obj.finish, - info = { - loc = loc[1], - }, - relative = { - { - start = label.start, - finish = label.finish, - }, - { - start = loc.start, - finish = loc.finish, - } - }, - } - return - end - end - ::CONTINUE:: - end -end - -local function resolveGoTo(gotos) - for i = 1, #gotos do - local action = gotos[i] - local label = guide.getLabel(action, action[1]) - if label then - resolveLable(label, action) - else - pushError { - type = 'NO_VISIBLE_LABEL', - start = action.start, - finish = action.finish, - info = { - label = action[1], - } - } - end - end -end - -local function popChunk() - local chunk = Chunk[#Chunk] - if chunk.gotos then - resolveGoTo(chunk.gotos) - chunk.gotos = nil - end - local lastAction = chunk[#chunk] - if lastAction then - chunk.finish = lastAction.finish - end - Chunk[#Chunk] = nil -end - -local function parseNil() - if Tokens[Index + 1] ~= 'nil' then - return nil - end - local offset = Tokens[Index] - Index = Index + 2 - return { - type = 'nil', - start = getPosition(offset, 'left'), - finish = getPosition(offset + 2, 'right'), - } -end - -local function parseBoolean() - local word = Tokens[Index+1] - if word ~= 'true' - and word ~= 'false' then - return nil - end - local start = getPosition(Tokens[Index], 'left') - local finish = getPosition(Tokens[Index] + #word - 1, 'right') - Index = Index + 2 - return { - type = 'boolean', - start = start, - finish = finish, - [1] = word == 'true' and true or false, - } -end - -local function parseStringUnicode() - local offset = Tokens[Index] + 1 - if ssub(Lua, offset, offset) ~= '{' then - local pos = getPosition(offset, 'left') - missSymbol('{', pos) - return nil, offset - end - local leftPos = getPosition(offset, 'left') - local x16 = smatch(Lua, '^%w*', offset + 1) - local rightPos = getPosition(offset + #x16, 'right') - offset = offset + #x16 + 1 - if ssub(Lua, offset, offset) == '}' then - offset = offset + 1 - rightPos = rightPos + 1 - else - missSymbol('}', rightPos) - end - offset = offset + 1 - if #x16 == 0 then - pushError { - type = 'UTF8_SMALL', - start = leftPos, - finish = rightPos, - } - return '', offset - end - if State.version ~= 'Lua 5.3' - and State.version ~= 'Lua 5.4' - and State.version ~= 'LuaJIT' - then - pushError { - type = 'ERR_ESC', - start = leftPos - 2, - finish = rightPos, - version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, - info = { - version = State.version, - } - } - return nil, offset - end - local byte = tonumber(x16, 16) - if not byte then - for i = 1, #x16 do - if not tonumber(ssub(x16, i, i), 16) then - pushError { - type = 'MUST_X16', - start = leftPos + i, - finish = leftPos + i + 1, - } - end - end - return nil, offset - end - if State.version == 'Lua 5.4' then - if byte < 0 or byte > 0x7FFFFFFF then - pushError { - type = 'UTF8_MAX', - start = leftPos, - finish = rightPos, - info = { - min = '00000000', - max = '7FFFFFFF', - } - } - return nil, offset - end - else - if byte < 0 or byte > 0x10FFFF then - pushError { - type = 'UTF8_MAX', - start = leftPos, - finish = rightPos, - version = byte <= 0x7FFFFFFF and 'Lua 5.4' or nil, - info = { - min = '000000', - max = '10FFFF', - } - } - end - end - if byte >= 0 and byte <= 0x10FFFF then - return uchar(byte), offset - end - return '', offset -end - -local stringPool = {} -local function parseShortString() - local mark = Tokens[Index+1] - local startOffset = Tokens[Index] - local startPos = getPosition(startOffset, 'left') - Index = Index + 2 - local stringIndex = 0 - local currentOffset = startOffset + 1 - local escs = {} - while true do - local token = Tokens[Index + 1] - if token == mark then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1) - Index = Index + 2 - break - end - if NLMap[token] then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1) - missSymbol(mark) - break - end - if not token then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = ssub(Lua, currentOffset or -1) - missSymbol(mark) - break - end - if token == '\\' then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = ssub(Lua, currentOffset, Tokens[Index] - 1) - currentOffset = Tokens[Index] - Index = Index + 2 - if not Tokens[Index] then - goto CONTINUE - end - local escLeft = getPosition(currentOffset, 'left') - -- has space? - if Tokens[Index] - currentOffset > 1 then - local right = getPosition(currentOffset + 1, 'right') - pushError { - type = 'ERR_ESC', - start = escLeft, - finish = right, - } - escs[#escs+1] = escLeft - escs[#escs+1] = right - escs[#escs+1] = 'err' - goto CONTINUE - end - local nextToken = ssub(Tokens[Index + 1], 1, 1) - if EscMap[nextToken] then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = EscMap[nextToken] - currentOffset = Tokens[Index] + #nextToken - Index = Index + 2 - escs[#escs+1] = escLeft - escs[#escs+1] = escLeft + 2 - escs[#escs+1] = 'normal' - goto CONTINUE - end - if nextToken == mark then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = mark - currentOffset = Tokens[Index] + #nextToken - Index = Index + 2 - escs[#escs+1] = escLeft - escs[#escs+1] = escLeft + 2 - escs[#escs+1] = 'normal' - goto CONTINUE - end - if nextToken == 'z' then - Index = Index + 2 - repeat until not skipNL() - currentOffset = Tokens[Index] - escs[#escs+1] = escLeft - escs[#escs+1] = escLeft + 2 - escs[#escs+1] = 'normal' - goto CONTINUE - end - if CharMapNumber[nextToken] then - local numbers = smatch(Tokens[Index + 1], '^%d+') - if #numbers > 3 then - numbers = ssub(numbers, 1, 3) - end - currentOffset = Tokens[Index] + #numbers - fastForwardToken(currentOffset) - local right = getPosition(currentOffset - 1, 'right') - local byte = tointeger(numbers) - if byte <= 255 then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = schar(byte) - else - pushError { - type = 'ERR_ESC', - start = escLeft, - finish = right, - } - end - escs[#escs+1] = escLeft - escs[#escs+1] = right - escs[#escs+1] = 'byte' - goto CONTINUE - end - if nextToken == 'x' then - local left = getPosition(Tokens[Index] - 1, 'left') - local x16 = ssub(Tokens[Index + 1], 2, 3) - local byte = tonumber(x16, 16) - if byte then - currentOffset = Tokens[Index] + 3 - stringIndex = stringIndex + 1 - stringPool[stringIndex] = schar(byte) - else - currentOffset = Tokens[Index] + 1 - pushError { - type = 'MISS_ESC_X', - start = getPosition(currentOffset, 'left'), - finish = getPosition(currentOffset + 1, 'right'), - } - end - local right = getPosition(currentOffset + 1, 'right') - escs[#escs+1] = escLeft - escs[#escs+1] = right - escs[#escs+1] = 'byte' - if State.version == 'Lua 5.1' then - pushError { - type = 'ERR_ESC', - start = left, - finish = left + 4, - version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, - info = { - version = State.version, - } - } - end - Index = Index + 2 - goto CONTINUE - end - if nextToken == 'u' then - local str, newOffset = parseStringUnicode() - if str then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = str - end - currentOffset = newOffset - fastForwardToken(currentOffset - 1) - local right = getPosition(currentOffset + 1, 'right') - escs[#escs+1] = escLeft - escs[#escs+1] = right - escs[#escs+1] = 'unicode' - goto CONTINUE - end - if NLMap[nextToken] then - stringIndex = stringIndex + 1 - stringPool[stringIndex] = '\n' - currentOffset = Tokens[Index] + #nextToken - skipNL() - local right = getPosition(currentOffset + 1, 'right') - escs[#escs+1] = escLeft - escs[#escs+1] = escLeft + 1 - escs[#escs+1] = 'normal' - goto CONTINUE - end - local right = getPosition(currentOffset + 1, 'right') - pushError { - type = 'ERR_ESC', - start = escLeft, - finish = right, - } - escs[#escs+1] = escLeft - escs[#escs+1] = right - escs[#escs+1] = 'err' - end - Index = Index + 2 - ::CONTINUE:: - end - local stringResult = tconcat(stringPool, '', 1, stringIndex) - local str = { - type = 'string', - start = startPos, - finish = lastRightPosition(), - escs = #escs > 0 and escs or nil, - [1] = stringResult, - [2] = mark, - } - if mark == '`' then - if not State.options.nonstandardSymbol[mark] then - pushError { - type = 'ERR_NONSTANDARD_SYMBOL', - start = startPos, - finish = str.finish, - info = { - symbol = '"', - }, - fix = { - title = 'FIX_NONSTANDARD_SYMBOL', - symbol = '"', - { - start = startPos, - finish = startPos + 1, - text = '"', - }, - { - start = str.finish - 1, - finish = str.finish, - text = '"', - }, - } - } - end - end - return str -end - -local function parseString() - local c = Tokens[Index + 1] - if CharMapStrSH[c] then - return parseShortString() - end - if CharMapStrLH[c] then - return parseLongString() - end - return nil -end - -local function parseNumber10(start) - local integer = true - local integerPart = smatch(Lua, '^%d*', start) - local offset = start + #integerPart - -- float part - if ssub(Lua, offset, offset) == '.' then - local floatPart = smatch(Lua, '^%d*', offset + 1) - integer = false - offset = offset + #floatPart + 1 - end - -- exp part - local echar = ssub(Lua, offset, offset) - if CharMapE10[echar] then - integer = false - offset = offset + 1 - local nextChar = ssub(Lua, offset, offset) - if CharMapSign[nextChar] then - offset = offset + 1 - end - local exp = smatch(Lua, '^%d*', offset) - offset = offset + #exp - if #exp == 0 then - pushError { - type = 'MISS_EXPONENT', - start = getPosition(offset - 1, 'right'), - finish = getPosition(offset - 1, 'right'), - } - end - end - return tonumber(ssub(Lua, start, offset - 1)), offset, integer -end - -local function parseNumber16(start) - local integerPart = smatch(Lua, '^[%da-fA-F]*', start) - local offset = start + #integerPart - local integer = true - -- float part - if ssub(Lua, offset, offset) == '.' then - local floatPart = smatch(Lua, '^[%da-fA-F]*', offset + 1) - integer = false - offset = offset + #floatPart + 1 - if #integerPart == 0 and #floatPart == 0 then - pushError { - type = 'MUST_X16', - start = getPosition(offset - 1, 'right'), - finish = getPosition(offset - 1, 'right'), - } - end - else - if #integerPart == 0 then - pushError { - type = 'MUST_X16', - start = getPosition(offset - 1, 'right'), - finish = getPosition(offset - 1, 'right'), - } - return 0, offset - end - end - -- exp part - local echar = ssub(Lua, offset, offset) - if CharMapE16[echar] then - integer = false - offset = offset + 1 - local nextChar = ssub(Lua, offset, offset) - if CharMapSign[nextChar] then - offset = offset + 1 - end - local exp = smatch(Lua, '^%d*', offset) - offset = offset + #exp - end - local n = tonumber(ssub(Lua, start - 2, offset - 1)) - return n, offset, integer -end - -local function parseNumber2(start) - local bins = smatch(Lua, '^[01]*', start) - local offset = start + #bins - if State.version ~= 'LuaJIT' then - pushError { - type = 'UNSUPPORT_SYMBOL', - start = getPosition(start - 2, 'left'), - finish = getPosition(offset - 1, 'right'), - version = 'LuaJIT', - info = { - version = 'Lua 5.4', - } - } - end - return tonumber(bins, 2), offset -end - -local function dropNumberTail(offset, integer) - local _, finish, word = sfind(Lua, '^([%.%w_\x80-\xff]+)', offset) - if not finish then - return offset - end - if integer then - if supper(ssub(word, 1, 2)) == 'LL' then - if State.version ~= 'LuaJIT' then - pushError { - type = 'UNSUPPORT_SYMBOL', - start = getPosition(offset, 'left'), - finish = getPosition(offset + 1, 'right'), - version = 'LuaJIT', - info = { - version = State.version, - } - } - end - offset = offset + 2 - word = ssub(word, offset) - elseif supper(ssub(word, 1, 3)) == 'ULL' then - if State.version ~= 'LuaJIT' then - pushError { - type = 'UNSUPPORT_SYMBOL', - start = getPosition(offset, 'left'), - finish = getPosition(offset + 2, 'right'), - version = 'LuaJIT', - info = { - version = State.version, - } - } - end - offset = offset + 3 - word = ssub(word, offset) - end - end - if supper(ssub(word, 1, 1)) == 'I' then - if State.version ~= 'LuaJIT' then - pushError { - type = 'UNSUPPORT_SYMBOL', - start = getPosition(offset, 'left'), - finish = getPosition(offset, 'right'), - version = 'LuaJIT', - info = { - version = State.version, - } - } - end - offset = offset + 1 - word = ssub(word, offset) - end - if #word > 0 then - pushError { - type = 'MALFORMED_NUMBER', - start = getPosition(offset, 'left'), - finish = getPosition(finish, 'right'), - } - end - return finish + 1 -end - -local function parseNumber() - local offset = Tokens[Index] - if not offset then - return nil - end - local startPos = getPosition(offset, 'left') - local neg - if ssub(Lua, offset, offset) == '-' then - neg = true - offset = offset + 1 - end - local number, integer - local firstChar = ssub(Lua, offset, offset) - if firstChar == '.' then - number, offset = parseNumber10(offset) - integer = false - elseif firstChar == '0' then - local nextChar = ssub(Lua, offset + 1, offset + 1) - if CharMapN16[nextChar] then - number, offset, integer = parseNumber16(offset + 2) - elseif CharMapN2[nextChar] then - number, offset = parseNumber2(offset + 2) - integer = true - else - number, offset, integer = parseNumber10(offset) - end - elseif CharMapNumber[firstChar] then - number, offset, integer = parseNumber10(offset) - else - return nil - end - if not number then - number = 0 - end - if neg then - number = - number - end - local result = { - type = integer and 'integer' or 'number', - start = startPos, - finish = getPosition(offset - 1, 'right'), - [1] = number, - } - offset = dropNumberTail(offset, integer) - fastForwardToken(offset) - return result -end - -local function isKeyWord(word) - if KeyWord[word] then - return true - end - if word == 'goto' then - return State.version ~= 'Lua 5.1' - end - return false -end - -local function parseName(asAction) - local word = peekWord() - if not word then - return nil - end - if ChunkFinishMap[word] then - return nil - end - if asAction and ChunkStartMap[word] then - return nil - end - local startPos = getPosition(Tokens[Index], 'left') - local finishPos = getPosition(Tokens[Index] + #word - 1, 'right') - Index = Index + 2 - if not State.options.unicodeName and word:find '[\x80-\xff]' then - pushError { - type = 'UNICODE_NAME', - start = startPos, - finish = finishPos, - } - end - if isKeyWord(word) then - pushError { - type = 'KEYWORD', - start = startPos, - finish = finishPos, - } - end - return { - type = 'name', - start = startPos, - finish = finishPos, - [1] = word, - } -end - -local function parseNameOrList(parent) - local first = parseName() - if not first then - return nil - end - skipSpace() - local list - while true do - if Tokens[Index + 1] ~= ',' then - break - end - Index = Index + 2 - skipSpace() - local name = parseName(true) - if not name then - missName() - break - end - if not list then - list = { - type = 'list', - start = first.start, - finish = first.finish, - parent = parent, - [1] = first - } - end - list[#list+1] = name - list.finish = name.finish - end - return list or first -end - -local function dropTail() - local token = Tokens[Index + 1] - if token ~= '?' - and token ~= ':' then - return - end - local pl, pt, pp = 0, 0, 0 - while true do - local token = Tokens[Index + 1] - if not token then - break - end - if NLMap[token] then - break - end - if token == ',' then - if pl > 0 - or pt > 0 - or pp > 0 then - goto CONTINUE - else - break - end - end - if token == '<' then - pl = pl + 1 - goto CONTINUE - end - if token == '{' then - pt = pt + 1 - goto CONTINUE - end - if token == '(' then - pp = pp + 1 - goto CONTINUE - end - if token == '>' then - if pl <= 0 then - break - end - pl = pl - 1 - goto CONTINUE - end - if token == '}' then - if pt <= 0 then - break - end - pt = pt - 1 - goto CONTINUE - end - if token == ')' then - if pp <= 0 then - break - end - pp = pp - 1 - goto CONTINUE - end - ::CONTINUE:: - Index = Index + 2 - end -end - -local function parseExpList(mini) - local list - local wantSep = false - while true do - skipSpace() - local token = Tokens[Index + 1] - if not token then - break - end - if ListFinishMap[token] then - break - end - if token == ',' then - local sepPos = getPosition(Tokens[Index], 'right') - if not wantSep then - pushError { - type = 'UNEXPECT_SYMBOL', - start = getPosition(Tokens[Index], 'left'), - finish = sepPos, - info = { - symbol = ',', - } - } - end - wantSep = false - Index = Index + 2 - goto CONTINUE - else - if mini then - if wantSep then - break - end - local nextToken = peekWord() - if isKeyWord(nextToken) - and nextToken ~= 'function' - and nextToken ~= 'true' - and nextToken ~= 'false' - and nextToken ~= 'nil' - and nextToken ~= 'not' then - break - end - end - local exp = parseExp() - if not exp then - break - end - dropTail() - if wantSep then - missSymbol(',', list[#list].finish, exp.start) - end - wantSep = true - if not list then - list = { - type = 'list', - start = exp.start, - } - end - list[#list+1] = exp - list.finish = exp.finish - exp.parent = list - end - ::CONTINUE:: - end - if not list then - return nil - end - if not wantSep then - missExp() - end - return list -end - -local function parseIndex() - local start = getPosition(Tokens[Index], 'left') - Index = Index + 2 - skipSpace() - local exp = parseExp() - local index = { - type = 'index', - start = start, - finish = exp and exp.finish or (start + 1), - index = exp - } - if exp then - exp.parent = index - else - missExp() - end - skipSpace() - if Tokens[Index + 1] == ']' then - index.finish = getPosition(Tokens[Index], 'right') - Index = Index + 2 - else - missSymbol ']' - end - return index -end - -local function parseTable() - local tbl = { - type = 'table', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index], 'right'), - } - Index = Index + 2 - local index = 0 - local tindex = 0 - local wantSep = false - while true do - skipSpace(true) - local token = Tokens[Index + 1] - if token == '}' then - Index = Index + 2 - break - end - if CharMapTSep[token] then - if not wantSep then - missExp() - end - wantSep = false - Index = Index + 2 - goto CONTINUE - end - local lastRight = lastRightPosition() - - if peekWord() then - local savePoint = getSavePoint() - local name = parseName() - if name then - skipSpace() - if Tokens[Index + 1] == '=' then - Index = Index + 2 - if wantSep then - pushError { - type = 'MISS_SEP_IN_TABLE', - start = lastRight, - finish = getPosition(Tokens[Index], 'left'), - } - end - wantSep = true - local eqRight = lastRightPosition() - skipSpace() - local fvalue = parseExp() - local tfield = { - type = 'tablefield', - start = name.start, - finish = fvalue and fvalue.finish or eqRight, - parent = tbl, - field = name, - value = fvalue, - } - name.type = 'field' - name.parent = tfield - if fvalue then - fvalue.parent = tfield - else - missExp() - end - index = index + 1 - tbl[index] = tfield - goto CONTINUE - end - end - savePoint() - end - - local exp = parseExp(true) - if exp then - if wantSep then - pushError { - type = 'MISS_SEP_IN_TABLE', - start = lastRight, - finish = exp.start, - } - end - wantSep = true - if exp.type == 'varargs' then - index = index + 1 - tbl[index] = exp - exp.parent = tbl - goto CONTINUE - end - index = index + 1 - tindex = tindex + 1 - local texp = { - type = 'tableexp', - start = exp.start, - finish = exp.finish, - tindex = tindex, - parent = tbl, - value = exp, - } - exp.parent = texp - tbl[index] = texp - goto CONTINUE - end - - if token == '[' then - if wantSep then - pushError { - type = 'MISS_SEP_IN_TABLE', - start = lastRight, - finish = getPosition(Tokens[Index], 'left'), - } - end - wantSep = true - local tindex = parseIndex() - skipSpace() - tindex.type = 'tableindex' - tindex.parent = tbl - index = index + 1 - tbl[index] = tindex - if expectAssign() then - skipSpace() - local ivalue = parseExp() - if ivalue then - ivalue.parent = tindex - tindex.finish = ivalue.finish - tindex.value = ivalue - else - missExp() - end - else - missSymbol '=' - end - goto CONTINUE - end - - missSymbol '}' - break - ::CONTINUE:: - end - tbl.finish = lastRightPosition() - return tbl -end - -local function addDummySelf(node, call) - if node.type ~= 'getmethod' then - return - end - -- dummy param `self` - if not call.args then - call.args = { - type = 'callargs', - start = call.start, - finish = call.finish, - parent = call, - } - end - local self = { - type = 'self', - start = node.colon.start, - finish = node.colon.finish, - parent = call.args, - [1] = 'self', - } - tinsert(call.args, 1, self) -end - -local function parseSimple(node, funcName) - local lastMethod - while true do - if lastMethod and node.node == lastMethod then - if node.type ~= 'call' then - missSymbol('(', node.node.finish, node.node.finish) - end - lastMethod = nil - end - skipSpace() - local token = Tokens[Index + 1] - if token == '.' then - local dot = { - type = token, - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index], 'right'), - } - Index = Index + 2 - skipSpace() - local field = parseName(true) - local getfield = { - type = 'getfield', - start = node.start, - finish = lastRightPosition(), - node = node, - dot = dot, - field = field - } - if field then - field.parent = getfield - field.type = 'field' - else - pushError { - type = 'MISS_FIELD', - start = lastRightPosition(), - finish = lastRightPosition(), - } - end - node.parent = getfield - node.next = getfield - node = getfield - elseif token == ':' then - local colon = { - type = token, - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index], 'right'), - } - Index = Index + 2 - skipSpace() - local method = parseName(true) - local getmethod = { - type = 'getmethod', - start = node.start, - finish = lastRightPosition(), - node = node, - colon = colon, - method = method - } - if method then - method.parent = getmethod - method.type = 'method' - else - pushError { - type = 'MISS_METHOD', - start = lastRightPosition(), - finish = lastRightPosition(), - } - end - node.parent = getmethod - node.next = getmethod - node = getmethod - if lastMethod then - missSymbol('(', node.node.finish, node.node.finish) - end - lastMethod = getmethod - elseif token == '(' then - if funcName then - break - end - local startPos = getPosition(Tokens[Index], 'left') - local call = { - type = 'call', - start = node.start, - node = node, - } - Index = Index + 2 - local args = parseExpList() - if Tokens[Index + 1] == ')' then - call.finish = getPosition(Tokens[Index], 'right') - Index = Index + 2 - else - call.finish = lastRightPosition() - missSymbol ')' - end - if args then - args.type = 'callargs' - args.start = startPos - args.finish = call.finish - args.parent = call - call.args = args - end - addDummySelf(node, call) - node.parent = call - node = call - elseif token == '{' then - if funcName then - break - end - local tbl = parseTable() - local call = { - type = 'call', - start = node.start, - finish = tbl.finish, - node = node, - } - local args = { - type = 'callargs', - start = tbl.start, - finish = tbl.finish, - parent = call, - [1] = tbl, - } - call.args = args - addDummySelf(node, call) - tbl.parent = args - node.parent = call - node = call - elseif CharMapStrSH[token] then - if funcName then - break - end - local str = parseShortString() - local call = { - type = 'call', - start = node.start, - finish = str.finish, - node = node, - } - local args = { - type = 'callargs', - start = str.start, - finish = str.finish, - parent = call, - [1] = str, - } - call.args = args - addDummySelf(node, call) - str.parent = args - node.parent = call - node = call - elseif CharMapStrLH[token] then - local str = parseLongString() - if str then - if funcName then - break - end - local call = { - type = 'call', - start = node.start, - finish = str.finish, - node = node, - } - local args = { - type = 'callargs', - start = str.start, - finish = str.finish, - parent = call, - [1] = str, - } - call.args = args - addDummySelf(node, call) - str.parent = args - node.parent = call - node = call - else - local index = parseIndex() - local bstart = index.start - index.type = 'getindex' - index.start = node.start - index.node = node - node.next = index - node.parent = index - node = index - if funcName then - pushError { - type = 'INDEX_IN_FUNC_NAME', - start = bstart, - finish = index.finish, - } - end - end - else - break - end - end - if node.type == 'call' - and node.node == lastMethod then - lastMethod = nil - end - if node == lastMethod then - if funcName then - lastMethod = nil - end - end - if lastMethod then - missSymbol('(', lastMethod.finish) - end - return node -end - -local function parseVarargs() - local varargs = { - type = 'varargs', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + 2, 'right'), - } - Index = Index + 2 - for i = #Chunk, 1, -1 do - local chunk = Chunk[i] - if chunk.vararg then - if not chunk.vararg.ref then - chunk.vararg.ref = {} - end - chunk.vararg.ref[#chunk.vararg.ref+1] = varargs - varargs.node = chunk.vararg - break - end - if chunk.type == 'main' then - break - end - if chunk.type == 'function' then - pushError { - type = 'UNEXPECT_DOTS', - start = varargs.start, - finish = varargs.finish, - } - break - end - end - return varargs -end - -local function parseParen() - local pl = Tokens[Index] - local paren = { - type = 'paren', - start = getPosition(pl, 'left'), - finish = getPosition(pl, 'right') - } - Index = Index + 2 - skipSpace() - local exp = parseExp() - if exp then - paren.exp = exp - paren.finish = exp.finish - exp.parent = paren - else - missExp() - end - skipSpace() - if Tokens[Index + 1] == ')' then - paren.finish = getPosition(Tokens[Index], 'right') - Index = Index + 2 - else - missSymbol ')' - end - return paren -end - -local function getLocal(name, pos) - for i = #Chunk, 1, -1 do - local chunk = Chunk[i] - local locals = chunk.locals - if locals then - local res - for n = 1, #locals do - local loc = locals[n] - if loc.effect > pos then - break - end - if loc[1] == name then - if not res or res.effect < loc.effect then - res = loc - end - end - end - if res then - return res - end - end - end -end - -local function resolveName(node) - if not node then - return nil - end - local loc = getLocal(node[1], node.start) - if loc then - node.type = 'getlocal' - node.node = loc - if not loc.ref then - loc.ref = {} - end - loc.ref[#loc.ref+1] = node - if loc.special then - addSpecial(loc.special, node) - end - else - node.type = 'getglobal' - local env = getLocal(State.ENVMode, node.start) - if env then - node.node = env - if not env.ref then - env.ref = {} - end - env.ref[#env.ref+1] = node - end - end - local name = node[1] - if Specials[name] then - addSpecial(name, node) - else - local ospeicals = State.options.special - if ospeicals and ospeicals[name] then - addSpecial(ospeicals[name], node) - end - end - return node -end - -local function isChunkFinishToken(token) - local currentChunk = Chunk[#Chunk] - if not currentChunk then - return false - end - local tp = currentChunk.type - if tp == 'main' then - return false - end - if tp == 'for' - or tp == 'in' - or tp == 'loop' - or tp == 'function' then - return token == 'end' - end - if tp == 'if' - or tp == 'ifblock' - or tp == 'elseifblock' - or tp == 'elseblock' then - return token == 'then' - or token == 'end' - or token == 'else' - or token == 'elseif' - end - if tp == 'repeat' then - return token == 'until' - end - return true -end - -local function parseActions() - local rtn, last - while true do - skipSpace(true) - local token = Tokens[Index + 1] - if token == ';' then - Index = Index + 2 - goto CONTINUE - end - if ChunkFinishMap[token] - and isChunkFinishToken(token) then - break - end - local action, failed = parseAction() - if failed then - if not skipUnknownSymbol() then - break - end - end - if action then - if not rtn and action.type == 'return' then - rtn = action - end - last = action - end - ::CONTINUE:: - end - if rtn and rtn ~= last then - pushError { - type = 'ACTION_AFTER_RETURN', - start = rtn.start, - finish = rtn.finish, - } - end -end - -local function parseParams(params) - local lastSep - local hasDots - while true do - skipSpace() - local token = Tokens[Index + 1] - if not token or token == ')' then - if lastSep then - missName() - end - break - end - if token == ',' then - if lastSep or lastSep == nil then - missName() - else - lastSep = true - end - Index = Index + 2 - goto CONTINUE - end - if token == '...' then - if lastSep == false then - missSymbol ',' - end - lastSep = false - if not params then - params = {} - end - local vararg = { - type = '...', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + 2, 'right'), - parent = params, - [1] = '...', - } - local chunk = Chunk[#Chunk] - chunk.vararg = vararg - params[#params+1] = vararg - if hasDots then - pushError { - type = 'ARGS_AFTER_DOTS', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + 2, 'right'), - } - end - hasDots = true - Index = Index + 2 - goto CONTINUE - end - if CharMapWord[ssub(token, 1, 1)] then - if lastSep == false then - missSymbol ',' - end - lastSep = false - if not params then - params = {} - end - params[#params+1] = createLocal { - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + #token - 1, 'right'), - parent = params, - [1] = token, - } - if hasDots then - pushError { - type = 'ARGS_AFTER_DOTS', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + #token - 1, 'right'), - } - end - if isKeyWord(token) then - pushError { - type = 'KEYWORD', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + #token - 1, 'right'), - } - end - Index = Index + 2 - goto CONTINUE - end - skipUnknownSymbol '%,%)%.' - ::CONTINUE:: - end - return params -end - -local function parseFunction(isLocal, isAction) - local funcLeft = getPosition(Tokens[Index], 'left') - local funcRight = getPosition(Tokens[Index] + 7, 'right') - local func = { - type = 'function', - start = funcLeft, - finish = funcRight, - keyword = { - [1] = funcLeft, - [2] = funcRight, - }, - } - Index = Index + 2 - local LastLocalCount = LocalCount - LocalCount = 0 - skipSpace(true) - local hasLeftParen = Tokens[Index + 1] == '(' - if not hasLeftParen then - local name = parseName() - if name then - local simple = parseSimple(name, true) - if isLocal then - if simple == name then - createLocal(name) - else - resolveName(name) - pushError { - type = 'UNEXPECT_LFUNC_NAME', - start = simple.start, - finish = simple.finish, - } - end - else - resolveName(name) - end - func.name = simple - func.finish = simple.finish - if not isAction then - simple.parent = func - pushError { - type = 'UNEXPECT_EFUNC_NAME', - start = simple.start, - finish = simple.finish, - } - end - skipSpace(true) - hasLeftParen = Tokens[Index + 1] == '(' - end - end - pushChunk(func) - local params - if func.name and func.name.type == 'getmethod' then - if func.name.type == 'getmethod' then - params = {} - params[1] = createLocal { - start = funcRight, - finish = funcRight, - parent = params, - [1] = 'self', - } - params[1].type = 'self' - end - end - if hasLeftParen then - local parenLeft = getPosition(Tokens[Index], 'left') - Index = Index + 2 - params = parseParams(params) - if params then - params.type = 'funcargs' - params.start = parenLeft - params.finish = lastRightPosition() - params.parent = func - func.args = params - end - skipSpace(true) - if Tokens[Index + 1] == ')' then - local parenRight = getPosition(Tokens[Index], 'right') - func.finish = parenRight - if params then - params.finish = parenRight - end - Index = Index + 2 - skipSpace(true) - else - func.finish = lastRightPosition() - if params then - params.finish = func.finish - end - missSymbol ')' - end - else - missSymbol '(' - end - parseActions() - popChunk() - if Tokens[Index + 1] == 'end' then - local endLeft = getPosition(Tokens[Index], 'left') - local endRight = getPosition(Tokens[Index] + 2, 'right') - func.keyword[3] = endLeft - func.keyword[4] = endRight - func.finish = endRight - Index = Index + 2 - else - missEnd(funcLeft, funcRight) - end - LocalCount = LastLocalCount - return func -end - -local function parseExpUnit() - local token = Tokens[Index + 1] - if token == '(' then - local paren = parseParen() - return parseSimple(paren, false) - end - - if token == '...' then - local varargs = parseVarargs() - return varargs - end - - if token == '{' then - local table = parseTable() - return table - end - - if CharMapStrSH[token] then - local string = parseShortString() - return string - end - - if CharMapStrLH[token] then - local string = parseLongString() - return string - end - - local number = parseNumber() - if number then - return number - end - - if ChunkFinishMap[token] then - return nil - end - - if token == 'nil' then - return parseNil() - end - - if token == 'true' - or token == 'false' then - return parseBoolean() - end - - if token == 'function' then - return parseFunction() - end - - local node = parseName() - if node then - return parseSimple(resolveName(node), false) - end - - return nil -end - -local function parseUnaryOP() - local token = Tokens[Index + 1] - local symbol = UnarySymbol[token] and token or UnaryAlias[token] - if not symbol then - return nil - end - local myLevel = UnarySymbol[symbol] - local op = { - type = symbol, - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + #symbol - 1, 'right'), - } - Index = Index + 2 - return op, myLevel -end - ----@param level integer # op level must greater than this level -local function parseBinaryOP(asAction, level) - local token = Tokens[Index + 1] - local symbol = (BinarySymbol[token] and token) - or BinaryAlias[token] - or (not asAction and BinaryActionAlias[token]) - if not symbol then - return nil - end - if symbol == '//' and State.options.nonstandardSymbol['//'] then - return nil - end - local myLevel = BinarySymbol[symbol] - if level and myLevel < level then - return nil - end - local op = { - type = symbol, - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + #token - 1, 'right'), - } - if not asAction then - if token == '=' then - pushError { - type = 'ERR_EQ_AS_ASSIGN', - start = op.start, - finish = op.finish, - fix = { - title = 'FIX_EQ_AS_ASSIGN', - { - start = op.start, - finish = op.finish, - text = '==', - } - } - } - end - end - if BinaryAlias[token] then - if not State.options.nonstandardSymbol[token] then - pushError { - type = 'ERR_NONSTANDARD_SYMBOL', - start = op.start, - finish = op.finish, - info = { - symbol = symbol, - }, - fix = { - title = 'FIX_NONSTANDARD_SYMBOL', - symbol = symbol, - { - start = op.start, - finish = op.finish, - text = symbol, - }, - } - } - end - end - if token == '//' - or token == '<<' - or token == '>>' then - if State.version ~= 'Lua 5.3' - and State.version ~= 'Lua 5.4' then - pushError { - type = 'UNSUPPORT_SYMBOL', - version = {'Lua 5.3', 'Lua 5.4'}, - start = op.start, - finish = op.finish, - info = { - version = State.version, - } - } - end - end - Index = Index + 2 - return op, myLevel -end - -function parseExp(asAction, level) - local exp - local uop, uopLevel = parseUnaryOP() - if uop then - skipSpace() - local child = parseExp(asAction, uopLevel) - -- 预计算负数 - if uop.type == '-' - and child - and (child.type == 'number' or child.type == 'integer') then - child.start = uop.start - child[1] = - child[1] - exp = child - else - exp = { - type = 'unary', - op = uop, - start = uop.start, - finish = child and child.finish or uop.finish, - [1] = child, - } - if child then - child.parent = exp - else - missExp() - end - end - else - exp = parseExpUnit() - if not exp then - return nil - end - end - - while true do - skipSpace() - local bop, bopLevel = parseBinaryOP(asAction, level) - if not bop then - break - end - - ::AGAIN:: - skipSpace() - local isForward = SymbolForward[bopLevel] - local child = parseExp(asAction, isForward and (bopLevel + 0.5) or bopLevel) - if not child then - if skipUnknownSymbol() then - goto AGAIN - else - missExp() - end - end - local bin = { - type = 'binary', - start = exp.start, - finish = child and child.finish or bop.finish, - op = bop, - [1] = exp, - [2] = child - } - exp.parent = bin - if child then - child.parent = bin - end - exp = bin - end - - return exp -end - -local function skipSeps() - while true do - skipSpace() - if Tokens[Index + 1] == ',' then - missExp() - Index = Index + 2 - else - break - end - end -end - ----@return parser.object first ----@return parser.object second ----@return parser.object[] rest -local function parseSetValues() - skipSpace() - local first = parseExp() - if not first then - return nil - end - skipSpace() - if Tokens[Index + 1] ~= ',' then - return first - end - Index = Index + 2 - skipSeps() - local second = parseExp() - if not second then - missExp() - return first - end - skipSpace() - if Tokens[Index + 1] ~= ',' then - return first, second - end - Index = Index + 2 - skipSeps() - local third = parseExp() - if not third then - missExp() - return first, second - end - - local rest = { third } - while true do - skipSpace() - if Tokens[Index + 1] ~= ',' then - return first, second, rest - end - Index = Index + 2 - skipSeps() - local exp = parseExp() - if not exp then - missExp() - return first, second, rest - end - rest[#rest+1] = exp - end -end - -local function pushActionIntoCurrentChunk(action) - local chunk = Chunk[#Chunk] - if chunk then - chunk[#chunk+1] = action - action.parent = chunk - end -end - ----@return parser.object second ----@return parser.object[] rest -local function parseVarTails(parser, isLocal) - if Tokens[Index + 1] ~= ',' then - return - end - Index = Index + 2 - skipSpace() - local second = parser(true) - if not second then - missName() - return - end - if isLocal then - createLocal(second, parseLocalAttrs()) - second.effect = maxinteger - end - skipSpace() - if Tokens[Index + 1] ~= ',' then - return second - end - Index = Index + 2 - skipSeps() - local third = parser(true) - if not third then - missName() - return second - end - if isLocal then - createLocal(third, parseLocalAttrs()) - third.effect = maxinteger - end - local rest = { third } - while true do - skipSpace() - if Tokens[Index + 1] ~= ',' then - return second, rest - end - Index = Index + 2 - skipSeps() - local name = parser(true) - if not name then - missName() - return second, rest - end - if isLocal then - createLocal(name, parseLocalAttrs()) - name.effect = maxinteger - end - rest[#rest+1] = name - end -end - -local function bindValue(n, v, index, lastValue, isLocal, isSet) - if isLocal then - n.effect = lastRightPosition() - if v and v.special then - addSpecial(v.special, n) - end - elseif isSet then - n.type = GetToSetMap[n.type] or n.type - if n.type == 'setlocal' then - local loc = n.node - if loc.attrs then - pushError { - type = 'SET_CONST', - start = n.start, - finish = n.finish, - } - end - end - end - if not v and lastValue then - if lastValue.type == 'call' - or lastValue.type == 'varargs' then - v = lastValue - if not v.extParent then - v.extParent = {} - end - end - end - if v then - if v.type == 'call' - or v.type == 'varargs' then - local select = { - type = 'select', - sindex = index, - start = v.start, - finish = v.finish, - vararg = v - } - if v.parent then - v.extParent[#v.extParent+1] = select - else - v.parent = select - end - v = select - end - n.value = v - n.range = v.finish - v.parent = n - if isLocal then - n.effect = lastRightPosition() - end - end -end - -local function parseMultiVars(n1, parser, isLocal) - local n2, nrest = parseVarTails(parser, isLocal) - skipSpace() - local v1, v2, vrest - local isSet - local max = 1 - if expectAssign(not isLocal) then - v1, v2, vrest = parseSetValues() - isSet = true - if not v1 then - missExp() - end - end - bindValue(n1, v1, 1, nil, isLocal, isSet) - local lastValue = v1 - if n2 then - max = 2 - bindValue(n2, v2, 2, lastValue, isLocal, isSet) - lastValue = v2 or lastValue - pushActionIntoCurrentChunk(n2) - end - if nrest then - for i = 1, #nrest do - local n = nrest[i] - local v = vrest and vrest[i] - max = i + 2 - bindValue(n, v, max, lastValue, isLocal, isSet) - lastValue = v or lastValue - pushActionIntoCurrentChunk(n) - end - end - - if v2 and not n2 then - v2.redundant = { - max = max, - passed = 2, - } - pushActionIntoCurrentChunk(v2) - end - if vrest then - for i = 1, #vrest do - local v = vrest[i] - if not nrest or not nrest[i] then - v.redundant = { - max = max, - passed = i + 2, - } - pushActionIntoCurrentChunk(v) - end - end - end - - return n1, isSet -end - -local function compileExpAsAction(exp) - pushActionIntoCurrentChunk(exp) - if GetToSetMap[exp.type] then - skipSpace() - local action, isSet = parseMultiVars(exp, parseExp) - if isSet - or action.type == 'getmethod' then - return action - end - end - - if exp.type == 'call' then - return exp - end - - if exp.type == 'binary' then - if GetToSetMap[exp[1].type] then - local op = exp.op - if op.type == '==' then - pushError { - type = 'ERR_ASSIGN_AS_EQ', - start = op.start, - finish = op.finish, - fix = { - title = 'FIX_ASSIGN_AS_EQ', - { - start = op.start, - finish = op.finish, - text = '=', - } - } - } - return - end - end - end - - pushError { - type = 'EXP_IN_ACTION', - start = exp.start, - finish = exp.finish, - } - - return exp -end - -local function parseLocal() - local locPos = getPosition(Tokens[Index], 'left') - Index = Index + 2 - skipSpace() - local word = peekWord() - if not word then - missName() - return nil - end - - if word == 'function' then - local func = parseFunction(true, true) - local name = func.name - if name then - func.name = nil - name.value = func - name.vstart = func.start - name.range = func.finish - name.locPos = locPos - func.parent = name - pushActionIntoCurrentChunk(name) - return name - else - missName(func.keyword[2]) - pushActionIntoCurrentChunk(func) - return func - end - end - - local name = parseName(true) - if not name then - missName() - return nil - end - local loc = createLocal(name, parseLocalAttrs()) - loc.locPos = locPos - loc.effect = maxinteger - pushActionIntoCurrentChunk(loc) - skipSpace() - parseMultiVars(loc, parseName, true) - if loc.value then - loc.effect = loc.value.finish - else - loc.effect = loc.finish - end - - return loc -end - -local function parseDo() - local doLeft = getPosition(Tokens[Index], 'left') - local doRight = getPosition(Tokens[Index] + 1, 'right') - local obj = { - type = 'do', - start = doLeft, - finish = doRight, - keyword = { - [1] = doLeft, - [2] = doRight, - }, - } - Index = Index + 2 - pushActionIntoCurrentChunk(obj) - pushChunk(obj) - parseActions() - popChunk() - if Tokens[Index + 1] == 'end' then - obj.finish = getPosition(Tokens[Index] + 2, 'right') - obj.keyword[3] = getPosition(Tokens[Index], 'left') - obj.keyword[4] = getPosition(Tokens[Index] + 2, 'right') - Index = Index + 2 - else - missEnd(doLeft, doRight) - end - if obj.locals then - LocalCount = LocalCount - #obj.locals - end - - return obj -end - -local function parseReturn() - local returnLeft = getPosition(Tokens[Index], 'left') - local returnRight = getPosition(Tokens[Index] + 5, 'right') - Index = Index + 2 - skipSpace() - local rtn = parseExpList(true) - if rtn then - rtn.type = 'return' - rtn.start = returnLeft - else - rtn = { - type = 'return', - start = returnLeft, - finish = returnRight, - } - end - pushActionIntoCurrentChunk(rtn) - for i = #Chunk, 1, -1 do - local block = Chunk[i] - if block.type == 'function' - or block.type == 'main' then - if not block.returns then - block.returns = {} - end - block.returns[#block.returns+1] = rtn - break - end - end - for i = #Chunk, 1, -1 do - local block = Chunk[i] - if block.type == 'ifblock' - or block.type == 'elseifblock' - or block.type == 'else' then - block.hasReturn = true - break - end - end - - return rtn -end - -local function parseLabel() - local left = getPosition(Tokens[Index], 'left') - Index = Index + 2 - skipSpace() - local label = parseName() - skipSpace() - - if not label then - missName() - end - - if Tokens[Index + 1] == '::' then - Index = Index + 2 - else - if label then - missSymbol '::' - end - end - - if not label then - return nil - end - - label.type = 'label' - pushActionIntoCurrentChunk(label) - - local block = guide.getBlock(label) - if block then - if not block.labels then - block.labels = {} - end - local name = label[1] - local olabel = guide.getLabel(block, name) - if olabel then - if State.version == 'Lua 5.4' - or block == guide.getBlock(olabel) then - pushError { - type = 'REDEFINED_LABEL', - start = label.start, - finish = label.finish, - relative = { - { - olabel.start, - olabel.finish, - } - } - } - end - end - block.labels[name] = label - end - - if State.version == 'Lua 5.1' then - pushError { - type = 'UNSUPPORT_SYMBOL', - start = left, - finish = lastRightPosition(), - version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, - info = { - version = State.version, - } - } - return - end - return label -end - -local function parseGoTo() - local start = getPosition(Tokens[Index], 'left') - Index = Index + 2 - skipSpace() - - local action = parseName() - if not action then - missName() - return nil - end - - action.type = 'goto' - action.keyStart = start - - for i = #Chunk, 1, -1 do - local chunk = Chunk[i] - if chunk.type == 'function' - or chunk.type == 'main' then - if not chunk.gotos then - chunk.gotos = {} - end - chunk.gotos[#chunk.gotos+1] = action - break - end - end - for i = #Chunk, 1, -1 do - local chunk = Chunk[i] - if chunk.type == 'ifblock' - or chunk.type == 'elseifblock' - or chunk.type == 'elseblock' then - chunk.hasGoTo = true - break - end - end - - pushActionIntoCurrentChunk(action) - return action -end - -local function parseIfBlock(parent) - local ifLeft = getPosition(Tokens[Index], 'left') - local ifRight = getPosition(Tokens[Index] + 1, 'right') - Index = Index + 2 - local ifblock = { - type = 'ifblock', - parent = parent, - start = ifLeft, - finish = ifRight, - keyword = { - [1] = ifLeft, - [2] = ifRight, - } - } - skipSpace() - local filter = parseExp() - if filter then - ifblock.filter = filter - ifblock.finish = filter.finish - filter.parent = ifblock - else - missExp() - end - skipSpace() - local thenToken = Tokens[Index + 1] - if thenToken == 'then' - or thenToken == 'do' then - ifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right') - ifblock.keyword[3] = getPosition(Tokens[Index], 'left') - ifblock.keyword[4] = ifblock.finish - if thenToken == 'do' then - pushError { - type = 'ERR_THEN_AS_DO', - start = ifblock.keyword[3], - finish = ifblock.keyword[4], - fix = { - title = 'FIX_THEN_AS_DO', - { - start = ifblock.keyword[3], - finish = ifblock.keyword[4], - text = 'then', - } - } - } - end - Index = Index + 2 - else - missSymbol 'then' - end - pushChunk(ifblock) - parseActions() - popChunk() - ifblock.finish = lastRightPosition() - if ifblock.locals then - LocalCount = LocalCount - #ifblock.locals - end - return ifblock -end - -local function parseElseIfBlock(parent) - local ifLeft = getPosition(Tokens[Index], 'left') - local ifRight = getPosition(Tokens[Index] + 5, 'right') - local elseifblock = { - type = 'elseifblock', - parent = parent, - start = ifLeft, - finish = ifRight, - keyword = { - [1] = ifLeft, - [2] = ifRight, - } - } - Index = Index + 2 - skipSpace() - local filter = parseExp() - if filter then - elseifblock.filter = filter - elseifblock.finish = filter.finish - filter.parent = elseifblock - else - missExp() - end - skipSpace() - local thenToken = Tokens[Index + 1] - if thenToken == 'then' - or thenToken == 'do' then - elseifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right') - elseifblock.keyword[3] = getPosition(Tokens[Index], 'left') - elseifblock.keyword[4] = elseifblock.finish - if thenToken == 'do' then - pushError { - type = 'ERR_THEN_AS_DO', - start = elseifblock.keyword[3], - finish = elseifblock.keyword[4], - fix = { - title = 'FIX_THEN_AS_DO', - { - start = elseifblock.keyword[3], - finish = elseifblock.keyword[4], - text = 'then', - } - } - } - end - Index = Index + 2 - else - missSymbol 'then' - end - pushChunk(elseifblock) - parseActions() - popChunk() - elseifblock.finish = lastRightPosition() - if elseifblock.locals then - LocalCount = LocalCount - #elseifblock.locals - end - return elseifblock -end - -local function parseElseBlock(parent) - local ifLeft = getPosition(Tokens[Index], 'left') - local ifRight = getPosition(Tokens[Index] + 3, 'right') - local elseblock = { - type = 'elseblock', - parent = parent, - start = ifLeft, - finish = ifRight, - keyword = { - [1] = ifLeft, - [2] = ifRight, - } - } - Index = Index + 2 - skipSpace() - pushChunk(elseblock) - parseActions() - popChunk() - elseblock.finish = lastRightPosition() - if elseblock.locals then - LocalCount = LocalCount - #elseblock.locals - end - return elseblock -end - -local function parseIf() - local token = Tokens[Index + 1] - local left = getPosition(Tokens[Index], 'left') - local action = { - type = 'if', - start = left, - finish = getPosition(Tokens[Index] + #token - 1, 'right'), - } - pushActionIntoCurrentChunk(action) - if token ~= 'if' then - missSymbol('if', left, left) - end - local hasElse - while true do - local word = Tokens[Index + 1] - local child - if word == 'if' then - child = parseIfBlock(action) - elseif word == 'elseif' then - child = parseElseIfBlock(action) - elseif word == 'else' then - child = parseElseBlock(action) - end - if not child then - break - end - if hasElse then - pushError { - type = 'BLOCK_AFTER_ELSE', - start = child.start, - finish = child.finish, - } - end - if word == 'else' then - hasElse = true - end - action[#action+1] = child - action.finish = child.finish - skipSpace() - end - - if Tokens[Index + 1] == 'end' then - action.finish = getPosition(Tokens[Index] + 2, 'right') - Index = Index + 2 - else - missEnd(action[1].keyword[1], action[1].keyword[2]) - end - - return action -end - -local function parseFor() - local action = { - type = 'for', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + 2, 'right'), - keyword = {}, - } - action.keyword[1] = action.start - action.keyword[2] = action.finish - Index = Index + 2 - pushActionIntoCurrentChunk(action) - pushChunk(action) - skipSpace() - local nameOrList = parseNameOrList(action) - if not nameOrList then - missName() - end - skipSpace() - -- for i = - if expectAssign() then - action.type = 'loop' - - skipSpace() - local expList = parseExpList() - local name - if nameOrList then - if nameOrList.type == 'name' then - name = nameOrList - else - name = nameOrList[1] - end - end - if name then - local loc = createLocal(name) - loc.parent = action - action.finish = name.finish - action.loc = loc - end - if expList then - expList.parent = action - local value = expList[1] - if value then - value.parent = expList - action.init = value - action.finish = expList[#expList].finish - end - local max = expList[2] - if max then - max.parent = expList - action.max = max - action.finish = max.finish - else - pushError { - type = 'MISS_LOOP_MAX', - start = lastRightPosition(), - finish = lastRightPosition(), - } - end - local step = expList[3] - if step then - step.parent = expList - action.step = step - action.finish = step.finish - end - else - pushError { - type = 'MISS_LOOP_MIN', - start = lastRightPosition(), - finish = lastRightPosition(), - } - end - - if action.loc then - action.loc.effect = action.finish - end - elseif Tokens[Index + 1] == 'in' then - action.type = 'in' - local inLeft = getPosition(Tokens[Index], 'left') - local inRight = getPosition(Tokens[Index] + 1, 'right') - Index = Index + 2 - skipSpace() - - local exps = parseExpList() - - action.finish = inRight - action.keyword[3] = inLeft - action.keyword[4] = inRight - - local list - if nameOrList and nameOrList.type == 'name' then - list = { - type = 'list', - start = nameOrList.start, - finish = nameOrList.finish, - parent = action, - [1] = nameOrList, - } - else - list = nameOrList - end - - if exps then - local lastExp = exps[#exps] - if lastExp then - action.finish = lastExp.finish - end - - action.exps = exps - exps.parent = action - for i = 1, #exps do - local exp = exps[i] - exp.parent = exps - end - else - missExp() - end - - if list then - local lastName = list[#list] - list.range = lastName and lastName.range or inRight - action.keys = list - for i = 1, #list do - local loc = createLocal(list[i]) - loc.parent = action - loc.effect = action.finish - end - end - else - missSymbol 'in' - end - - skipSpace() - local doToken = Tokens[Index + 1] - if doToken == 'do' - or doToken == 'then' then - local left = getPosition(Tokens[Index], 'left') - local right = getPosition(Tokens[Index] + #doToken - 1, 'right') - action.finish = left - action.keyword[#action.keyword+1] = left - action.keyword[#action.keyword+1] = right - if doToken == 'then' then - pushError { - type = 'ERR_DO_AS_THEN', - start = left, - finish = right, - fix = { - title = 'FIX_DO_AS_THEN', - { - start = left, - finish = right, - text = 'do', - } - } - } - end - Index = Index + 2 - else - missSymbol 'do' - end - - skipSpace() - parseActions() - popChunk() - - skipSpace() - if Tokens[Index + 1] == 'end' then - action.finish = getPosition(Tokens[Index] + 2, 'right') - action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left') - action.keyword[#action.keyword+1] = action.finish - Index = Index + 2 - else - missEnd(action.keyword[1], action.keyword[2]) - end - - if action.locals then - LocalCount = LocalCount - #action.locals - end - - return action -end - -local function parseWhile() - local action = { - type = 'while', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + 4, 'right'), - keyword = {}, - } - action.keyword[1] = action.start - action.keyword[2] = action.finish - Index = Index + 2 - - skipSpace() - local nextToken = Tokens[Index + 1] - local filter = nextToken ~= 'do' - and nextToken ~= 'then' - and parseExp() - if filter then - action.filter = filter - action.finish = filter.finish - filter.parent = action - else - missExp() - end - - skipSpace() - local doToken = Tokens[Index + 1] - if doToken == 'do' - or doToken == 'then' then - local left = getPosition(Tokens[Index], 'left') - local right = getPosition(Tokens[Index] + #doToken - 1, 'right') - action.finish = left - action.keyword[#action.keyword+1] = left - action.keyword[#action.keyword+1] = right - if doToken == 'then' then - pushError { - type = 'ERR_DO_AS_THEN', - start = left, - finish = right, - fix = { - title = 'FIX_DO_AS_THEN', - { - start = left, - finish = right, - text = 'do', - } - } - } - end - Index = Index + 2 - else - missSymbol 'do' - end - - pushActionIntoCurrentChunk(action) - pushChunk(action) - skipSpace() - parseActions() - popChunk() - - skipSpace() - if Tokens[Index + 1] == 'end' then - action.finish = getPosition(Tokens[Index] + 2, 'right') - action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left') - action.keyword[#action.keyword+1] = action.finish - Index = Index + 2 - else - missEnd(action.keyword[1], action.keyword[2]) - end - - if action.locals then - LocalCount = LocalCount - #action.locals - end - - return action -end - -local function parseRepeat() - local action = { - type = 'repeat', - start = getPosition(Tokens[Index], 'left'), - finish = getPosition(Tokens[Index] + 5, 'right'), - keyword = {}, - } - action.keyword[1] = action.start - action.keyword[2] = action.finish - Index = Index + 2 - - pushActionIntoCurrentChunk(action) - pushChunk(action) - skipSpace() - parseActions() - - skipSpace() - if Tokens[Index + 1] == 'until' then - action.finish = getPosition(Tokens[Index] + 4, 'right') - action.keyword[#action.keyword+1] = getPosition(Tokens[Index], 'left') - action.keyword[#action.keyword+1] = action.finish - Index = Index + 2 - - skipSpace() - local filter = parseExp() - if filter then - action.filter = filter - filter.parent = action - else - missExp() - end - - else - missSymbol 'until' - end - - popChunk() - if action.filter then - action.finish = action.filter.finish - end - - if action.locals then - LocalCount = LocalCount - #action.locals - end - - return action -end - -local function parseBreak() - local returnLeft = getPosition(Tokens[Index], 'left') - local returnRight = getPosition(Tokens[Index] + #Tokens[Index + 1] - 1, 'right') - Index = Index + 2 - skipSpace() - local action = { - type = 'break', - start = returnLeft, - finish = returnRight, - } - - local ok - for i = #Chunk, 1, -1 do - local chunk = Chunk[i] - if chunk.type == 'function' then - break - end - if chunk.type == 'while' - or chunk.type == 'in' - or chunk.type == 'loop' - or chunk.type == 'repeat' - or chunk.type == 'for' then - if not chunk.breaks then - chunk.breaks = {} - end - chunk.breaks[#chunk.breaks+1] = action - ok = true - break - end - end - for i = #Chunk, 1, -1 do - local chunk = Chunk[i] - if chunk.type == 'ifblock' - or chunk.type == 'elseifblock' - or chunk.type == 'elseblock' then - chunk.hasBreak = true - break - end - end - if not ok and Mode == 'Lua' then - pushError { - type = 'BREAK_OUTSIDE', - start = action.start, - finish = action.finish, - } - end - - pushActionIntoCurrentChunk(action) - return action -end - -function parseAction() - local token = Tokens[Index + 1] - - if token == '::' then - return parseLabel() - end - - if token == 'local' then - return parseLocal() - end - - if token == 'if' - or token == 'elseif' - or token == 'else' then - return parseIf() - end - - if token == 'for' then - return parseFor() - end - - if token == 'do' then - return parseDo() - end - - if token == 'return' then - return parseReturn() - end - - if token == 'break' then - return parseBreak() - end - - if token == 'continue' and State.options.nonstandardSymbol['continue'] then - return parseBreak() - end - - if token == 'while' then - return parseWhile() - end - - if token == 'repeat' then - return parseRepeat() - end - - if token == 'goto' and isKeyWord 'goto' then - return parseGoTo() - end - - if token == 'function' then - local exp = parseFunction(false, true) - local name = exp.name - if name then - exp.name = nil - name.type = GetToSetMap[name.type] - name.value = exp - name.vstart = exp.start - name.range = exp.finish - exp.parent = name - if name.type == 'setlocal' then - local loc = name.node - if loc.attrs then - pushError { - type = 'SET_CONST', - start = name.start, - finish = name.finish, - } - end - end - pushActionIntoCurrentChunk(name) - return name - else - pushActionIntoCurrentChunk(exp) - missName(exp.keyword[2]) - return exp - end - end - - local exp = parseExp(true) - if exp then - local action = compileExpAsAction(exp) - if action then - return action - end - end - return nil, true -end - -local function skipFirstComment() - if Tokens[Index + 1] ~= '#' then - return - end - while true do - Index = Index + 2 - local token = Tokens[Index + 1] - if not token then - break - end - if NLMap[token] then - skipNL() - break - end - end -end - -local function parseLua() - local main = { - type = 'main', - start = 0, - finish = 0, - } - pushChunk(main) - createLocal{ - type = 'local', - start = -1, - finish = -1, - effect = -1, - parent = main, - tag = '_ENV', - special= '_G', - [1] = State.ENVMode, - } - LocalCount = 0 - skipFirstComment() - while true do - parseActions() - if Index <= #Tokens then - unknownSymbol() - Index = Index + 2 - else - break - end - end - popChunk() - main.finish = getPosition(#Lua, 'right') - - return main -end - -local function initState(lua, version, options) - Lua = lua - Line = 0 - LineOffset = 1 - LastTokenFinish = 0 - LocalCount = 0 - Chunk = {} - Tokens = tokens(lua) - Index = 1 - local state = { - version = version, - lua = lua, - ast = {}, - errs = {}, - comms = {}, - lines = { - [0] = 1, - }, - options = options or {}, - } - if not state.options.nonstandardSymbol then - state.options.nonstandardSymbol = {} - end - State = state - if version == 'Lua 5.1' or version == 'LuaJIT' then - state.ENVMode = '@fenv' - else - state.ENVMode = '_ENV' - end - - pushError = function (err) - local errs = state.errs - if err.finish < err.start then - err.finish = err.start - end - local last = errs[#errs] - if last then - if last.start <= err.start and last.finish >= err.finish then - return - end - end - err.level = err.level or 'Error' - errs[#errs+1] = err - return err - end - - state.pushError = pushError -end - -return function (lua, mode, version, options) - Mode = mode - initState(lua, version, options) - skipSpace() - if mode == 'Lua' then - State.ast = parseLua() - elseif mode == 'Nil' then - State.ast = parseNil() - elseif mode == 'Boolean' then - State.ast = parseBoolean() - elseif mode == 'String' then - State.ast = parseString() - elseif mode == 'Number' then - State.ast = parseNumber() - elseif mode == 'Name' then - State.ast = parseName() - elseif mode == 'Exp' then - State.ast = parseExp() - elseif mode == 'Action' then - State.ast = parseAction() - end - - if State.ast then - State.ast.state = State - end - - while true do - if Index <= #Tokens then - unknownSymbol() - Index = Index + 2 - else - break - end - end - - return State -end diff --git a/script/parser/parse.lua b/script/parser/parse.lua deleted file mode 100644 index e7c7d177..00000000 --- a/script/parser/parse.lua +++ /dev/null @@ -1,63 +0,0 @@ -local ast = require 'parser.ast' -local grammar = require 'parser.grammar' - -local function buildState(lua, version, options) - local errs = {} - local diags = {} - local comms = {} - local state = { - version = version, - lua = lua, - root = {}, - errs = errs, - diags = diags, - comms = comms, - options = options or {}, - pushError = function (err) - if err.finish < err.start then - err.finish = err.start - end - local last = errs[#errs] - if last then - if last.start <= err.start and last.finish >= err.finish then - return - end - end - err.level = err.level or 'error' - errs[#errs+1] = err - return err - end, - pushDiag = function (code, info) - if not diags[code] then - diags[code] = {} - end - diags[code][#diags[code]+1] = info - end, - pushComment = function (comment) - comms[#comms+1] = comment - end - } - if version == 'Lua 5.1' or version == 'LuaJIT' then - state.ENVMode = '@fenv' - else - state.ENVMode = '_ENV' - end - return state -end - -return function (lua, mode, version, options) - local state = buildState(lua, version, options) - local clock = os.clock() - ast.init(state) - local suc, res, err = xpcall(grammar, debug.traceback, lua, mode) - ast.close() - if not suc then - return nil, res - end - if not res and err then - state.pushError(err) - end - state.ast = res - state.parseClock = os.clock() - clock - return state -end diff --git a/script/parser/split.lua b/script/parser/split.lua deleted file mode 100644 index 6ce4a4e7..00000000 --- a/script/parser/split.lua +++ /dev/null @@ -1,9 +0,0 @@ -local m = require 'lpeglabel' - -local NL = m.P'\r\n' + m.S'\r\n' -local LINE = m.C(1 - NL) - -return function (str) - local MATCH = m.Ct((LINE * NL)^0 * LINE) - return MATCH:match(str) -end diff --git a/script/parser/tokens.lua b/script/parser/tokens.lua index 958f292e..a4de7f88 100644 --- a/script/parser/tokens.lua +++ b/script/parser/tokens.lua @@ -7,6 +7,11 @@ local Word = m.R('AZ', 'az', '__', '\x80\xff') * m.R('AZ', 'az', '09', '__', ' local Symbol = m.P'==' + m.P'~=' + m.P'--' + -- non-standard: + + m.P'<<=' + + m.P'>>=' + + m.P'//=' + -- end non-standard + m.P'<<' + m.P'>>' + m.P'<=' @@ -15,7 +20,7 @@ local Symbol = m.P'==' + m.P'...' + m.P'..' + m.P'::' - -- incorrect + -- non-standard: + m.P'!=' + m.P'&&' + m.P'||' @@ -24,7 +29,12 @@ local Symbol = m.P'==' + m.P'+=' + m.P'-=' + m.P'*=' + + m.P'%=' + + m.P'&=' + + m.P'|=' + + m.P'^=' + m.P'/=' + -- end non-standard -- singles + m.S'+-*/!#%^&()={}[]|\\\'":;<>,.?~`' local Unknown = (1 - Number - Word - Symbol - Sp - Nl)^1 diff --git a/script/plugin.lua b/script/plugin.lua index 145abe74..870b68b6 100644 --- a/script/plugin.lua +++ b/script/plugin.lua @@ -4,6 +4,8 @@ local client = require 'client' local lang = require 'language' local await = require 'await' local scope = require 'workspace.scope' +local ws = require 'workspace' +local fs = require 'bee.filesystem' ---@class plugin local m = {} @@ -69,20 +71,35 @@ local function checkTrustLoad(scp) return true end ----@param scp scope -function m.init(scp) +---@param uri uri +local function initPlugin(uri) await.call(function () ---@async - local ws = require 'workspace' + local scp = scope.getScope(uri) local interface = {} scp:set('pluginInterface', interface) + if not scp.uri then + return + end + local pluginPath = ws.getAbsolutePath(scp.uri, config.get(scp.uri, 'Lua.runtime.plugin')) log.info('plugin path:', pluginPath) if not pluginPath then return end + + --Adding the plugins path to package.path allows for requires in files + --to find files relative to itself. + local oldPath = package.path + local path = fs.path(pluginPath):parent_path() / '?.lua' + if not package.path:find(path:string(), 1, true) then + package.path = package.path .. ';' .. path:string() + end + local pluginLua = util.loadFile(pluginPath) if not pluginLua then + log.warn('plugin not found:', pluginPath) + package.path = oldPath return end @@ -98,7 +115,8 @@ function m.init(scp) if not client.isVSCode() and not checkTrustLoad(scp) then return end - local suc, err = xpcall(f, log.error, f) + local pluginArgs = config.get(scp.uri, 'Lua.runtime.pluginArgs') + local suc, err = xpcall(f, log.error, f, uri, pluginArgs) if not suc then m.showError(scp, err) return @@ -108,4 +126,10 @@ function m.init(scp) end) end +ws.watch(function (ev, uri) + if ev == 'startReload' then + initPlugin(uri) + end +end) + return m diff --git a/script/progress.lua b/script/progress.lua index b43ed05b..f1f371f5 100644 --- a/script/progress.lua +++ b/script/progress.lua @@ -11,10 +11,10 @@ local m = {} m.map = {} ---@class progress ----@field _uri uri +---@field _uri uri +---@field _token integer local mt = {} mt.__index = mt -mt._token = nil mt._title = nil mt._message = nil mt._removed = false diff --git a/script/proto/converter.lua b/script/proto/converter.lua index 9c75f056..3f5ddebc 100644 --- a/script/proto/converter.lua +++ b/script/proto/converter.lua @@ -13,7 +13,7 @@ local function rawPackPosition(uri, pos) if col > 0 then local state = files.getState(uri) local text = files.getText(uri) - if text then + if state and text then local lineOffset = state.lines[row] if lineOffset then local start = lineOffset diff --git a/script/proto/define.lua b/script/proto/define.lua index fb60c56c..ecdaf306 100644 --- a/script/proto/define.lua +++ b/script/proto/define.lua @@ -1,3 +1,5 @@ +local diag = require 'proto.diagnostic' + local m = {} --- 诊断等级 @@ -8,122 +10,21 @@ m.DiagnosticSeverity = { Hint = 4, } ----@alias DiagnosticDefaultSeverity ----| 'Hint' ----| 'Information' ----| 'Warning' ----| 'Error' - ---- 诊断类型与默认等级 ----@type table<string, DiagnosticDefaultSeverity> -m.DiagnosticDefaultSeverity = { - ['unused-local'] = 'Hint', - ['unused-function'] = 'Hint', - ['undefined-global'] = 'Warning', - ['undefined-field'] = 'Warning', - ['global-in-nil-env'] = 'Warning', - ['unused-label'] = 'Hint', - ['unused-vararg'] = 'Hint', - ['trailing-space'] = 'Hint', - ['redefined-local'] = 'Hint', - ['newline-call'] = 'Information', - ['newfield-call'] = 'Warning', - ['redundant-parameter'] = 'Warning', - ['missing-parameter'] = 'Warning', - ['redundant-return'] = 'Warning', - ['ambiguity-1'] = 'Warning', - ['lowercase-global'] = 'Information', - ['undefined-env-child'] = 'Information', - ['duplicate-index'] = 'Warning', - ['duplicate-set-field'] = 'Warning', - ['empty-block'] = 'Hint', - ['redundant-value'] = 'Warning', - ['code-after-break'] = 'Hint', - ['unbalanced-assignments'] = 'Warning', - ['close-non-object'] = 'Warning', - ['count-down-loop'] = 'Warning', - ['no-unknown'] = 'Information', - ['deprecated'] = 'Warning', - ['different-requires'] = 'Warning', - ['await-in-sync'] = 'Warning', - ['not-yieldable'] = 'Warning', - ['discard-returns'] = 'Warning', - ['need-check-nil'] = 'Warning', - ['type-check'] = 'Warning', - - ['duplicate-doc-alias'] = 'Warning', - ['undefined-doc-class'] = 'Warning', - ['undefined-doc-name'] = 'Warning', - ['circle-doc-class'] = 'Warning', - ['undefined-doc-param'] = 'Warning', - ['duplicate-doc-param'] = 'Warning', - ['doc-field-no-class'] = 'Warning', - ['duplicate-doc-field'] = 'Warning', - ['unknown-diag-code'] = 'Warning', - - ['codestyle-check'] = "Warning", +m.DiagnosticFileStatus = { + Any = 1, + Opened = 2, + None = 3, } ----@alias DiagnosticDefaultNeededFileStatus ----| 'Any' ----| 'Opened' ----| 'None' - --- 文件状态 -m.FileStatus = { - Any = 1, - Opened = 2, -} +--- 诊断类型与默认等级 +m.DiagnosticDefaultSeverity = diag.getDefaultSeverity() --- 诊断类型与需要的文件状态(可以控制只分析打开的文件、还是所有文件) ----@type table<string, DiagnosticDefaultNeededFileStatus> -m.DiagnosticDefaultNeededFileStatus = { - ['unused-local'] = 'Opened', - ['unused-function'] = 'Opened', - ['undefined-global'] = 'Any', - ['undefined-field'] = 'Opened', - ['global-in-nil-env'] = 'Any', - ['unused-label'] = 'Opened', - ['unused-vararg'] = 'Opened', - ['trailing-space'] = 'Opened', - ['redefined-local'] = 'Opened', - ['newline-call'] = 'Any', - ['newfield-call'] = 'Any', - ['redundant-parameter'] = 'Opened', - ['missing-parameter'] = 'Opened', - ['redundant-return'] = 'Opened', - ['ambiguity-1'] = 'Any', - ['lowercase-global'] = 'Any', - ['undefined-env-child'] = 'Any', - ['duplicate-index'] = 'Any', - ['duplicate-set-field'] = 'Any', - ['empty-block'] = 'Opened', - ['redundant-value'] = 'Opened', - ['code-after-break'] = 'Opened', - ['unbalanced-assignments'] = 'Any', - ['close-non-object'] = 'Any', - ['count-down-loop'] = 'Any', - ['no-unknown'] = 'None', - ['deprecated'] = 'Opened', - ['different-requires'] = 'Any', - ['await-in-sync'] = 'None', - ['not-yieldable'] = 'None', - ['discard-returns'] = 'Opened', - ['need-check-nil'] = 'Opened', - ['type-check'] = 'None', +m.DiagnosticDefaultNeededFileStatus = diag.getDefaultStatus() - ['duplicate-doc-alias'] = 'Any', - ['undefined-doc-class'] = 'Any', - ['undefined-doc-name'] = 'Any', - ['circle-doc-class'] = 'Any', - ['undefined-doc-param'] = 'Any', - ['duplicate-doc-param'] = 'Any', - ['doc-field-no-class'] = 'Any', - ['duplicate-doc-field'] = 'Any', - ['unknown-diag-code'] = 'Any', +m.DiagnosticDefaultGroupSeverity = diag.getGroupSeverity() - ['codestyle-check'] = 'None', -} +m.DiagnosticDefaultGroupFileStatus = diag.getGroupStatus() --- 诊断报告标签 m.DiagnosticTag = { @@ -260,24 +161,27 @@ m.TokenTypes = { ["number"] = 19, ["regexp"] = 20, ["operator"] = 21, + ["decorator"] = 22, } m.BuiltIn = { - ['basic'] = 'default', - ['bit'] = 'default', - ['bit32'] = 'default', - ['builtin'] = 'default', - ['coroutine'] = 'default', - ['debug'] = 'default', - ['ffi'] = 'default', - ['io'] = 'default', - ['jit'] = 'default', - ['math'] = 'default', - ['os'] = 'default', - ['package'] = 'default', - ['string'] = 'default', - ['table'] = 'default', - ['utf8'] = 'default', + ['basic'] = 'default', + ['bit'] = 'default', + ['bit32'] = 'default', + ['builtin'] = 'default', + ['coroutine'] = 'default', + ['debug'] = 'default', + ['ffi'] = 'default', + ['io'] = 'default', + ['jit'] = 'default', + ['math'] = 'default', + ['os'] = 'default', + ['package'] = 'default', + ['string'] = 'default', + ['table'] = 'default', + ['table.new'] = 'default', + ['table.clear'] = 'default', + ['utf8'] = 'default', } m.InlayHintKind = { diff --git a/script/proto/diagnostic.lua b/script/proto/diagnostic.lua new file mode 100644 index 00000000..9b0303cc --- /dev/null +++ b/script/proto/diagnostic.lua @@ -0,0 +1,267 @@ +local util = require 'utility' + +---@class proto.diagnostic +local m = {} + +---@alias DiagnosticSeverity +---| 'Hint' +---| 'Information' +---| 'Warning' +---| 'Error' + +---@alias DiagnosticNeededFileStatus +---| 'Any' +---| 'Opened' +---| 'None' + +---@class proto.diagnostic.info +---@field severity DiagnosticSeverity +---@field status DiagnosticNeededFileStatus +---@field group string + +m.diagnosticDatas = {} +m.diagnosticGroups = {} + +function m.register(names) + ---@param info proto.diagnostic.info + return function (info) + for _, name in ipairs(names) do + m.diagnosticDatas[name] = { + severity = info.severity, + status = info.status, + } + if not m.diagnosticGroups[info.group] then + m.diagnosticGroups[info.group] = {} + end + m.diagnosticGroups[info.group][name] = true + end + end +end + +m.register { + 'unused-local', + 'unused-function', + 'unused-label', + 'unused-vararg', + 'trailing-space', + 'redundant-return', + 'empty-block', + 'code-after-break', + 'unreachable-code', +} { + group = 'unused', + severity = 'Hint', + status = 'Opened', +} + +m.register { + 'redundant-value', + 'unbalanced-assignments', + 'redundant-parameter', + 'missing-parameter', + 'missing-return-value', + 'redundant-return-value', + 'missing-return', +} { + group = 'unbalanced', + severity = 'Warning', + status = 'Any', +} + +m.register { + 'need-check-nil', + 'undefined-field', + 'cast-local-type', + 'assign-type-mismatch', + 'param-type-mismatch', + 'cast-type-mismatch', + 'return-type-mismatch', +} { + group = 'type-check', + severity = 'Warning', + status = 'Opened', +} + +m.register { + 'duplicate-doc-alias', + 'undefined-doc-class', + 'undefined-doc-name', + 'circle-doc-class', + 'undefined-doc-param', + 'duplicate-doc-param', + 'doc-field-no-class', + 'duplicate-doc-field', + 'unknown-diag-code', + 'unknown-cast-variable', + 'unknown-operator', +} { + group = 'luadoc', + severity = 'Warning', + status = 'Any', +} + +m.register { + 'codestyle-check' +} { + group = 'codestyle', + severity = 'Warning', + status = 'None', +} + +m.register { + 'spell-check' +} { + group = 'codestyle', + severity = 'Information', + status = 'None', +} + +m.register { + 'newline-call', + 'newfield-call', + 'ambiguity-1', + 'count-down-loop', + 'different-requires', +} { + group = 'ambiguity', + severity = 'Warning', + status = 'Any', +} + +m.register { + 'await-in-sync', + 'not-yieldable', +} { + group = 'await', + severity = 'Warning', + status = 'None', +} + +m.register { + 'no-unknown', +} { + group = 'strong', + severity = 'Warning', + status = 'None', +} + +m.register { + 'redefined-local', +} { + group = 'redefined', + severity = 'Hint', + status = 'Opened', +} + +m.register { + 'undefined-global', + 'global-in-nil-env', +} { + group = 'global', + severity = 'Warning', + status = 'Any', +} + +m.register { + 'lowercase-global', + 'undefined-env-child', +} { + group = 'global', + severity = 'Information', + status = 'Any', +} + +m.register { + 'duplicate-index', + 'duplicate-set-field', +} { + group = 'duplicate', + severity = 'Warning', + status = 'Any', +} + +m.register { + 'close-non-object', + 'deprecated', + 'discard-returns', +} { + group = 'strict', + severity = 'Warning', + status = 'Any', +} + +---@return table<string, DiagnosticSeverity> +function m.getDefaultSeverity() + local severity = {} + for name, info in pairs(m.diagnosticDatas) do + severity[name] = info.severity + end + return severity +end + +---@return table<string, DiagnosticNeededFileStatus> +function m.getDefaultStatus() + local status = {} + for name, info in pairs(m.diagnosticDatas) do + status[name] = info.status + end + return status +end + +function m.getGroupSeverity() + local group = {} + for name in pairs(m.diagnosticGroups) do + group[name] = 'Fallback' + end + return group +end + +function m.getGroupStatus() + local group = {} + for name in pairs(m.diagnosticGroups) do + group[name] = 'Fallback' + end + return group +end + +---@param name string +---@return string[] +m.getGroups = util.cacheReturn(function (name) + local groups = {} + for groupName, nameMap in pairs(m.diagnosticGroups) do + if nameMap[name] then + groups[#groups+1] = groupName + end + end + table.sort(groups) + return groups +end) + +---@return table<string, true> +function m.getDiagAndErrNameMap() + if not m._diagAndErrNames then + local names = {} + for name in pairs(m.getDefaultSeverity()) do + names[name] = true + end + local path = package.searchpath('parser.compile', package.path) + if path then + local f = io.open(path) + if f then + for line in f:lines() do + local name = line:match([=[type%s*=%s*['"](%u[%u_]+%u)['"]]=]) + if name then + local id = name:lower():gsub('_', '-') + names[id] = true + end + end + f:close() + end + end + table.sort(names) + m._diagAndErrNames = names + end + return m._diagAndErrNames +end + +return m diff --git a/script/provider/build-meta.lua b/script/provider/build-meta.lua new file mode 100644 index 00000000..baabe39c --- /dev/null +++ b/script/provider/build-meta.lua @@ -0,0 +1,155 @@ +local fs = require 'bee.filesystem' +local config = require 'config' +local util = require 'utility' +local await = require 'await' +local progress = require 'progress' +local lang = require 'language' + +local m = {} + +---@class meta +---@field root string +---@field classes meta.class[] + +---@class meta.class +---@field name string +---@field comment string +---@field location string +---@field namespace string +---@field baseClass string +---@field attribute string +---@field integerface string[] +---@field fields meta.field[] +---@field methods meta.method[] + +---@class meta.field +---@field name string +---@field typeName string +---@field comment string +---@field location string + +---@class meta.method +---@field name string +---@field comment string +---@field location string +---@field isStatic boolean +---@field returnTypeName string +---@field params {name: string, typeName: string}[] + +---@param ... string +---@return string +local function mergeString(...) + local buf = {} + for i = 1, select('#', ...) do + local str = select(i, ...) + if str ~= '' then + buf[#buf+1] = str + end + end + return table.concat(buf, '.') +end + +local function addComments(lines, comment) + if comment == '' then + return + end + lines[#lines+1] = '--' + lines[#lines+1] = '--' .. comment:gsub('[\r\n]+$', ''):gsub('\n', '\n--') + lines[#lines+1] = '--' +end + +---@param lines string[] +---@param name string +---@param method meta.method +local function addMethod(lines, name, method) + if not method.name:match '^[%a_][%w_]*$' then + return + end + addComments(lines, method.comment) + lines[#lines+1] = ('---@source %s'):format(method.location:gsub('#', ':')) + local params = {} + for _, param in ipairs(method.params) do + lines[#lines+1] = ('---@param %s %s'):format(param.name, param.typeName) + params[#params+1] = param.name + end + if method.returnTypeName ~= '' + and method.returnTypeName ~= 'Void' then + lines[#lines+1] = ('---@return %s'):format(method.returnTypeName) + end + lines[#lines+1] = ('function %s%s%s(%s) end'):format( + name, + method.isStatic and ':' or '.', + method.name, + table.concat(params, ', ') + ) + lines[#lines+1] = '' +end + +---@param root string +---@param class meta.class +---@return string +local function buildText(root, class) + local lines = {} + + addComments(lines, class.comment) + lines[#lines+1] = ('---@source %s'):format(class.location:gsub('#', ':')) + if class.baseClass == '' then + lines[#lines+1] = ('---@class %s'):format(mergeString(class.namespace, class.name)) + else + lines[#lines+1] = ('---@class %s: %s'):format(mergeString(class.namespace, class.name), class.baseClass) + end + + for _, field in ipairs(class.fields) do + addComments(lines, field.comment) + lines[#lines+1] = ('---@source %s'):format(field.location:gsub('#', ':')) + lines[#lines+1] = ('---@field %s %s'):format(field.name, field.typeName) + end + + lines[#lines+1] = ('---@source %s'):format(class.location:gsub('#', ':')) + local name = mergeString(root, class.namespace, class.name) + lines[#lines+1] = ('%s = {}'):format(name) + lines[#lines+1] = '' + + for _, method in ipairs(class.methods) do + addMethod(lines, name, method) + end + + return table.concat(lines, '\n') +end + +local function buildRootText(api) + local lines = {} + + lines[#lines+1] = ('---@class %s'):format(api.root) + lines[#lines+1] = ('%s = {}'):format(api.root) + lines[#lines+1] = '' + return table.concat(lines, '\n') +end + +---@async +---@param path string +---@param api meta +function m.build(path, api) + + local files = util.multiTable(2, function () + return { '---@meta' } + end) + + files[api.root][#files[api.root]+1] = buildRootText(api) + + local proc <close> = progress.create(nil, lang.script.WINDOW_PROCESSING_BUILD_META, 0.5) + for i, class in ipairs(api.classes) do + local space = class.namespace ~= '' and class.namespace or api.root + proc:setMessage(space) + proc:setPercentage(i / #api.classes * 100) + local text = buildText(api.root, class) + files[space][#files[space]+1] = text + await.delay() + end + + for space, texts in pairs(files) do + util.saveFile(path .. '/' .. space .. '.lua', table.concat(texts, '\n\n')) + end +end + +return m diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua index 15b08d49..46ea600f 100644 --- a/script/provider/diagnostic.lua +++ b/script/provider/diagnostic.lua @@ -14,6 +14,10 @@ local loading = require 'workspace.loading' local scope = require 'workspace.scope' local time = require 'bee.time' local ltable = require 'linked-table' +local furi = require 'file-uri' +local json = require 'json' +local fw = require 'filewatch' +local vm = require 'vm.vm' ---@class diagnosticProvider local m = {} @@ -29,6 +33,9 @@ end local function buildSyntaxError(uri, err) local text = files.getText(uri) + if not text then + return + end local message = lang.script('PARSER_' .. err.type, err.info) if err.version then @@ -80,10 +87,14 @@ local function buildDiagnostic(uri, diag) relatedInformation = {} for _, rel in ipairs(diag.related) do local rtext = files.getText(rel.uri) + if not rtext then + goto CONTINUE + end relatedInformation[#relatedInformation+1] = { message = rel.message or rtext:sub(rel.start, rel.finish), location = converter.location(rel.uri, converter.packRange(rel.uri, rel.start, rel.finish)) } + ::CONTINUE:: end end @@ -136,9 +147,9 @@ local function mergeDiags(a, b, c) end -- enable `push`, disable `clear` -function m.clear(uri) +function m.clear(uri, force) await.close('diag:' .. uri) - if m.cache[uri] == nil then + if m.cache[uri] == nil and not force then return end m.cache[uri] = nil @@ -149,14 +160,27 @@ function m.clear(uri) log.info('clearDiagnostics', uri) end --- enable `push` and `send` -function m.clearCache(uri) - m.cache[uri] = false +function m.clearCacheExcept(uris) + local excepts = {} + for _, uri in ipairs(uris) do + excepts[uri] = true + end + for uri in pairs(m.cache) do + if not excepts[uri] then + m.cache[uri] = false + end + end end -function m.clearAll() - for luri in pairs(m.cache) do - m.clear(luri) +function m.clearAll(force) + if force then + for luri in files.eachFile() do + m.clear(luri, force) + end + else + for luri in pairs(m.cache) do + m.clear(luri) + end end end @@ -168,9 +192,11 @@ function m.syntaxErrors(uri, ast) local results = {} pcall(function () - local disables = config.get(uri, 'Lua.diagnostics.disable') + local disables = util.arrayToHash(config.get(uri, 'Lua.diagnostics.disable')) for _, err in ipairs(ast.errs) do - if not disables[err.type:lower():gsub('_', '-')] then + local id = err.type:lower():gsub('_', '-') + if not disables[id] + and not vm.isDiagDisabledAt(uri, err.start, id, true) then results[#results+1] = buildSyntaxError(uri, err) end end @@ -193,30 +219,45 @@ local function copyDiagsWithoutSyntax(diags) end ---@async -function m.doDiagnostic(uri, isScopeDiag) +---@param uri uri +---@return boolean +local function isValid(uri) if not config.get(uri, 'Lua.diagnostics.enable') then - return + return false end if files.isLibrary(uri, true) then local status = config.get(uri, 'Lua.diagnostics.libraryFiles') if status == 'Disable' then - return + return false elseif status == 'Opened' then if not files.isOpen(uri) then - return + return false end end end if ws.isIgnored(uri) then local status = config.get(uri, 'Lua.diagnostics.ignoredFiles') if status == 'Disable' then - return + return false elseif status == 'Opened' then if not files.isOpen(uri) then - return + return false end end end + local scheme = furi.split(uri) + local disableScheme = config.get(uri, 'Lua.diagnostics.disableScheme') + if util.arrayHas(disableScheme, scheme) then + return false + end + return true +end + +---@async +function m.doDiagnostic(uri, isScopeDiag) + if not isValid(uri) then + return + end await.delay() @@ -267,7 +308,7 @@ function m.doDiagnostic(uri, isScopeDiag) xpcall(core, log.error, uri, isScopeDiag, function (result) diags[#diags+1] = buildDiagnostic(uri, result) - if not isScopeDiag and time.time() - lastPushClock >= 200 then + if not isScopeDiag and time.time() - lastPushClock >= 500 then lastPushClock = time.time() pushResult() end @@ -287,20 +328,78 @@ function m.doDiagnostic(uri, isScopeDiag) pushResult() end +---@param uri uri +function m.resendDiagnostic(uri) + local full = m.cache[uri] + if not full then + return + end + + local version = files.getVersion(uri) + + proto.notify('textDocument/publishDiagnostics', { + uri = uri, + version = version, + diagnostics = full, + }) + log.debug('publishDiagnostics', uri, #full) +end + +---@async +---@return table|nil result +---@return boolean? unchanged +function m.pullDiagnostic(uri, isScopeDiag) + if not isValid(uri) then + return nil, util.equal(m.cache[uri], nil) + end + + await.delay() + + local state = files.getState(uri) + if not state then + return nil, util.equal(m.cache[uri], nil) + end + + local prog <close> = progress.create(uri, lang.script.WINDOW_DIAGNOSING, 0.5) + prog:setMessage(ws.getRelativePath(uri)) + + local syntax = m.syntaxErrors(uri, state) + local diags = {} + + xpcall(core, log.error, uri, isScopeDiag, function (result) + diags[#diags+1] = buildDiagnostic(uri, result) + end) + + local full = mergeDiags(syntax, diags) + + if util.equal(m.cache[uri], full) then + return full, true + end + + m.cache[uri] = full + + return full +end + +---@param uri uri function m.refresh(uri) if not ws.isReady(uri) then return end + + await.close('diag:' .. uri) + ---@async + await.call(function () + await.setID('diag:' .. uri) + await.sleep(0.1) + xpcall(m.doDiagnostic, log.error, uri) + end) + local scp = scope.getScope(uri) local scopeID = 'diagnosticsScope:' .. scp:getName() - await.close('diag:' .. uri) await.close(scopeID) - await.call(function () ---@async - if uri then - await.setID('diag:' .. uri) - await.sleep(0.1) - xpcall(m.doDiagnostic, log.error, uri) - end + ---@async + await.call(function () local delay = config.get(uri, 'Lua.diagnostics.workspaceDelay') / 1000 if delay < 0 then return @@ -359,7 +458,7 @@ local function askForDisable(uri) end ---@async -function m.awaitDiagnosticsScope(suri) +function m.awaitDiagnosticsScope(suri, callback) local scp = scope.getScope(suri) while loading.count() > 0 do await.sleep(1.0) @@ -393,7 +492,7 @@ function m.awaitDiagnosticsScope(suri) i = i + 1 bar:setMessage(('%d/%d'):format(i, #uris)) bar:setPercentage(i / #uris * 100) - xpcall(m.doDiagnostic, log.error, uri, true) + callback(uri) await.delay() if cancelled then log.info('Break workspace diagnostics') @@ -416,13 +515,66 @@ function m.diagnosticsScope(uri, force) local id = 'diagnosticsScope:' .. scp:getName() await.close(id) await.call(function () ---@async - m.awaitDiagnosticsScope(uri) + await.sleep(0.0) + m.awaitDiagnosticsScope(uri, function (fileUri) + xpcall(m.doDiagnostic, log.error, fileUri, true) + end) end, id) end +---@async +function m.pullDiagnosticScope(callback) + local processing = 0 + + for _, scp in ipairs(scope.folders) do + if ws.isReady(scp.uri) + and config.get(scp.uri, 'Lua.diagnostics.enable') then + local id = 'diagnosticsScope:' .. scp:getName() + await.close(id) + await.call(function () ---@async + processing = processing + 1 + local _ <close> = util.defer(function () + processing = processing - 1 + end) + + local delay = config.get(scp.uri, 'Lua.diagnostics.workspaceDelay') / 1000 + if delay < 0 then + return + end + print(delay) + await.sleep(math.max(delay, 0.2)) + print('start') + + m.awaitDiagnosticsScope(scp.uri, function (fileUri) + local suc, result, unchanged = xpcall(m.pullDiagnostic, log.error, fileUri, true) + if suc then + callback { + uri = fileUri, + result = result, + unchanged = unchanged, + version = files.getVersion(fileUri), + } + end + end) + end, id) + end + end + + -- sleep for ever + while true do + await.sleep(1.0) + end +end + +function m.refreshClient() + log.debug('Refresh client diagnostics') + proto.request('workspace/diagnostic/refresh', json.null) +end + ws.watch(function (ev, uri) if ev == 'reload' then m.diagnosticsScope(uri) + m.refreshClient() end end) @@ -434,7 +586,7 @@ files.watch(function (ev, uri) ---@async m.refresh(uri) elseif ev == 'open' then if ws.isReady(uri) then - m.clearCache(uri) + m.resendDiagnostic(uri) xpcall(m.doDiagnostic, log.error, uri) end elseif ev == 'close' then @@ -446,9 +598,20 @@ files.watch(function (ev, uri) ---@async end) config.watch(function (uri, key, value, oldValue) - if key:find 'Lua.diagnostics' then + if util.stringStartWith(key, 'Lua.diagnostics') + or util.stringStartWith(key, 'Lua.spell') then if value ~= oldValue then m.diagnosticsScope(uri) + m.refreshClient() + end + end +end) + +fw.event(function (ev, path) + if util.stringEndWith(path, '.editorconfig') then + for _, scp in ipairs(ws.folders) do + m.diagnosticsScope(scp.uri) + m.refreshClient() end end end) diff --git a/script/provider/formatting.lua b/script/provider/formatting.lua index 73b6608d..4ec5545a 100644 --- a/script/provider/formatting.lua +++ b/script/provider/formatting.lua @@ -19,7 +19,7 @@ local updateType = { Deleted = 3, } -fw.event(function (ev, path) +fw.event(function(ev, path) if util.stringEndWith(path, '.editorconfig') then for uri, fsPath in pairs(loadedUris) do loadedUris[uri] = nil @@ -30,17 +30,9 @@ fw.event(function (ev, path) end end end - for _, scp in ipairs(ws.folders) do - diagnostics.diagnosticsScope(scp.uri) - end end end) -config.watch(function (uri, key, value) - if key == "Lua.format.defaultConfig" then - codeFormat.set_default_config(value) - end -end) local m = {} @@ -51,6 +43,7 @@ function m.updateConfig(uri) if not m.loadedDefaultConfig then m.loadedDefaultConfig = true codeFormat.set_default_config(config.get(uri, 'Lua.format.defaultConfig')) + m.updateNonStandardSymbols(config.get(nil, 'Lua.runtime.nonstandardSymbol')) end local currentUri = uri @@ -64,7 +57,7 @@ function m.updateConfig(uri) local currentPath = furi.decode(currentUri) local editorConfigFSPath = fs.path(currentPath) / '.editorconfig' if fs.exists(editorConfigFSPath) then - loadedUris[uri] = editorConfigFSPath + loadedUris[currentUri] = editorConfigFSPath local status, err = codeFormat.update_config(updateType.Created, currentUri, editorConfigFSPath:string()) if not status and err then log.error(err) @@ -83,4 +76,30 @@ function m.updateConfig(uri) end end +---@param symbols? string[] +function m.updateNonStandardSymbols(symbols) + if symbols == nil then + return + end + + local eqTokens = {} + for _, token in ipairs(symbols) do + if token:find("=") and token ~= "!=" then + table.insert(eqTokens, token) + end + end + + if #eqTokens ~= 0 then + codeFormat.set_nonstandard_symbol('=', eqTokens) + end +end + +config.watch(function(uri, key, value) + if key == "Lua.format.defaultConfig" then + codeFormat.set_default_config(value) + elseif key == "Lua.runtime.nonstandardSymbol" then + m.updateNonStandardSymbols(value) + end +end) + return m diff --git a/script/provider/markdown.lua b/script/provider/markdown.lua index 6b7d24c8..50716073 100644 --- a/script/provider/markdown.lua +++ b/script/provider/markdown.lua @@ -10,7 +10,7 @@ function mt:__tostring() end ---@param language string ----@param text string|markdown +---@param text? string|markdown function mt:add(language, text) if not text then return self @@ -40,7 +40,16 @@ function mt:splitLine() return self end -function mt:string() +function mt:emptyLine() + self._cacheResult = nil + self[#self+1] = { + type = 'emptyline', + } + return self +end + +---@return string +function mt:string(nl) if self._cacheResult then return self._cacheResult end @@ -59,6 +68,11 @@ function mt:string() lines[#lines+1] = '' lines[#lines+1] = '---' end + elseif obj.type == 'emptyline' then + if #lines > 0 + and lines[#lines] ~= '' then + lines[#lines+1] = '' + end elseif obj.type == 'markdown' then concat(obj.markdown) else @@ -80,6 +94,10 @@ function mt:string() if lines[#lines] ~= '' then lines[#lines+1] = '' end + elseif last == '---' then + if lines[#lines] ~= '' then + lines[#lines+1] = '' + end end end lines[#lines+1] = obj.text @@ -101,7 +119,7 @@ function mt:string() end end - local result = table.concat(lines, '\n') + local result = table.concat(lines, nl or '\n') self._cacheResult = result return result end diff --git a/script/provider/provider.lua b/script/provider/provider.lua index 3d012757..018db0c3 100644 --- a/script/provider/provider.lua +++ b/script/provider/provider.lua @@ -21,6 +21,8 @@ local furi = require 'file-uri' local inspect = require 'inspect' local markdown = require 'provider.markdown' local guide = require 'parser.guide' +local fs = require 'bee.filesystem' +local jumpSource = require 'core.jump-source' ---@async local function updateConfig(uri) @@ -28,7 +30,7 @@ local function updateConfig(uri) local specified = cfgLoader.loadLocalConfig(uri, CONFIGPATH) if specified then log.info('Load config from specified', CONFIGPATH) - log.debug(inspect(specified)) + log.info(inspect(specified)) -- watch directory filewatch.watch(workspace.getAbsolutePath(uri, CONFIGPATH):gsub('[^/\\]+$', '')) config.update(scope.override, specified) @@ -38,14 +40,14 @@ local function updateConfig(uri) local clientConfig = cfgLoader.loadClientConfig(folder.uri) if clientConfig then log.info('Load config from client', folder.uri) - log.debug(inspect(clientConfig)) + log.info(inspect(clientConfig)) end local rc = cfgLoader.loadRCConfig(folder.uri, '.luarc.json') or cfgLoader.loadRCConfig(folder.uri, '.luarc.jsonc') if rc then log.info('Load config from .luarc.json/.luarc.jsonc', folder.uri) - log.debug(inspect(rc)) + log.info(inspect(rc)) end config.update(folder, clientConfig, rc) @@ -53,7 +55,7 @@ local function updateConfig(uri) local global = cfgLoader.loadClientConfig() log.info('Load config from client', 'fallback') - log.debug(inspect(global)) + log.info(inspect(global)) config.update(scope.fallback, global) end @@ -149,7 +151,6 @@ m.register 'initialized'{ }) end client.setReady() - library.init() workspace.init() return true end @@ -234,11 +235,35 @@ m.register 'workspace/didRenameFiles' { end } +m.register 'workspace/didChangeWorkspaceFolders' { + capability = { + workspace = { + workspaceFolders = { + supported = true, + changeNotifications = true, + }, + }, + }, + ---@async + function (params) + log.debug('workspace/didChangeWorkspaceFolders', inspect(params)) + for _, folder in ipairs(params.event.added) do + workspace.create(folder.uri) + updateConfig() + workspace.reload(scope.getScope(folder.uri)) + end + for _, folder in ipairs(params.event.removed) do + workspace.remove(folder.uri) + end + end +} + m.register 'textDocument/didOpen' { function (params) - local doc = params.textDocument - local scheme = furi.split(doc.uri) - if scheme ~= 'file' then + local doc = params.textDocument + local scheme = furi.split(doc.uri) + local supports = config.get(doc.uri, 'Lua.workspace.supportScheme') + if not util.arrayHas(supports, scheme) then return end local uri = files.getRealUri(doc.uri) @@ -264,15 +289,21 @@ m.register 'textDocument/didClose' { } m.register 'textDocument/didChange' { + ---@async function (params) - local doc = params.textDocument - local scheme = furi.split(doc.uri) - if scheme ~= 'file' then + local doc = params.textDocument + local scheme = furi.split(doc.uri) + local supports = config.get(doc.uri, 'Lua.workspace.supportScheme') + if not util.arrayHas(supports, scheme) then return end local changes = params.contentChanges local uri = files.getRealUri(doc.uri) - local text = files.getOriginText(uri) or '' + local text = files.getOriginText(uri) + if not text then + files.setText(uri, pub.awaitTask('loadFile', furi.decode(uri)), false) + return + end local rows = files.getCachedRows(uri) text, rows = tm(text, rows, changes) files.setText(uri, text, true, function (file) @@ -310,7 +341,7 @@ m.register 'textDocument/hover' { end local pos = converter.unpackPosition(uri, params.position) local hover, source = core.byUri(uri, pos) - if not hover then + if not hover or not source then return nil end return { @@ -346,18 +377,16 @@ m.register 'textDocument/definition' { for i, info in ipairs(result) do local targetUri = info.uri if targetUri then - if files.exists(targetUri) then - if client.getAbility 'textDocument.definition.linkSupport' then - response[i] = converter.locationLink(targetUri - , converter.packRange(targetUri, info.target.start, info.target.finish) - , converter.packRange(targetUri, info.target.start, info.target.finish) - , converter.packRange(uri, info.source.start, info.source.finish) - ) - else - response[i] = converter.location(targetUri - , converter.packRange(targetUri, info.target.start, info.target.finish) - ) - end + if client.getAbility 'textDocument.definition.linkSupport' then + response[i] = converter.locationLink(targetUri + , converter.packRange(targetUri, info.target.start, info.target.finish) + , converter.packRange(targetUri, info.target.start, info.target.finish) + , converter.packRange(uri, info.source.start, info.source.finish) + ) + else + response[i] = converter.location(targetUri + , converter.packRange(targetUri, info.target.start, info.target.finish) + ) end end end @@ -388,18 +417,16 @@ m.register 'textDocument/typeDefinition' { for i, info in ipairs(result) do local targetUri = info.uri if targetUri then - if files.exists(targetUri) then - if client.getAbility 'textDocument.typeDefinition.linkSupport' then - response[i] = converter.locationLink(targetUri - , converter.packRange(targetUri, info.target.start, info.target.finish) - , converter.packRange(targetUri, info.target.start, info.target.finish) - , converter.packRange(uri, info.source.start, info.source.finish) - ) - else - response[i] = converter.location(targetUri - , converter.packRange(targetUri, info.target.start, info.target.finish) - ) - end + if client.getAbility 'textDocument.typeDefinition.linkSupport' then + response[i] = converter.locationLink(targetUri + , converter.packRange(targetUri, info.target.start, info.target.finish) + , converter.packRange(targetUri, info.target.start, info.target.finish) + , converter.packRange(uri, info.source.start, info.source.finish) + ) + else + response[i] = converter.location(targetUri + , converter.packRange(targetUri, info.target.start, info.target.finish) + ) end end end @@ -902,29 +929,35 @@ local function toArray(map) return array end -m.register 'textDocument/semanticTokens/full' { - capability = { - semanticTokensProvider = { - legend = { - tokenTypes = toArray(define.TokenTypes), - tokenModifiers = toArray(define.TokenModifiers), - }, - full = true, - }, - }, - ---@async - function (params) - log.debug('textDocument/semanticTokens/full') - local uri = files.getRealUri(params.textDocument.uri) - workspace.awaitReady(uri) - local _ <close> = progress.create(uri, lang.script.WINDOW_PROCESSING_SEMANTIC_FULL, 0.5) - local core = require 'core.semantic-tokens' - local results = core(uri, 0, math.huge) - return { - data = results - } +client.event(function (ev) + if ev == 'init' then + if not client.isVSCode() then + m.register 'textDocument/semanticTokens/full' { + capability = { + semanticTokensProvider = { + legend = { + tokenTypes = toArray(define.TokenTypes), + tokenModifiers = toArray(define.TokenModifiers), + }, + full = true, + }, + }, + ---@async + function (params) + log.debug('textDocument/semanticTokens/full') + local uri = files.getRealUri(params.textDocument.uri) + workspace.awaitReady(uri) + local _ <close> = progress.create(uri, lang.script.WINDOW_PROCESSING_SEMANTIC_FULL, 0.5) + local core = require 'core.semantic-tokens' + local results = core(uri, 0, math.huge) + return { + data = results + } + end + } + end end -} +end) m.register 'textDocument/semanticTokens/range' { capability = { @@ -988,6 +1021,40 @@ m.register 'textDocument/foldingRange' { end } +m.register 'textDocument/documentColor' { + capability = { + colorProvider = true + }, + ---@async + function (params) + local color = require 'core.color' + local uri = files.getRealUri(params.textDocument.uri) + workspace.awaitReady(uri) + if not files.exists(uri) then + return nil + end + local colors = color.colors(uri) + if not colors then + return nil + end + local results = {} + for _, colorValue in ipairs(colors) do + results[#results+1] = { + range = converter.packRange(uri, colorValue.start, colorValue.finish), + color = colorValue.color + } + end + return results + end +} + +m.register 'textDocument/colorPresentation' { + function (params) + local color = (require 'core.color').colorToText(params.color) + return {{label = color}} + end +} + m.register 'window/workDoneProgress/cancel' { function (params) log.debug('close proto(cancel):', params.token) @@ -1001,6 +1068,7 @@ m.register '$/status/click' { local titleDiagnostic = lang.script.WINDOW_LUA_STATUS_DIAGNOSIS_TITLE local result = client.awaitRequestMessage('Info', lang.script.WINDOW_LUA_STATUS_DIAGNOSIS_MSG, { titleDiagnostic, + DEVELOP and 'Restart Server', }) if not result then return @@ -1010,6 +1078,10 @@ m.register '$/status/click' { for _, scp in ipairs(workspace.folders) do diagnostic.diagnosticsScope(scp.uri, true) end + elseif result == 'Restart Server' then + local diag = require 'provider.diagnostic' + diag.clearAll(true) + os.exit(0, true) end end } @@ -1210,6 +1282,123 @@ m.register 'inlayHint/resolve' { end } +m.register 'textDocument/diagnostic' { + preview = true, + capability = { + diagnosticProvider = { + identifier = 'identifier', + interFileDependencies = true, + workspaceDiagnostics = false, + } + }, + ---@async + function (params) + local uri = files.getRealUri(params.textDocument.uri) + workspace.awaitReady(uri) + local core = require 'provider.diagnostic' + -- TODO: do some trick + core.doDiagnostic(uri) + + return { + kind = 'unchanged', + resultId = uri, + } + + --if not params.previousResultId then + -- core.clearCache(uri) + --end + --local results, unchanged = core.pullDiagnostic(uri, false) + --if unchanged then + -- return { + -- kind = 'unchanged', + -- resultId = uri, + -- } + --else + -- return { + -- kind = 'full', + -- resultId = uri, + -- items = results or {}, + -- } + --end + end +} + +m.register 'workspace/diagnostic' { + --preview = true, + --capability = { + -- diagnosticProvider = { + -- workspaceDiagnostics = false, + -- } + --}, + ---@async + function (params) + local core = require 'provider.diagnostic' + local excepts = {} + for _, id in ipairs(params.previousResultIds) do + excepts[#excepts+1] = id.value + end + core.clearCacheExcept(excepts) + local function convertItem(result) + if result.unchanged then + return { + kind = 'unchanged', + resultId = result.uri, + uri = result.uri, + version = result.version, + } + else + return { + kind = 'full', + resultId = result.uri, + items = result.result or {}, + uri = result.uri, + version = result.version, + } + end + end + core.pullDiagnosticScope(function (result) + proto.notify('$/progress', { + token = params.partialResultToken, + value = { + items = { + convertItem(result) + } + } + }) + end) + return { items = {} } + end +} + +m.register '$/api/report' { + ---@async + function (params) + local buildMeta = require 'provider.build-meta' + local SDBMHash = require 'SDBMHash' + await.close 'api/report' + await.setID 'api/report' + local name = params.name or 'default' + local uri = workspace.getFirstScope().uri + local hash = uri and ('%08x'):format(SDBMHash():hash(uri)) + local encoding = config.get(nil, 'Lua.runtime.fileEncoding') + local nameBuf = {} + nameBuf[#nameBuf+1] = name + nameBuf[#nameBuf+1] = hash + nameBuf[#nameBuf+1] = encoding + local fileDir = METAPATH .. '/' .. table.concat(nameBuf, ' ') + fs.create_directories(fs.path(fileDir)) + buildMeta.build(fileDir, params) + client.setConfig { + { + key = 'Lua.workspace.library', + action = 'add', + value = fileDir, + uri = uri, + } + } + end +} + local function refreshStatusBar() local valid = config.get(nil, 'Lua.window.statusBar') for _, scp in ipairs(workspace.folders) do diff --git a/script/provider/spell.lua b/script/provider/spell.lua new file mode 100644 index 00000000..6647bbad --- /dev/null +++ b/script/provider/spell.lua @@ -0,0 +1,53 @@ +local suc, codeFormat = pcall(require, 'code_format') +if not suc then + return +end + +local fs = require 'bee.filesystem' +local config = require 'config' +local diagnostics = require 'provider.diagnostic' +local pformatting = require 'provider.formatting' +local util = require 'utility' + +local m = {} + +function m.loadDictionaryFromFile(filePath) + return codeFormat.spell_load_dictionary_from_path(filePath) +end + +function m.loadDictionaryFromBuffer(buffer) + return codeFormat.spell_load_dictionary_from_buffer(buffer) +end + +function m.addWord(word) + return codeFormat.spell_load_dictionary_from_buffer(word) +end + +function m.spellCheck(uri, text) + if not m._dictionaryLoaded then + m.initDictionary() + m._dictionaryLoaded = true + end + + local tempDict = config.get(uri, 'Lua.spell.dict') + + return codeFormat.spell_analysis(uri, text, tempDict) +end + +function m.getSpellSuggest(word) + local status, result = codeFormat.spell_suggest(word) + if status then + return result + end +end + +function m.initDictionary() + local basicDictionary = fs.path(METAPATH) / "spell/dictionary.txt" + local luaDictionary = fs.path(METAPATH) / "spell/lua_dict.txt" + + m.loadDictionaryFromFile(basicDictionary:string()) + m.loadDictionaryFromFile(luaDictionary:string()) + pformatting.updateNonStandardSymbols(config.get(nil, "Lua.runtime.nonstandardSymbol")) +end + +return m diff --git a/script/pub/pub.lua b/script/pub/pub.lua index 47591ee6..1e9b6c8f 100644 --- a/script/pub/pub.lua +++ b/script/pub/pub.lua @@ -136,8 +136,6 @@ function m.task(name, params, callback) end --- 接收反馈 ---- 返回接收到的反馈数量 ----@return integer function m.recieve(block) if block then local id, name, result = waiter:bpop() diff --git a/script/service/service.lua b/script/service/service.lua index 26790c63..07612d7b 100644 --- a/script/service/service.lua +++ b/script/service/service.lua @@ -235,7 +235,7 @@ end function m.testVersion() local stack = debug.setcstacklimit(200) debug.setcstacklimit(stack + 1) - if debug.setcstacklimit(stack) == stack + 1 then + if type(stack) == 'number' and debug.setcstacklimit(stack) == stack + 1 then proto.notify('window/showMessage', { type = 2, message = 'It seems to be running in Lua 5.4.0 or Lua 5.4.1 . Please upgrade to Lua 5.4.2 or above. Otherwise, it may encounter weird "C stack overflow", resulting in failure to work properly', diff --git a/script/service/telemetry.lua b/script/service/telemetry.lua index 2e52def2..211dff0e 100644 --- a/script/service/telemetry.lua +++ b/script/service/telemetry.lua @@ -90,6 +90,7 @@ local function pushErrorLog(link) )) end +---@type boolean? local isValid = false timer.wait(5, function () diff --git a/script/utility.lua b/script/utility.lua index 47b0c8d8..7f9d559b 100644 --- a/script/utility.lua +++ b/script/utility.lua @@ -8,6 +8,7 @@ local ipairs = ipairs local next = next local rawset = rawset local move = table.move +local tableRemove = table.remove local setmetatable = debug.setmetatable local mathType = math.type local mathCeil = math.ceil @@ -157,8 +158,10 @@ function m.dump(tbl, option) local tp = type(value) local format = option['format'] and option['format'][key] if format then - lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, format(value, unpack, deep+1, stack)) - elseif tp == 'table' then + value = format(value, unpack, deep+1, stack) + tp = type(value) + end + if tp == 'table' then if mark[value] and mark[value] > 0 then lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, option['loop'] or '"<Loop>"') elseif deep >= (option['deep'] or mathHuge) then @@ -272,7 +275,7 @@ local function sortTable(tbl) end --- 创建一个有序表 ----@param tbl table {optional = 'self'} +---@param tbl? table ---@return table function m.container(tbl) return sortTable(tbl) @@ -287,6 +290,9 @@ function m.loadFile(path, keepBom) end local text = f:read 'a' f:close() + if not text then + return nil + end if not keepBom then if text:sub(1, 3) == '\xEF\xBB\xBF' then return text:sub(4) @@ -521,6 +527,14 @@ function m.revertTable(t) return t end +function m.revertMap(t) + local nt = {} + for k, v in pairs(t) do + nt[v] = k + end + return nt +end + function m.randomSortTable(t, max) local len = #t if len <= 1 then @@ -567,7 +581,7 @@ end ---遍历文本的每一行 ---@param text string ---@param keepNL? boolean # 保留换行符 ----@return fun(text:string):string, integer +---@return fun():string, integer function m.eachLine(text, keepNL) local offset = 1 local lineCount = 0 @@ -649,9 +663,6 @@ function m.trim(str, mode) end function m.expandPath(path) - if type(path) ~= 'string' then - return nil - end if path:sub(1, 1) == '~' then local home = getenv('HOME') if not home then -- has to be Windows @@ -791,8 +802,60 @@ function m.multiTable(count, default) return current end +---@param t table +---@param sorter boolean|function +function m.getTableKeys(t, sorter) + local keys = {} + for k in pairs(t) do + keys[#keys+1] = k + end + if sorter == true then + tableSort(keys) + elseif type(sorter) == 'function' then + tableSort(keys, sorter) + end + return keys +end + +function m.arrayHas(array, value) + for i = 1, #array do + if array[i] == value then + return true + end + end + return false +end + +function m.arrayInsert(array, value) + if not m.arrayHas(array, value) then + array[#array+1] = value + end +end + +function m.arrayRemove(array, value) + for i = 1, #array do + if array[i] == value then + tableRemove(array, i) + return + end + end +end + m.MODE_K = { __mode = 'k' } m.MODE_V = { __mode = 'v' } m.MODE_KV = { __mode = 'kv' } +---@generic T: fun(param: any):any +---@param func T +---@return T +function m.cacheReturn(func) + local cache = {} + return function (param) + if cache[param] == nil then + cache[param] = func(param) + end + return cache[param] + end +end + return m diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 75620d19..60c7243e 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -6,10 +6,149 @@ local files = require 'files' ---@class vm local vm = require 'vm.vm' +local LOCK = {} + ---@class parser.object ---@field _compiledNodes boolean ---@field _node vm.node ---@field _globalBase table +---@field cindex integer +---@field func parser.object +---@field operators? parser.object[] + +-- 该函数有副作用,会给source绑定node! +---@param source parser.object +---@return boolean +local function bindDocs(source) + local docs = source.bindDocs + if not docs then + return false + end + for i = #docs, 1, -1 do + local doc = docs[i] + if doc.type == 'doc.type' then + vm.setNode(source, vm.compileNode(doc)) + return true + end + if doc.type == 'doc.class' then + vm.setNode(source, vm.compileNode(doc)) + return true + end + if doc.type == 'doc.param' then + local node = vm.compileNode(doc) + if doc.optional then + node:addOptional() + end + vm.setNode(source, node) + return true + end + if doc.type == 'doc.module' then + local name = doc.module + if not name then + return true + end + local uri = rpath.findUrisByRequireName(guide.getUri(source), name)[1] + if not uri then + return true + end + local state = files.getState(uri) + local ast = state and state.ast + if not ast then + return true + end + vm.setNode(source, vm.compileNode(ast)) + return true + end + if doc.type == 'doc.overload' then + vm.setNode(source, vm.compileNode(doc)) + end + end + return false +end + +---@param source parser.object +---@param key any +---@param ref boolean +---@param pushResult fun(res: parser.object, markDoc?: boolean) +local function searchFieldByLocalID(source, key, ref, pushResult) + local fields + if key then + fields = vm.getLocalSourcesSets(source, key) + if ref then + local gets = vm.getLocalSourcesGets(source, key) + if gets then + fields = fields or {} + for _, src in ipairs(gets) do + fields[#fields+1] = src + end + end + end + else + fields = vm.getLocalFields(source, false) + end + if not fields then + return + end + local hasMarkDoc = {} + for _, src in ipairs(fields) do + if src.bindDocs then + if bindDocs(src) then + local skey = guide.getKeyName(src) + if skey then + hasMarkDoc[skey] = true + end + pushResult(src, true) + end + end + end + for _, src in ipairs(fields) do + local skey = guide.getKeyName(src) + if not hasMarkDoc[skey] then + pushResult(src) + end + end +end + +---@param suri uri +---@param source parser.object +---@param key any +---@param ref boolean +---@param pushResult fun(res: parser.object, markDoc?: boolean) +local function searchFieldByGlobalID(suri, source, key, ref, pushResult) + local node = source._globalNode + if not node then + return + end + if node.cate == 'variable' then + if key then + if type(key) ~= 'string' then + return + end + local global = vm.getGlobal('variable', node.name, key) + if global then + for _, set in ipairs(global:getSets(suri)) do + pushResult(set) + end + for _, get in ipairs(global:getGets(suri)) do + pushResult(get) + end + end + else + local globals = vm.getGlobalFields('variable', node.name) + for _, global in ipairs(globals) do + for _, set in ipairs(global:getSets(suri)) do + pushResult(set) + end + for _, get in ipairs(global:getGets(suri)) do + pushResult(get) + end + end + end + end + if node.cate == 'type' then + vm.getClassFields(suri, node, key, ref, pushResult) + end +end local searchFieldSwitch = util.switch() : case 'table' @@ -47,6 +186,7 @@ local searchFieldSwitch = util.switch() end end) : case 'string' + : case 'doc.type.string' : call(function (suri, source, key, ref, pushResult) -- change to `string: stringlib` ? local stringlib = vm.getGlobal('type', 'stringlib') @@ -54,23 +194,6 @@ local searchFieldSwitch = util.switch() vm.getClassFields(suri, stringlib, key, ref, pushResult) end end) - : case 'local' - : case 'self' - : call(function (suri, node, key, ref, pushResult) - local fields - if key then - fields = vm.getLocalSources(node, key) - else - fields = vm.getLocalFields(node) - end - if fields then - for _, src in ipairs(fields) do - if ref or guide.isSet(src) then - pushResult(src) - end - end - end - end) : case 'doc.type.array' : call(function (suri, source, key, ref, pushResult) if type(key) == 'number' then @@ -78,8 +201,13 @@ local searchFieldSwitch = util.switch() or not math.tointeger(key) then return end + pushResult(source.node) + end + if type(key) == 'table' then + if vm.isSubType(suri, key, 'integer') then + pushResult(source.node) + end end - pushResult(source.node) end) : case 'doc.type.table' : call(function (suri, source, key, ref, pushResult) @@ -144,42 +272,15 @@ local searchFieldSwitch = util.switch() end end) : default(function (suri, source, key, ref, pushResult) - local node = source._globalNode - if not node then - return - end - if node.cate == 'variable' then - if key then - if type(key) ~= 'string' then - return - end - local global = vm.getGlobal('variable', node.name, key) - if global then - for _, set in ipairs(global:getSets(suri)) do - pushResult(set) - end - for _, get in ipairs(global:getGets(suri)) do - pushResult(get) - end - end - else - local globals = vm.getGlobalFields('variable', node.name) - for _, global in ipairs(globals) do - for _, set in ipairs(global:getSets(suri)) do - pushResult(set) - end - for _, get in ipairs(global:getGets(suri)) do - pushResult(get) - end - end - end - end - if node.cate == 'type' then - vm.getClassFields(suri, node, key, ref, pushResult) - end + searchFieldByLocalID(source, key, ref, pushResult) + searchFieldByGlobalID(suri, source, key, ref, pushResult) end) - +---@param suri uri +---@param object vm.global +---@param key string|vm.global +---@param ref boolean +---@param pushResult fun(field: vm.object, isMark?: boolean) function vm.getClassFields(suri, object, key, ref, pushResult) local mark = {} @@ -194,6 +295,14 @@ function vm.getClassFields(suri, object, key, ref, pushResult) if set.type == 'doc.class' then -- check ---@field local hasFounded = {} + + local function copyToSearched() + for fieldKey in pairs(hasFounded) do + searchedFields[fieldKey] = true + hasFounded[fieldKey] = nil + end + end + for _, field in ipairs(set.fields) do local fieldKey = guide.getKeyName(field) if fieldKey then @@ -201,12 +310,11 @@ function vm.getClassFields(suri, object, key, ref, pushResult) if key == nil or fieldKey == key then if not searchedFields[fieldKey] then - pushResult(field) + pushResult(field, true) hasFounded[fieldKey] = true end end - end - if not hasFounded[fieldKey] then + elseif key and not hasFounded[key] then local keyType = type(key) if keyType == 'table' then -- ---@field [integer] boolean -> class[integer] @@ -214,7 +322,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult) if vm.isSubType(suri, key.name, fieldNode) then local nkey = '|' .. key.name if not searchedFields[nkey] then - pushResult(field) + pushResult(field, true) hasFounded[nkey] = true end end @@ -230,13 +338,13 @@ function vm.getClassFields(suri, object, key, ref, pushResult) or keyType == 'string' then typeName = keyType end - if typeName then + if typeName and field.field.type ~= 'doc.field.name' then -- ---@field [integer] boolean -> class[1] local fieldNode = vm.compileNode(field.field) if vm.isSubType(suri, typeName, fieldNode) then local nkey = '|' .. typeName if not searchedFields[nkey] then - pushResult(field) + pushResult(field, true) hasFounded[nkey] = true end end @@ -244,38 +352,37 @@ function vm.getClassFields(suri, object, key, ref, pushResult) end end end + copyToSearched() -- check local field and global field - if set.bindSources then - for _, src in ipairs(set.bindSources) do - searchFieldSwitch(src.type, suri, src, key, ref, function (field) + if not searchedFields[key] and set.bindSource then + local src = set.bindSource + if src.value and src.value.type == 'table' then + searchFieldSwitch('table', suri, src.value, key, ref, function (field) local fieldKey = guide.getKeyName(field) if fieldKey then if not searchedFields[fieldKey] and guide.isSet(field) then hasFounded[fieldKey] = true - pushResult(field) + pushResult(field, true) end end end) - if src.value and src.value.type == 'table' then - searchFieldSwitch('table', suri, src.value, key, ref, function (field) - local fieldKey = guide.getKeyName(field) - if fieldKey then - if not searchedFields[fieldKey] - and guide.isSet(field) then - hasFounded[fieldKey] = true - pushResult(field) - end - end - end) - end end + copyToSearched() + searchFieldSwitch(src.type, suri, src, key, ref, function (field) + local fieldKey = guide.getKeyName(field) + if fieldKey and not searchedFields[fieldKey] then + if not searchedFields[fieldKey] + and guide.isSet(field) then + hasFounded[fieldKey] = true + pushResult(field, true) + end + end + end) + copyToSearched() end -- look into extends(if field not found) - if not hasFounded[key] and set.extends then - for fieldKey in pairs(hasFounded) do - searchedFields[fieldKey] = true - end + if not searchedFields[key] and set.extends then for _, extend in ipairs(set.extends) do if extend.type == 'doc.extends.name' then local extendType = vm.getGlobal('type', extend[1]) @@ -284,6 +391,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult) end end end + copyToSearched() end end end @@ -296,7 +404,7 @@ function vm.getClassFields(suri, object, key, ref, pushResult) for _, set in ipairs(sets) do pushResult(set) end - else + elseif type(key) == 'string' then local global = vm.getGlobal('variable', key) if global then for _, set in ipairs(global:getSets(suri)) do @@ -312,10 +420,10 @@ function vm.getClassFields(suri, object, key, ref, pushResult) end ---@class parser.object ----@field _sign? vm.sign +---@field _sign vm.sign|false ---@param source parser.object ----@return vm.sign? +---@return vm.sign|false local function getObjectSign(source) if source._sign ~= nil then return source._sign @@ -376,7 +484,7 @@ end ---@param func parser.object ---@param index integer ----@return vm.object? +---@return (parser.object|vm.generic)? function vm.getReturnOfFunction(func, index) if func.type == 'function' then if not func._returns then @@ -384,9 +492,9 @@ function vm.getReturnOfFunction(func, index) end if not func._returns[index] then func._returns[index] = { - type = 'function.return', - parent = func, - index = index, + type = 'function.return', + parent = func, + returnIndex = index, } end return func._returns[index] @@ -394,7 +502,12 @@ function vm.getReturnOfFunction(func, index) if func.type == 'doc.type.function' then local rtn = func.returns[index] if not rtn then - return nil + local lastReturn = func.returns[#func.returns] + if lastReturn and lastReturn.name and lastReturn.name[1] == '...' then + rtn = lastReturn + else + return nil + end end local sign = getObjectSign(func) if not sign then @@ -402,8 +515,10 @@ function vm.getReturnOfFunction(func, index) end return vm.createGeneric(rtn, sign) end + return nil end +---@param args parser.object[] ---@return vm.node local function getReturnOfSetMetaTable(args) local tbl = args[1] @@ -427,88 +542,60 @@ local function getReturnOfSetMetaTable(args) return node end ----@return vm.node? -local function getReturn(func, index, args) - if func.special == 'setmetatable' then - if not args then - return nil - end - return getReturnOfSetMetaTable(args) - end - if func.special == 'pcall' and index > 1 then - if not args then - return nil - end - local newArgs = {} - for i = 2, #args do - newArgs[#newArgs+1] = args[i] - end - return getReturn(args[1], index - 1, newArgs) - end - if func.special == 'xpcall' and index > 1 then - if not args then - return nil - end - local newArgs = {} - for i = 3, #args do - newArgs[#newArgs+1] = args[i] - end - return getReturn(args[1], index - 1, newArgs) +---@param source parser.object +local function matchCall(source) + local call = source.parent + if not call + or call.type ~= 'call' + or call.node ~= source then + return end - if func.special == 'require' then - if not args then - return nil - end - local nameArg = args[1] - if not nameArg or nameArg.type ~= 'string' then - return nil - end - local name = nameArg[1] - if not name or type(name) ~= 'string' then - return nil - end - local uri = rpath.findUrisByRequirePath(guide.getUri(func), name)[1] - if not uri then - return nil - end - local state = files.getState(uri) - local ast = state and state.ast - if not ast then - return nil - end - return vm.compileNode(ast) + local funcs = vm.getMatchedFunctions(source, call.args) + local myNode = vm.getNode(source) + if not myNode then + return end - local node = vm.compileNode(func) - ---@type vm.node? - local result - for cnode in node:eachObject() do - if cnode.type == 'function' - or cnode.type == 'doc.type.function' then - local returnObject = vm.getReturnOfFunction(cnode, index) - if returnObject then - local returnNode = vm.compileNode(returnObject) - for rnode in returnNode:eachObject() do - if rnode.type == 'generic' then - returnNode = rnode:resolve(guide.getUri(func), args) - break - end - end - if returnNode then - for rnode in returnNode:eachObject() do - -- TODO: narrow type - if rnode.type ~= 'doc.generic.name' then - result = result or vm.createNode() - result:merge(rnode) - end - end - if result and returnNode:isOptional() then - result:addOptional() - end + local needRemove + for n in myNode:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + if not util.arrayHas(funcs, n) then + if not needRemove then + needRemove = vm.createNode() end + needRemove:merge(n) end end end - return result + if needRemove then + local newNode = myNode:copy() + newNode:removeNode(needRemove) + newNode:setData('originNode', myNode) + vm.setNode(source, newNode, true) + end +end + +---@param func parser.object +---@param index integer +---@param args parser.object[] +---@return vm.node +local function getReturn(func, index, args) + if not func._callReturns then + func._callReturns = {} + end + if not func._callReturns[index] then + local call = func.parent + func._callReturns[index] = { + type = 'call.return', + parent = call, + func = func, + cindex = index, + args = args, + start = call.start, + finish = call.finish, + } + end + return vm.compileNode(func._callReturns[index]) end ---@param source parser.object @@ -517,123 +604,136 @@ local function bindAs(source) local root = guide.getRoot(source) local docs = root.docs if not docs then - return + return false end - for _, doc in ipairs(docs) do - if doc.type == 'doc.as' and doc.originalComment.start == source.finish + 2 then - if doc.as then - vm.setNode(source, vm.compileNode(doc.as), true) + local ases = docs._asCache + if not ases then + ases = {} + docs._asCache = ases + for _, doc in ipairs(docs) do + if doc.type == 'doc.as' and doc.as and doc.touch then + ases[#ases+1] = doc end - return true end + table.sort(ases, function (a, b) + return a.touch < b.touch + end) end - return false -end -local function bindDocs(source) - local isParam = source.parent.type == 'funcargs' - or source.parent.type == 'in' - local docs = source.bindDocs - for i = #docs, 1, -1 do - local doc = docs[i] - if doc.type == 'doc.type' then - if not isParam then - vm.setNode(source, vm.compileNode(doc)) - return true - end - end - if doc.type == 'doc.class' then - if (source.type == 'local' and not isParam) - or (source._globalNode and guide.isSet(source)) - or source.type == 'tablefield' - or source.type == 'tableindex' then - vm.setNode(source, vm.compileNode(doc)) - return true - end - end - if doc.type == 'doc.param' then - if isParam and source[1] == doc.param[1] then - local node = vm.compileNode(doc) - if doc.optional then - node:addOptional() - end - vm.setNode(source, node) - return true - end - end - if doc.type == 'doc.module' then - local name = doc.module - local uri = rpath.findUrisByRequirePath(guide.getUri(source), name)[1] - if not uri then - return nil - end - local state = files.getState(uri) - local ast = state and state.ast - if not ast then - return nil - end - vm.setNode(source, vm.compileNode(ast)) - return true - end - if doc.type == 'doc.overload' then - if not isParam then - vm.setNode(source, vm.compileNode(doc)) - end - end + if #ases == 0 then + return false end - return false -end -local function compileByLocalID(source) - local sources = vm.getLocalSources(source) - if not sources then - return - end - local hasMarkDoc - for _, src in ipairs(sources) do - if src.bindDocs then - if bindDocs(src) then - hasMarkDoc = true - vm.setNode(source, vm.compileNode(src)) - end + local max = #ases + local index + local left = 1 + local right = max + for _ = 1, 1000 do + if left == right then + index = left + break + end + index = left + (right - left) // 2 + local doc = ases[index] + if doc.touch < source.finish then + left = index + 1 + else + right = index end end - for _, src in ipairs(sources) do - if src.value then - if not hasMarkDoc or guide.isLiteral(src.value) then - if src.value.type ~= 'nil' then - vm.setNode(source, vm.compileNode(src.value)) - end - end - end + + local doc = ases[index] + if doc and doc.touch == source.finish then + vm.setNode(source, vm.compileNode(doc.as), true) + return true end + + return false end ----@param source vm.node ----@param key? any +---@param source parser.object +---@param key? string|vm.global ---@param pushResult fun(source: parser.object) function vm.compileByParentNode(source, key, ref, pushResult) local parentNode = vm.compileNode(source) + local docedResults = {} + local commonResults = {} + local mark = {} local suri = guide.getUri(source) + local hasClass for node in parentNode:eachObject() do - searchFieldSwitch(node.type, suri, node, key, ref, pushResult) + if node.type == 'global' + and node.cate == 'type' + ---@cast node vm.global + and not guide.isBasicType(node.name) then + hasClass = true + break + end + end + for node in parentNode:eachObject() do + if not hasClass + or ( + node.type == 'global' + and node.cate == 'type' + ---@cast node vm.global + and not guide.isBasicType(node.name) + ) + or guide.isLiteral(node) then + searchFieldSwitch(node.type, suri, node, key, ref, function (res, markDoc) + if mark[res] then + return + end + mark[res] = true + if markDoc then + docedResults[#docedResults+1] = res + else + commonResults[#commonResults+1] = res + end + end) + end end -end ----@return vm.node? -local function selectNode(source, list, index) - if not list then - return nil + if not next(mark) then + searchFieldByLocalID(source, key, ref, function (res, markDoc) + if mark[res] then + return + end + mark[res] = true + if markDoc then + docedResults[#docedResults+1] = res + else + commonResults[#commonResults+1] = res + end + end) end + + if #docedResults > 0 then + for _, res in ipairs(docedResults) do + pushResult(res) + end + end + if #docedResults == 0 or key == nil then + for _, res in ipairs(commonResults) do + pushResult(res) + end + end +end + +---@param list parser.object[] +---@param index integer +---@return vm.node +---@return parser.object? +function vm.selectNode(list, index) local exp if list[index] then exp = list[index] + index = 1 else for i = index, 1, -1 do if list[i] then local last = list[i] if last.type == 'call' - or last.type == '...' then + or last.type == 'varargs' then index = index - i + 1 exp = last end @@ -642,40 +742,46 @@ local function selectNode(source, list, index) end end if not exp then - return nil + return vm.createNode(vm.declareGlobal('type', 'nil')), nil end + ---@type vm.node? local result if exp.type == 'call' then result = getReturn(exp.node, index, exp.args) - if not result then - vm.setNode(source, vm.declareGlobal('type', 'unknown')) - return vm.getNode(source) + if result:isEmpty() then + result:merge(vm.declareGlobal('type', 'unknown')) end else + ---@type vm.node result = vm.compileNode(exp) + if result:isEmpty() then + result:merge(vm.declareGlobal('type', 'unknown')) + end end + return result, exp +end + +---@param source parser.object +---@param list parser.object[] +---@param index integer +---@return vm.node +local function selectNode(source, list, index) + local result = vm.selectNode(list, index) if source.type == 'function.return' then -- remove any for returns local rtnNode = vm.createNode() - local hasKnownType for n in result:eachObject() do if guide.isLiteral(n) then - hasKnownType = true rtnNode:merge(n) end if n.type == 'global' and n.cate == 'type' then - if n.name ~= 'any' - and n.name ~= 'unknown' then - hasKnownType = true + if n.name ~= 'any' then rtnNode:merge(n) end else rtnNode:merge(n) end end - if not hasKnownType then - rtnNode:merge(vm.declareGlobal('type', 'unknown')) - end vm.setNode(source, rtnNode) return rtnNode end @@ -684,15 +790,19 @@ local function selectNode(source, list, index) end ---@param source parser.object ----@param node vm.object +---@param node vm.node.object ---@return boolean local function isValidCallArgNode(source, node) if source.type == 'function' then return node.type == 'doc.type.function' end if source.type == 'table' then - return node.type == 'doc.type.table' - or (node.type == 'global' and node.cate == 'type' and not guide.isBasicType(node.name)) + return node.type == 'doc.type.table' or node.type == 'doc.type.array' + or ( node.type == 'global' + and node.cate == 'type' + ---@cast node vm.global + and not guide.isBasicType(node.name) + ) end if source.type == 'dummyarg' then return true @@ -741,12 +851,14 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) for n in callNode:eachObject() do if n.type == 'function' then + ---@cast n parser.object local sign = getObjectSign(n) local farg = getFuncArg(n, myIndex) if farg then for fn in vm.compileNode(farg):eachObject() do if isValidCallArgNode(arg, fn) then if fn.type == 'doc.type.function' then + ---@cast fn parser.object if sign then local generic = vm.createGeneric(fn, sign) local args = {} @@ -762,6 +874,7 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) end end if n.type == 'doc.type.function' then + ---@cast n parser.object local myEvent if n.args[eventIndex] then local argNode = vm.compileNode(n.args[eventIndex]) @@ -788,6 +901,7 @@ end ---@param arg parser.object ---@param call parser.object ---@param index? integer +---@return vm.node? function vm.compileCallArg(arg, call, index) if not index then for i, carg in ipairs(call.args) do @@ -796,6 +910,9 @@ function vm.compileCallArg(arg, call, index) break end end + if not index then + return nil + end end local callNode = vm.compileNode(call.node) @@ -812,8 +929,44 @@ function vm.compileCallArg(arg, call, index) return vm.getNode(arg) end +---@class parser.object +---@field _iterator? table +---@field _iterArgs? table +---@field _iterVars? table<parser.object, vm.node> + +---@param source parser.object +local function compileForVars(source) + if source._iterator then + return + end + if not source.exps then + return + end + -- for k, v in pairs(t) do + --> for k, v in iterator, status, initValue do + --> local k, v = iterator(status, initValue) + source._iterator = { + type = 'dummyfunc', + parent = source, + } + source._iterArgs = {{},{}} + source._iterVars = {} + -- iterator + selectNode(source._iterator, source.exps, 1) + -- status + selectNode(source._iterArgs[1], source.exps, 2) + -- initValue + selectNode(source._iterArgs[2], source.exps, 3) + if source.keys then + for i, loc in ipairs(source.keys) do + local node = getReturn(source._iterator, i, source._iterArgs) + node:removeOptional() + source._iterVars[loc] = node + end + end +end + ---@param source parser.object ----@return vm.node local function compileLocal(source) vm.setNode(source, source) @@ -835,15 +988,25 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(source.parent.parent.parent.node)) end end + vm.getNode(source):remove 'function' end local hasMarkValue - if source.value then - if not hasMarkDoc or guide.isLiteral(source.value) then - hasMarkValue = true - if source.value.type == 'table' then - vm.setNode(source, source.value) - elseif source.value.type ~= 'nil' then - vm.setNode(source, vm.compileNode(source.value)) + if not hasMarkDoc and source.value then + hasMarkValue = true + if source.value.type == 'table' then + vm.setNode(source, source.value) + elseif source.value.type ~= 'nil' then + vm.setNode(source, vm.compileNode(source.value)) + end + end + if not hasMarkValue and not hasMarkValue then + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' + and ref.value + and ref.value.type == 'function' then + vm.setNode(source, vm.compileNode(ref.value)) + end end end end @@ -883,12 +1046,21 @@ local function compileLocal(source) end -- for x in ... do if source.parent.type == 'in' then - vm.compileNode(source.parent) + compileForVars(source.parent) + local keyNode = source.parent._iterVars and source.parent._iterVars[source] + if keyNode then + vm.setNode(source, keyNode) + end end -- for x = ... do if source.parent.type == 'loop' then - vm.compileNode(source.parent) + if source.parent.loc == source then + if bindDocs(source) then + return + end + vm.setNode(source, vm.declareGlobal('type', 'integer')) + end end vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) @@ -920,9 +1092,21 @@ local compilerSwitch = util.switch() or source.parent.type == 'setlocal' or source.parent.type == 'tablefield' or source.parent.type == 'tableindex' + or source.parent.type == 'tableexp' or source.parent.type == 'setfield' or source.parent.type == 'setindex' then - vm.setNode(source, vm.compileNode(source.parent)) + local parentNode = vm.compileNode(source.parent) + for _, pn in ipairs(parentNode) do + if pn.type == 'global' + and pn.cate == 'type' then + ---@cast pn vm.global + if not guide.isBasicType(pn.name) then + vm.setNode(source, pn) + end + elseif pn.type == 'doc.type.table' or pn.type == 'doc.type.array' then + vm.setNode(source, pn) + end + end end end) : case 'function' @@ -964,8 +1148,7 @@ local compilerSwitch = util.switch() local hasMark = vm.getNode(source):getData 'hasDefined' - local runner = vm.createRunner(source) - runner:launch(function (src, node) + vm.launchRunner(source, function (src, node) if src.type == 'setlocal' then if src.bindDocs then for _, doc in ipairs(src.bindDocs) do @@ -975,18 +1158,32 @@ local compilerSwitch = util.switch() end end end - if src.value and guide.isLiteral(src.value) then + if src.value then if src.value.type == 'table' then - vm.setNode(src, vm.createNode(src.value), true) + vm.setNode(src, vm.createNode(src.value)) + vm.setNode(src, node:copy():asTable()) else + local function clearLockedNode(child) + if not child then + return + end + if child.type == 'function' then + return + end + if child.type == 'setlocal' + or child.type == 'getlocal' then + if child.node == source then + return + end + end + if LOCK[child] then + vm.removeNode(child) + end + guide.eachChild(child, clearLockedNode) + end + clearLockedNode(src.value) vm.setNode(src, vm.compileNode(src.value), true) end - elseif src.value - and src.value.type == 'binary' - and src.value.op and src.value.op.type == 'or' - and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then - -- x = x or 1 - vm.setNode(src, vm.compileNode(src.value)) else vm.setNode(src, node, true) end @@ -996,6 +1193,7 @@ local compilerSwitch = util.switch() return end vm.setNode(src, node, true) + matchCall(src) end end) @@ -1004,7 +1202,10 @@ local compilerSwitch = util.switch() for _, ref in ipairs(source.ref) do if ref.type == 'setlocal' and guide.getParentFunction(ref) == parentFunc then - vm.setNode(source, vm.getNode(ref)) + local refNode = vm.getNode(ref) + if refNode then + vm.setNode(source, refNode) + end end end end @@ -1023,27 +1224,14 @@ local compilerSwitch = util.switch() : case 'setfield' : case 'setmethod' : case 'setindex' - : call(function (source) - compileByLocalID(source) - local key = guide.getKeyName(source) - if key == nil then - return - end - vm.compileByParentNode(source.node, key, false, function (src) - if src.type == 'doc.type.field' - or src.type == 'doc.field' then - vm.setNode(source, vm.compileNode(src)) - end - end) - end) : case 'getfield' : case 'getmethod' : case 'getindex' : call(function (source) - if bindAs(source) then + if guide.isGet(source) and bindAs(source) then return end - compileByLocalID(source) + ---@type (string|vm.node)? local key = guide.getKeyName(source) if key == nil and source.index then key = vm.compileNode(source.index) @@ -1052,14 +1240,39 @@ local compilerSwitch = util.switch() return end if type(key) == 'table' then + ---@cast key vm.node local uri = guide.getUri(source) local value = vm.getTableValue(uri, vm.compileNode(source.node), key) if value then vm.setNode(source, value) end + for k in key:eachObject() do + if k.type == 'global' and k.cate == 'type' then + ---@cast k vm.global + vm.compileByParentNode(source.node, k, false, function (src) + vm.setNode(source, vm.compileNode(src)) + if src.value then + vm.setNode(source, vm.compileNode(src.value)) + end + end) + end + end else + ---@cast key string vm.compileByParentNode(source.node, key, false, function (src) - vm.setNode(source, vm.compileNode(src)) + if src.value then + if bindDocs(src) then + vm.setNode(source, vm.compileNode(src)) + elseif src.value.type ~= 'nil' then + vm.setNode(source, vm.compileNode(src.value)) + local node = vm.getNode(src) + if node then + vm.setNode(source, node) + end + end + else + vm.setNode(source, vm.compileNode(src)) + end end) end end) @@ -1097,21 +1310,20 @@ local compilerSwitch = util.switch() hasMarkDoc = bindDocs(source) end - if source.value then - if not hasMarkDoc or guide.isLiteral(source.value) then - if source.value.type == 'table' then - vm.setNode(source, source.value) - elseif source.value.type ~= 'nil' then - vm.setNode(source, vm.compileNode(source.value)) + if not hasMarkDoc then + vm.compileByParentNode(source.node, guide.getKeyName(source), false, function (src) + if src.type == 'doc.field' + or src.type == 'doc.type.field' then + hasMarkDoc = true + vm.setNode(source, vm.compileNode(src)) end - end + end) end - if not hasMarkDoc then - vm.compileByParentNode(source.parent, guide.getKeyName(source), false, function (src) - vm.setNode(source, vm.compileNode(src)) - end) + if not hasMarkDoc and source.value then + vm.setNode(source, vm.compileNode(source.value)) end + end) : case 'field' : case 'method' @@ -1120,18 +1332,29 @@ local compilerSwitch = util.switch() end) : case 'tableexp' : call(function (source) + if (source.parent.type == 'table') then + local node = vm.compileNode(source.parent) + for n in node:eachObject() do + if n.type == 'doc.type.array' then + vm.setNode(source, vm.compileNode(n.node)) + end + end + end vm.setNode(source, vm.compileNode(source.value)) end) : case 'function.return' + ---@param source parser.object : call(function (source) local func = source.parent - local index = source.index + local index = source.returnIndex local hasMarkDoc if func.bindDocs then local sign = getObjectSign(func) + local lastReturn for _, doc in ipairs(func.bindDocs) do if doc.type == 'doc.return' then for _, rtn in ipairs(doc.returns) do + lastReturn = rtn if rtn.returnIndex == index then hasMarkDoc = true local hasGeneric @@ -1141,6 +1364,7 @@ local compilerSwitch = util.switch() end) end if hasGeneric then + ---@cast sign -false vm.setNode(source, vm.createGeneric(rtn, sign)) else vm.setNode(source, vm.compileNode(rtn)) @@ -1149,10 +1373,146 @@ local compilerSwitch = util.switch() end end end + if lastReturn + and not hasMarkDoc then + if lastReturn.name and lastReturn.name[1] == '...' then + hasMarkDoc = true + vm.setNode(source, vm.compileNode(lastReturn)) + end + end end + local hasReturn if func.returns and not hasMarkDoc then for _, rtn in ipairs(func.returns) do - selectNode(source, rtn, index) + if selectNode(source, rtn, index) then + hasReturn = true + end + end + if hasReturn then + local hasKnownType + local hasUnknownType + for n in vm.getNode(source):eachObject() do + if guide.isLiteral(n) then + if n.type ~= 'nil' then + hasKnownType = true + break + end + goto CONTINUE + end + if n.type == 'global' and n.cate == 'type' then + if n.name ~= 'nil' then + hasKnownType = true + break + end + goto CONTINUE + end + hasUnknownType = true + ::CONTINUE:: + end + if not hasKnownType and hasUnknownType then + vm.setNode(source, vm.declareGlobal('type', 'unknown')) + end + end + end + if not hasMarkDoc and not hasReturn then + vm.setNode(source, vm.declareGlobal('type', 'nil')) + end + end) + : case 'call.return' + ---@param source parser.object + : call(function (source) + if bindAs(source) then + return + end + local func = source.func + local args = source.args + local index = source.cindex + if func.special == 'setmetatable' then + if not args then + return + end + vm.setNode(source, getReturnOfSetMetaTable(args)) + return + end + if func.special == 'pcall' and index > 1 then + if not args then + return + end + local newArgs = {} + for i = 2, #args do + newArgs[#newArgs+1] = args[i] + end + local node = getReturn(args[1], index - 1, newArgs) + if node then + vm.setNode(source, node) + end + return + end + if func.special == 'xpcall' and index > 1 then + if not args then + return + end + local newArgs = {} + for i = 3, #args do + newArgs[#newArgs+1] = args[i] + end + local node = getReturn(args[1], index - 1, newArgs) + if node then + vm.setNode(source, node) + end + return + end + if func.special == 'require' then + if not args then + return + end + local nameArg = args[1] + if not nameArg or nameArg.type ~= 'string' then + return + end + local name = nameArg[1] + if not name or type(name) ~= 'string' then + return + end + local uri = rpath.findUrisByRequireName(guide.getUri(func), name)[1] + if not uri then + return + end + local state = files.getState(uri) + local ast = state and state.ast + if not ast then + return + end + vm.setNode(source, vm.compileNode(ast)) + return + end + local funcNode = vm.compileNode(func) + ---@type vm.node? + for mfunc in funcNode:eachObject() do + if mfunc.type == 'function' + or mfunc.type == 'doc.type.function' then + ---@cast mfunc parser.object + local returnObject = vm.getReturnOfFunction(mfunc, index) + if returnObject then + local returnNode = vm.compileNode(returnObject) + for rnode in returnNode:eachObject() do + if rnode.type == 'generic' then + returnNode = rnode:resolve(guide.getUri(func), args) + break + end + end + if returnNode then + for rnode in returnNode:eachObject() do + -- TODO: narrow type + if rnode.type ~= 'doc.generic.name' then + vm.setNode(source, rnode) + end + end + if returnNode:isOptional() then + vm.getNode(source):addOptional() + end + end + end end end end) @@ -1174,12 +1534,8 @@ local compilerSwitch = util.switch() if not node then return end - for n in node:eachObject() do - if n.type == 'global' - and n.cate == 'type' - and n.name == '...' then - return - end + if node:isEmpty() then + node = vm.runOperator('call', vararg.node) or node end vm.setNode(source, node) end @@ -1199,51 +1555,11 @@ local compilerSwitch = util.switch() if not node then return end - for n in node:eachObject() do - if n.type == 'global' - and n.cate == 'type' - and n.name == '...' then - return - end + if node:isEmpty() then + node = vm.runOperator('call', source.node) or node end vm.setNode(source, node) end) - : case 'in' - : call(function (source) - if not source._iterator then - -- for k, v in pairs(t) do - --> for k, v in iterator, status, initValue do - --> local k, v = iterator(status, initValue) - source._iterator = { - type = 'dummyfunc', - parent = source, - } - source._iterArgs = {{},{}} - end - -- iterator - selectNode(source._iterator, source.exps, 1) - -- status - selectNode(source._iterArgs[1], source.exps, 2) - -- initValue - selectNode(source._iterArgs[2], source.exps, 3) - if source.keys then - for i, loc in ipairs(source.keys) do - local node = getReturn(source._iterator, i, source._iterArgs) - if node then - if i == 1 then - node:removeOptional() - end - vm.setNode(loc, node) - end - end - end - end) - : case 'loop' - : call(function (source) - if source.loc then - vm.setNode(source.loc, vm.declareGlobal('type', 'integer')) - end - end) : case 'doc.type' : call(function (source) for _, typeUnit in ipairs(source.types) do @@ -1256,6 +1572,7 @@ local compilerSwitch = util.switch() : case 'doc.type.integer' : case 'doc.type.string' : case 'doc.type.boolean' + : case 'doc.type.code' : call(function (source) vm.setNode(source, source) end) @@ -1299,6 +1616,10 @@ local compilerSwitch = util.switch() : call(function (source) vm.setNode(source, vm.compileNode(source.parent)) end) + : case 'doc.enum.name' + : call(function (source) + vm.setNode(source, vm.compileNode(source.parent)) + end) : case 'doc.field' : call(function (source) if not source.extends then @@ -1337,18 +1658,14 @@ local compilerSwitch = util.switch() end) : case '...' : call(function (source) - local func = source.parent.parent - if func.type ~= 'function' then - return - end - if not func.bindDocs then + if not source.bindDocs then return end - for _, doc in ipairs(func.bindDocs) do + for _, doc in ipairs(source.bindDocs) do if doc.type == 'doc.vararg' then vm.setNode(source, vm.compileNode(doc)) end - if doc.type == 'doc.param' and doc.param[1] == '...' then + if doc.type == 'doc.param' then vm.setNode(source, vm.compileNode(doc)) end end @@ -1361,7 +1678,7 @@ local compilerSwitch = util.switch() : call(function (source) local type = vm.getGlobal('type', source[1]) if type then - vm.setNode(source, vm.compileNode(type)) + vm.setNode(source, type) end end) : case 'doc.type.arg' @@ -1375,10 +1692,6 @@ local compilerSwitch = util.switch() vm.getNode(source):addOptional() end end) - : case 'generic' - : call(function (source) - vm.setNode(source, source) - end) : case 'unary' : call(function (source) if bindAs(source) then @@ -1387,58 +1700,7 @@ local compilerSwitch = util.switch() if not source[1] then return end - if source.op.type == 'not' then - local result = vm.test(source[1]) - if result == nil then - vm.setNode(source, vm.declareGlobal('type', 'boolean')) - return - else - vm.setNode(source, { - type = 'boolean', - start = source.start, - finish = source.finish, - parent = source, - [1] = not result, - }) - return - end - end - if source.op.type == '#' then - vm.setNode(source, vm.declareGlobal('type', 'integer')) - return - end - if source.op.type == '-' then - local v = vm.getNumber(source[1]) - if v == nil then - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - else - vm.setNode(source, { - type = 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = -v, - }) - return - end - end - if source.op.type == '~' then - local v = vm.getInteger(source[1]) - if v == nil then - vm.setNode(source, vm.declareGlobal('type', 'integer')) - return - else - vm.setNode(source, { - type = 'integer', - start = source.start, - finish = source.finish, - parent = source, - [1] = ~v, - }) - return - end - end + vm.unarySwich(source.op.type, source) end) : case 'binary' : call(function (source) @@ -1448,323 +1710,21 @@ local compilerSwitch = util.switch() if not source[1] or not source[2] then return end - if source.op.type == 'and' then - local node1 = vm.compileNode(source[1]) - local node2 = vm.compileNode(source[2]) - local r1 = vm.test(source[1]) - if r1 == true then - vm.setNode(source, node2) - elseif r1 == false then - vm.setNode(source, node1) - else - vm.setNode(source, node2) - end - end - if source.op.type == 'or' then - local node1 = vm.compileNode(source[1]) - local node2 = vm.compileNode(source[2]) - local r1 = vm.test(source[1]) - if r1 == true then - vm.setNode(source, node1) - elseif r1 == false then - vm.setNode(source, node2) - else - vm.getNode(source):merge(node1) - vm.getNode(source):setTruthy() - vm.getNode(source):merge(node2) - end - end - if source.op.type == '==' then - local result = vm.equal(source[1], source[2]) - if result == nil then - vm.setNode(source, vm.declareGlobal('type', 'boolean')) - return - else - vm.setNode(source, { - type = 'boolean', - start = source.start, - finish = source.finish, - parent = source, - [1] = result, - }) - return - end - end - if source.op.type == '~=' then - local result = vm.equal(source[1], source[2]) - if result == nil then - vm.setNode(source, vm.declareGlobal('type', 'boolean')) - return - else - vm.setNode(source, { - type = 'boolean', - start = source.start, - finish = source.finish, - parent = source, - [1] = not result, - }) - return - end - end - if source.op.type == '<<' then - local a = vm.getInteger(source[1]) - local b = vm.getInteger(source[2]) - if a and b then - vm.setNode(source, { - type = 'integer', - start = source.start, - finish = source.finish, - parent = source, - [1] = a << b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'integer')) - return - end - end - if source.op.type == '>>' then - local a = vm.getInteger(source[1]) - local b = vm.getInteger(source[2]) - if a and b then - vm.setNode(source, { - type = 'integer', - start = source.start, - finish = source.finish, - parent = source, - [1] = a >> b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'integer')) - return - end - end - if source.op.type == '&' then - local a = vm.getInteger(source[1]) - local b = vm.getInteger(source[2]) - if a and b then - vm.setNode(source, { - type = 'integer', - start = source.start, - finish = source.finish, - parent = source, - [1] = a & b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'integer')) - return - end - end - if source.op.type == '|' then - local a = vm.getInteger(source[1]) - local b = vm.getInteger(source[2]) - if a and b then - vm.setNode(source, { - type = 'integer', - start = source.start, - finish = source.finish, - parent = source, - [1] = a | b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'integer')) - return - end - end - if source.op.type == '~' then - local a = vm.getInteger(source[1]) - local b = vm.getInteger(source[2]) - if a and b then - vm.setNode(source, { - type = 'integer', - start = source.start, - finish = source.finish, - parent = source, - [1] = a ~ b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'integer')) - return - end - end - if source.op.type == '+' then - local a = vm.getNumber(source[1]) - local b = vm.getNumber(source[2]) - if a and b then - local result = a + b - vm.setNode(source, { - type = math.type(result) == 'integer' and 'integer' or 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = result, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - end - end - if source.op.type == '-' then - local a = vm.getNumber(source[1]) - local b = vm.getNumber(source[2]) - if a and b then - local result = a - b - vm.setNode(source, { - type = math.type(result) == 'integer' and 'integer' or 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = result, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - end - end - if source.op.type == '*' then - local a = vm.getNumber(source[1]) - local b = vm.getNumber(source[2]) - if a and b then - local result = a * b - vm.setNode(source, { - type = math.type(result) == 'integer' and 'integer' or 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = result, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - end - end - if source.op.type == '/' then - local a = vm.getNumber(source[1]) - local b = vm.getNumber(source[2]) - if a and b then - vm.setNode(source, { - type = 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = a / b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - end - end - if source.op.type == '%' then - local a = vm.getNumber(source[1]) - local b = vm.getNumber(source[2]) - if a and b and b ~= 0 then - local result = a % b - vm.setNode(source, { - type = math.type(result) == 'integer' and 'integer' or 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = result, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - end - end - if source.op.type == '^' then - local a = vm.getNumber(source[1]) - local b = vm.getNumber(source[2]) - if a and b then - vm.setNode(source, { - type = 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = a ^ b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - end - end - if source.op.type == '//' then - local a = vm.getNumber(source[1]) - local b = vm.getNumber(source[2]) - if a and b and b ~= 0 then - local result = a // b - vm.setNode(source, { - type = math.type(result) == 'integer' and 'integer' or 'number', - start = source.start, - finish = source.finish, - parent = source, - [1] = result, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'number')) - return - end - end - if source.op.type == '..' then - local a = vm.getString(source[1]) - or vm.getNumber(source[1]) - local b = vm.getString(source[2]) - or vm.getNumber(source[2]) - if a and b then - if type(a) == 'number' or type(b) == 'number' then - local uri = guide.getUri(source) - local version = config.get(uri, 'Lua.runtime.version') - if math.tointeger(a) and math.type(a) == 'float' then - if version == 'Lua 5.3' or version == 'Lua 5.4' then - a = ('%.1f'):format(a) - else - a = ('%.0f'):format(a) - end - end - if math.tointeger(b) and math.type(b) == 'float' then - if version == 'Lua 5.3' or version == 'Lua 5.4' then - b = ('%.1f'):format(b) - else - b = ('%.0f'):format(b) - end - end - end - vm.setNode(source, { - type = 'string', - start = source.start, - finish = source.finish, - parent = source, - [1] = a .. b, - }) - return - else - vm.setNode(source, vm.declareGlobal('type', 'string')) - return - end - end + vm.binarySwitch(source.op.type, source) end) ----@param source vm.object +---@param source parser.object local function compileByNode(source) compilerSwitch(source.type, source) end ----@param source vm.object +---@param source parser.object local function compileByGlobal(source) local global = source._globalNode if not global then return end + ---@cast source parser.object local root = guide.getRoot(source) local uri = guide.getUri(source) if not root._globalBase then @@ -1800,22 +1760,30 @@ local function compileByGlobal(source) if global.cate == 'variable' then local hasMarkDoc for _, set in ipairs(global:getSets(uri)) do - if set.bindDocs then + if set.bindDocs and set.parent.type == 'main' then if bindDocs(set) then globalNode:merge(vm.compileNode(set)) hasMarkDoc = true end + if vm.getNode(set) then + globalNode:merge(vm.compileNode(set)) + end end end + -- Set all globals node first to avoid recursive for _, set in ipairs(global:getSets(uri)) do - if set.value then + vm.setNode(set, globalNode, true) + end + for _, set in ipairs(global:getSets(uri)) do + if set.value and set.value.type ~= 'nil' and set.parent.type == 'main' then if not hasMarkDoc or guide.isLiteral(set.value) then - if set.value.type ~= 'nil' then - globalNode:merge(vm.compileNode(set.value)) - end + globalNode:merge(vm.compileNode(set.value)) end end end + for _, set in ipairs(global:getSets(uri)) do + vm.setNode(set, globalNode, true) + end end if global.cate == 'type' then for _, set in ipairs(global:getSets(uri)) do @@ -1850,21 +1818,27 @@ function vm.compileNode(source) end end - if source.type == 'global' then - return source - end - local cache = vm.getNode(source) if cache ~= nil then return cache end - local node = vm.createNode() - vm.setNode(source, node, true) + if source.type == 'generic' then + vm.setNode(source, source) + local node = vm.getNode(source) + ---@cast node -? + return node + end + + ---@cast source parser.object + vm.setNode(source, vm.createNode(), true) + LOCK[source] = true compileByGlobal(source) compileByNode(source) + matchCall(source) + LOCK[source] = nil - node = vm.getNode(source) - + local node = vm.getNode(source) + ---@cast node -? return node end diff --git a/script/vm/def.lua b/script/vm/def.lua index 83e92686..f557f221 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -5,72 +5,7 @@ local guide = require 'parser.guide' local simpleSwitch -local function searchGetLocal(source, node, pushResult) - local key = guide.getKeyName(source) - for _, ref in ipairs(node.node.ref) do - if ref.type == 'getlocal' - and ref.next - and guide.isSet(ref.next) - and guide.getKeyName(ref.next) == key then - pushResult(ref.next) - end - end -end - simpleSwitch = util.switch() - : case 'local' - : call(function (source, pushResult) - pushResult(source) - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' then - pushResult(ref) - end - end - end - end) - : case 'sellf' - : call(function (source, pushResult) - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' then - pushResult(ref) - end - end - end - for _, res in ipairs(vm.getDefs(source.method.node)) do - pushResult(res) - end - end) - : case 'getlocal' - : case 'setlocal' - : call(function (source, pushResult) - simpleSwitch('local', source.node, pushResult) - end) - : case 'field' - : call(function (source, pushResult) - local parent = source.parent - if parent.type ~= 'tablefield' then - simpleSwitch(parent.type, parent, pushResult) - end - end) - : case 'setfield' - : case 'getfield' - : call(function (source, pushResult) - local node = source.node - if node.type == 'getlocal' then - searchGetLocal(source, node, pushResult) - return - end - end) - : case 'getindex' - : case 'setindex' - : call(function (source, pushResult) - local node = source.node - if node.type == 'getlocal' then - searchGetLocal(source, node, pushResult) - end - end) : case 'goto' : call(function (source, pushResult) if source.node then @@ -98,7 +33,7 @@ local searchFieldSwitch = util.switch() end end) : case 'global' - ---@param obj vm.object + ---@param obj vm.global ---@param key string : call(function (suri, obj, key, pushResult) if obj.cate == 'variable' then @@ -115,12 +50,10 @@ local searchFieldSwitch = util.switch() end) : case 'local' : call(function (suri, obj, key, pushResult) - local sources = vm.getLocalSources(obj, key) + local sources = vm.getLocalSourcesSets(obj, key) if sources then for _, src in ipairs(sources) do - if guide.isSet(src) then - pushResult(src) - end + pushResult(src) end end end) @@ -152,7 +85,7 @@ local nodeSwitch;nodeSwitch = util.switch() local parentNode = vm.compileNode(source.node) local uri = guide.getUri(source) local key = guide.getKeyName(source) - if not key then + if type(key) ~= 'string' then return end if lastKey then @@ -169,9 +102,15 @@ local nodeSwitch;nodeSwitch = util.switch() if lastKey then return end - local tbl = source.parent + local key = guide.getKeyName(source) + if type(key) ~= 'string' then + return + end local uri = guide.getUri(source) - searchFieldSwitch(tbl.type, uri, tbl, guide.getKeyName(source), pushResult) + local parentNode = vm.compileNode(source.node) + for pn in parentNode:eachObject() do + searchFieldSwitch(pn.type, uri, pn, key, pushResult) + end end) : case 'doc.see.field' : call(function (source, lastKey, pushResult) @@ -194,14 +133,12 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchByLocalID(source, pushResult) - local idSources = vm.getLocalSources(source) + local idSources = vm.getLocalSourcesSets(source) if not idSources then return end for _, src in ipairs(idSources) do - if guide.isSet(src) then - pushResult(src) - end + pushResult(src) end end diff --git a/script/vm/doc.lua b/script/vm/doc.lua index e2b383b6..293cf5c3 100644 --- a/script/vm/doc.lua +++ b/script/vm/doc.lua @@ -20,6 +20,8 @@ function vm.getDocSets(suri, name) end end +---@param uri uri +---@return boolean function vm.isMetaFile(uri) local status = files.getState(uri) if not status then @@ -45,6 +47,8 @@ function vm.isMetaFile(uri) return false end +---@param doc parser.object +---@return table<string, boolean>? function vm.getValidVersions(doc) if doc.type ~= 'doc.version' then return @@ -87,13 +91,14 @@ function vm.getValidVersions(doc) return valids end +---@param value parser.object ---@return parser.object? local function getDeprecated(value) if not value.bindDocs then - return false + return nil end if value._deprecated ~= nil then - return value._deprecated + return value._deprecated or nil end for _, doc in ipairs(value.bindDocs) do if doc.type == 'doc.deprecated' then @@ -101,24 +106,26 @@ local function getDeprecated(value) return doc elseif doc.type == 'doc.version' then local valids = vm.getValidVersions(doc) - if not valids[config.get(guide.getUri(value), 'Lua.runtime.version')] then + if valids and not valids[config.get(guide.getUri(value), 'Lua.runtime.version')] then value._deprecated = doc return doc end end end value._deprecated = false - return false + return nil end +---@param value parser.object +---@param deep boolean? ---@return parser.object? function vm.getDeprecated(value, deep) if deep then local defs = vm.getDefs(value) if #defs == 0 then - return false + return nil end - local deprecated = false + local deprecated for _, def in ipairs(defs) do if def.type == 'setglobal' or def.type == 'setfield' @@ -128,7 +135,7 @@ function vm.getDeprecated(value, deep) or def.type == 'tableindex' then deprecated = getDeprecated(def) if not deprecated then - return false + return nil end end end @@ -138,6 +145,8 @@ function vm.getDeprecated(value, deep) end end +---@param value parser.object +---@return boolean local function isAsync(value) if value.type == 'function' then if not value.bindDocs then @@ -155,9 +164,15 @@ local function isAsync(value) value._async = false return false end + if value.type == 'main' then + return true + end return value.async == true end +---@param value parser.object +---@param deep boolean? +---@return boolean function vm.isAsync(value, deep) if isAsync(value) then return true @@ -176,6 +191,8 @@ function vm.isAsync(value, deep) return false end +---@param value parser.object +---@return boolean local function isNoDiscard(value) if value.type == 'function' then if not value.bindDocs then @@ -196,6 +213,9 @@ local function isNoDiscard(value) return false end +---@param value parser.object +---@param deep boolean? +---@return boolean function vm.isNoDiscard(value, deep) if isNoDiscard(value) then return true @@ -214,6 +234,8 @@ function vm.isNoDiscard(value, deep) return false end +---@param param parser.object +---@return boolean local function isCalledInFunction(param) if not param.ref then return false @@ -238,6 +260,9 @@ local function isCalledInFunction(param) return false end +---@param node parser.object +---@param index integer +---@return boolean local function isLinkedCall(node, index) for _, def in ipairs(vm.getDefs(node)) do if def.type == 'function' then @@ -252,16 +277,21 @@ local function isLinkedCall(node, index) return false end +---@param node parser.object +---@param index integer +---@return boolean function vm.isLinkedCall(node, index) return isLinkedCall(node, index) end +---@param call parser.object +---@return boolean function vm.isAsyncCall(call) if vm.isAsync(call.node, true) then return true end if not call.args then - return + return false end for i, arg in ipairs(call.args) do if vm.isAsync(arg, true) @@ -272,6 +302,9 @@ function vm.isAsyncCall(call) return false end +---@param uri uri +---@param doc parser.object +---@param results table[] local function makeDiagRange(uri, doc, results) local names if doc.names then @@ -325,7 +358,12 @@ local function makeDiagRange(uri, doc, results) end end -function vm.isDiagDisabledAt(uri, position, name) +---@param uri uri +---@param position integer +---@param name string +---@param err? boolean +---@return boolean +function vm.isDiagDisabledAt(uri, position, name, err) local status = files.getState(uri) if not status then return false @@ -355,7 +393,8 @@ function vm.isDiagDisabledAt(uri, position, name) local count = 0 for _, range in ipairs(cache.diagnosticRanges) do if range.row <= myRow then - if not range.names or range.names[name] then + if (range.names and range.names[name]) + or (not range.names and not err) then if range.mode == 'disable' then count = count + 1 elseif range.mode == 'enable' then diff --git a/script/vm/field.lua b/script/vm/field.lua index 5de838be..b92c3a7b 100644 --- a/script/vm/field.lua +++ b/script/vm/field.lua @@ -16,7 +16,7 @@ local searchByNodeSwitch = util.switch() end) local function searchByLocalID(source, pushResult) - local fields = vm.getLocalFields(source) + local fields = vm.getLocalFields(source, true) if fields then for _, field in ipairs(fields) do pushResult(field) diff --git a/script/vm/function.lua b/script/vm/function.lua new file mode 100644 index 00000000..7cde6298 --- /dev/null +++ b/script/vm/function.lua @@ -0,0 +1,245 @@ +---@class vm +local vm = require 'vm.vm' + +---@param arg parser.object +---@return parser.object? +local function getDocParam(arg) + if not arg.bindDocs then + return nil + end + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.param' + and doc.param[1] == arg[1] then + return doc + end + end + return nil +end + +---@param func parser.object +---@return integer min +---@return number max +---@return integer def +function vm.countParamsOfFunction(func) + local min = 0 + local max = 0 + local def = 0 + if func.type == 'function' then + if func.args then + max = #func.args + def = max + for i = #func.args, 1, -1 do + local arg = func.args[i] + if arg.type == '...' then + max = math.huge + elseif arg.type == 'self' + and i == 1 then + min = i + break + elseif getDocParam(arg) + and not vm.compileNode(arg):isNullable() then + min = i + break + end + end + end + end + if func.type == 'doc.type.function' then + if func.args then + max = #func.args + def = max + for i = #func.args, 1, -1 do + local arg = func.args[i] + if arg.name and arg.name[1] =='...' then + max = math.huge + elseif not vm.compileNode(arg):isNullable() then + min = i + break + end + end + end + end + return min, max, def +end + +---@param node vm.node +---@return integer min +---@return number max +---@return integer def +function vm.countParamsOfNode(node) + local min, max, def + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + ---@cast n parser.object + local fmin, fmax, fdef = vm.countParamsOfFunction(n) + if not min or fmin < min then + min = fmin + end + if not max or fmax > max then + max = fmax + end + if not def or fdef > def then + def = fdef + end + end + end + return min or 0, max or math.huge, def or 0 +end + +---@param func parser.object +---@param onlyDoc? boolean +---@param mark? table +---@return integer min +---@return number max +---@return integer def +function vm.countReturnsOfFunction(func, onlyDoc, mark) + if func.type == 'function' then + ---@type integer?, number?, integer? + local min, max, def + local hasDocReturn + if func.bindDocs then + local lastReturn + local n = 0 + ---@type integer?, number?, integer? + local dmin, dmax, ddef + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + hasDocReturn = true + for _, ret in ipairs(doc.returns) do + n = n + 1 + lastReturn = ret + dmax = n + ddef = n + if (not ret.name or ret.name[1] ~= '...') + and not vm.compileNode(ret):isNullable() then + dmin = n + end + end + end + end + if lastReturn then + if lastReturn.name and lastReturn.name[1] == '...' then + dmax = math.huge + end + end + if dmin and (not min or (dmin < min)) then + min = dmin + end + if dmax and (not max or (dmax > max)) then + max = dmax + end + if ddef and (not def or (ddef > def)) then + def = ddef + end + end + if not onlyDoc and not hasDocReturn and func.returns then + for _, ret in ipairs(func.returns) do + local rmin, rmax, ddef = vm.countList(ret, mark) + if not min or rmin < min then + min = rmin + end + if not max or rmax > max then + max = rmax + end + if not def or ddef > def then + def = ddef + end + end + end + return min or 0, max or math.huge, def or 0 + end + if func.type == 'doc.type.function' then + return vm.countList(func.returns) + end + error('not a function') +end + +---@param func parser.object +---@param mark? table +---@return integer min +---@return number max +---@return integer def +function vm.countReturnsOfCall(func, args, mark) + local funcs = vm.getMatchedFunctions(func, args, mark) + ---@type integer?, number?, integer? + local min, max, def + for _, f in ipairs(funcs) do + local rmin, rmax, rdef = vm.countReturnsOfFunction(f, false, mark) + if not min or rmin < min then + min = rmin + end + if not max or rmax > max then + max = rmax + end + if not def or rdef > def then + def = rdef + end + end + return min or 0, max or math.huge, def or 0 +end + +---@param list parser.object[]? +---@param mark? table +---@return integer min +---@return number max +---@return integer def +function vm.countList(list, mark) + if not list then + return 0, 0, 0 + end + local lastArg = list[#list] + if not lastArg then + return 0, 0, 0 + end + if lastArg.type == '...' + or lastArg.type == 'varargs' then + return #list - 1, math.huge, #list + end + if lastArg.type == 'call' then + if not mark then + mark = {} + end + if mark[lastArg] then + return #list - 1, math.huge, #list + end + mark[lastArg] = true + local rmin, rmax, rdef = vm.countReturnsOfCall(lastArg.node, lastArg.args, mark) + return #list - 1 + rmin, #list - 1 + rmax, #list - 1 + rdef + end + return #list, #list, #list +end + +---@param func parser.object +---@param args parser.object[]? +---@param mark? table +---@return parser.object[] +function vm.getMatchedFunctions(func, args, mark) + local funcs = {} + local node = vm.compileNode(func) + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + funcs[#funcs+1] = n + end + end + if #funcs <= 1 then + return funcs + end + + local amin, amax = vm.countList(args, mark) + + local matched = {} + for _, n in ipairs(funcs) do + local min, max = vm.countParamsOfFunction(n) + if amin >= min and amax <= max then + matched[#matched+1] = n + end + end + + if #matched == 0 then + return funcs + else + return matched + end +end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index 6462028e..544e11c9 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -11,11 +11,11 @@ local mt = {} mt.__index = mt mt.type = 'generic' ----@param source parser.object +---@param source vm.object? ---@param resolved? table<string, vm.node> ----@return parser.object | vm.node +---@return vm.object? local function cloneObject(source, resolved) - if not resolved then + if not resolved or not source then return source end if source.type == 'doc.generic.name' then @@ -121,8 +121,17 @@ function mt:resolve(uri, args) local protoNode = vm.compileNode(self.proto) local result = vm.createNode() for nd in protoNode:eachObject() do - local clonedNode = vm.compileNode(cloneObject(nd, resolved)) - result:merge(clonedNode) + if nd.type == 'global' then + ---@cast nd vm.global + result:merge(nd) + else + ---@cast nd -vm.global + local clonedObject = cloneObject(nd, resolved) + if clonedObject then + local clonedNode = vm.compileNode(clonedObject) + result:merge(clonedNode) + end + end end return result end diff --git a/script/vm/global.lua b/script/vm/global.lua index a54ab552..22235681 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -2,6 +2,7 @@ local util = require 'utility' local scope = require 'workspace.scope' local guide = require 'parser.guide' local files = require 'files' +local ws = require 'workspace' ---@class vm local vm = require 'vm.vm' @@ -11,8 +12,8 @@ local vm = require 'vm.vm' ---@class vm.global ---@field links table<uri, vm.global.link> ----@field setsCache table<uri, parser.object[]> ----@field getsCache table<uri, parser.object[]> +---@field setsCache? table<uri, parser.object[]> +---@field getsCache? table<uri, parser.object[]> ---@field cate vm.global.cate local mt = {} mt.__index = mt @@ -41,6 +42,7 @@ function mt:addGet(uri, source) self.getsCache = nil end +---@param suri uri ---@return parser.object[] function mt:getSets(suri) if not self.setsCache then @@ -127,7 +129,8 @@ local function createGlobal(name, cate) end ---@class parser.object ----@field _globalNode vm.global +---@field _globalNode vm.global|false +---@field _enums? (string|integer)[] ---@type table<string, vm.global> local allGlobals = {} @@ -161,6 +164,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local global = vm.declareGlobal('variable', name, uri) global:addSet(uri, source) source._globalNode = global @@ -169,6 +175,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local global = vm.declareGlobal('variable', name, uri) global:addGet(uri, source) source._globalNode = global @@ -271,6 +280,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local class = vm.declareGlobal('type', name, uri) class:addSet(uri, source) source._globalNode = class @@ -293,6 +305,9 @@ local compilerGlobalSwitch = util.switch() : call(function (source) local uri = guide.getUri(source) local name = guide.getKeyName(source) + if not name then + return + end local alias = vm.declareGlobal('type', name, uri) alias:addSet(uri, source) source._globalNode = alias @@ -305,10 +320,51 @@ local compilerGlobalSwitch = util.switch() source.extends._generic = vm.createGeneric(source.extends, source._sign) end end) + : case 'doc.enum' + : call(function (source) + local uri = guide.getUri(source) + local name = guide.getKeyName(source) + if not name then + return + end + local enum = vm.declareGlobal('type', name, uri) + enum:addSet(uri, source) + source._globalNode = enum + + local tbl = source.bindSource + if not tbl then + return + end + source._enums = {} + for _, field in ipairs(tbl) do + if field.type == 'tablefield' + or field.type == 'tableindex' then + if not field.value then + goto CONTINUE + end + local key = guide.getKeyName(field) + if not key then + goto CONTINUE + end + if field.value.type == 'integer' + or field.value.type == 'string' then + source._enums[#source._enums+1] = field.value[1] + end + if field.value.type == 'binary' + or field.value.type == 'unary' then + source._enums[#source._enums+1] = vm.getNumber(field.value) + end + ::CONTINUE:: + end + end + end) : case 'doc.type.name' : call(function (source) local uri = guide.getUri(source) local name = source[1] + if name == '_' then + return + end local type = vm.declareGlobal('type', name, uri) type:addGet(uri, source) source._globalNode = type @@ -446,7 +502,7 @@ local function compileSelf(source) if not node then return end - local fields = vm.getLocalFields(source) + local fields = vm.getLocalFields(source, false) if not fields then return end @@ -491,6 +547,7 @@ local function compileAst(source) 'doc.alias', 'doc.type.name', 'doc.extends.name', + 'doc.enum', }, function (src) compileObject(src) end) @@ -530,9 +587,11 @@ for uri in files.eachFile() do end end +---@async files.watch(function (ev, uri) if ev == 'update' then dropUri(uri) + ws.awaitReady(uri) local state = files.getState(uri) if state then compileAst(state.ast) diff --git a/script/vm/infer.lua b/script/vm/infer.lua index fabc9828..263b2500 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -8,10 +8,9 @@ local vm = require 'vm.vm' ---@field views table<string, boolean> ---@field cachedView? string ---@field node? vm.node ----@field uri? uri +---@field _drop table local mt = {} mt.__index = mt -mt._hasNumber = false mt._hasTable = false mt._hasClass = false mt._hasFunctionDef = false @@ -21,6 +20,8 @@ mt._isLocal = false vm.NULL = setmetatable({}, mt) +local LOCK = {} + local inferSorted = { ['boolean'] = - 100, ['string'] = - 99, @@ -43,14 +44,13 @@ local viewNodeSwitch = util.switch() end) : case 'number' : call(function (source, infer) - infer._hasNumber = true return source.type end) : case 'table' - : call(function (source, infer) + : call(function (source, infer, uri) if source.type == 'table' then if #source == 1 and source[1].type == 'varargs' then - local node = vm.getInfer(source[1]):view() + local node = vm.getInfer(source[1]):view(uri) return ('%s[]'):format(node) end end @@ -76,19 +76,18 @@ local viewNodeSwitch = util.switch() : case 'global' : call(function (source, infer) if source.cate == 'type' then - infer._hasClass = true - if source.name == 'number' then - infer._hasNumber = true + if not guide.isBasicType(source.name) then + infer._hasClass = true end return source.name end end) : case 'doc.type.name' - : call(function (source, infer) + : call(function (source, infer, uri) if source.signs then local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = vm.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view(uri) end return ('%s<%s>'):format(source[1], table.concat(buf, ', ')) else @@ -96,34 +95,68 @@ local viewNodeSwitch = util.switch() end end) : case 'generic' - : call(function (source, infer) - return vm.getInfer(source.proto):view() + : call(function (source, infer, uri) + return vm.getInfer(source.proto):view(uri) end) : case 'doc.generic.name' : call(function (source, infer) return ('<%s>'):format(source[1]) end) : case 'doc.type.array' - : call(function (source, infer) + : call(function (source, infer, uri) infer._hasClass = true - local view = vm.getInfer(source.node):view() + local view = vm.getInfer(source.node):view(uri) if source.node.type == 'doc.type' then view = '(' .. view .. ')' end return view .. '[]' end) : case 'doc.type.sign' - : call(function (source, infer) + : call(function (source, infer, uri) infer._hasClass = true local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = vm.getInfer(sign):view() + buf[i] = vm.getInfer(sign):view(uri) + end + if infer._drop then + local node = vm.compileNode(source) + for c in node:eachObject() do + if guide.isLiteral(c) then + infer._drop[c] = true + end + end end return ('%s<%s>'):format(source.node[1], table.concat(buf, ', ')) end) : case 'doc.type.table' - : call(function (source, infer) - infer._hasTable = true + : call(function (source, infer, uri) + if #source.fields == 0 then + infer._hasTable = true + return + end + if infer._drop and infer._drop[source] then + infer._hasTable = true + return + end + infer._hasClass = true + local buf = {} + buf[#buf+1] = '{ ' + for i, field in ipairs(source.fields) do + if i > 1 then + buf[#buf+1] = ', ' + end + local key = field.name + if key.type == 'doc.type' then + buf[#buf+1] = ('[%s]: '):format(vm.getInfer(key):view(uri)) + elseif type(key[1]) == 'string' then + buf[#buf+1] = key[1] .. ': ' + else + buf[#buf+1] = ('[%q]: '):format(key[1]) + end + buf[#buf+1] = vm.getInfer(field.extends):view(uri) + end + buf[#buf+1] = ' }' + return table.concat(buf) end) : case 'doc.type.string' : call(function (source, infer) @@ -134,8 +167,12 @@ local viewNodeSwitch = util.switch() : call(function (source, infer) return ('%q'):format(source[1]) end) - : case 'doc.type.function' + : case 'doc.type.code' : call(function (source, infer) + return ('`%s`'):format(source[1]) + end) + : case 'doc.type.function' + : call(function (source, infer, uri) infer._hasDocFunction = true local args = {} local rets = {} @@ -148,31 +185,53 @@ local viewNodeSwitch = util.switch() argNode = argNode:copy() argNode:removeOptional() end - args[i] = string.format('%s%s: %s' + args[i] = string.format('%s%s%s%s' , arg.name[1] , isOptional and '?' or '' - , vm.getInfer(argNode):view() + , arg.name[1] == '...' and '' or ': ' + , vm.getInfer(argNode):view(uri) ) end if #args > 0 then argView = table.concat(args, ', ') end + local needReturnParen for i, ret in ipairs(source.returns) do - rets[i] = vm.getInfer(ret):view() + local retType = vm.getInfer(ret):view(uri) + if ret.name then + if ret.name[1] == '...' then + rets[i] = ('%s%s'):format(ret.name[1], retType) + else + needReturnParen = true + rets[i] = ('%s: %s'):format(ret.name[1], retType) + end + else + rets[i] = retType + end end if #rets > 0 then - regView = ':' .. table.concat(rets, ', ') + if needReturnParen then + regView = (':(%s)'):format(table.concat(rets, ', ')) + else + regView = (':%s'):format(table.concat(rets, ', ')) + end end return ('fun(%s)%s'):format(argView, regView) end) ----@param source parser.object | vm.node +---@class vm.node +---@field lastInfer? vm.infer + +---@param source vm.object | vm.node ---@return vm.infer function vm.getInfer(source) + ---@type vm.node local node if source.type == 'vm.node' then + ---@cast source vm.node node = source else + ---@cast source vm.object node = vm.compileNode(source) end if node.lastInfer then @@ -180,7 +239,7 @@ function vm.getInfer(source) end local infer = setmetatable({ node = node, - uri = source.type ~= 'vm.node' and guide.getUri(source), + _drop = {}, }, mt) node.lastInfer = infer @@ -188,9 +247,6 @@ function vm.getInfer(source) end function mt:_trim() - if self._hasNumber then - self.views['integer'] = nil - end if self._hasDocFunction then if self._hasFunctionDef then for view in pairs(self.views) do @@ -205,6 +261,13 @@ function mt:_trim() if self._hasTable and not self._hasClass then self.views['table'] = true end + if self.views['number'] then + self.views['integer'] = nil + end + if self.views['boolean'] then + self.views['true'] = nil + self.views['false'] = nil + end end ---@param uri uri @@ -214,46 +277,86 @@ function mt:_eraseAlias(uri) local expandAlias = config.get(uri, 'Lua.hover.expandAlias') for n in self.node:eachObject() do if n.type == 'global' and n.cate == 'type' then + if LOCK[n.name] then + goto CONTINUE + end + LOCK[n.name] = true for _, set in ipairs(n:getSets(uri)) do if set.type == 'doc.alias' then if expandAlias then drop[n.name] = true + local newInfer = {} + for _, ext in ipairs(set.extends.types) do + viewNodeSwitch(ext.type, ext, newInfer, uri) + end + if newInfer._hasTable then + self.views['table'] = true + end else for _, ext in ipairs(set.extends.types) do - local view = viewNodeSwitch(ext.type, ext, {}) + local view = viewNodeSwitch(ext.type, ext, {}, uri) if view and view ~= n.name then drop[view] = true end end end end + if set.type == 'doc.class' then + if set.extends then + for _, ext in ipairs(set.extends) do + if ext.type == 'doc.extends.name' then + local view = ext[1] + drop[view] = true + end + end + end + end end + LOCK[n.name] = nil + ::CONTINUE:: end end return drop end +---@param uri uri ---@param tp string ---@return boolean -function mt:hasType(tp) - self:_computeViews() +function mt:hasType(uri, tp) + self:_computeViews(uri) return self.views[tp] == true end +---@param uri uri +function mt:hasUnknown(uri) + self:_computeViews(uri) + return not next(self.views) + or self.views['unknown'] == true +end + +---@param uri uri +function mt:hasAny(uri) + self:_computeViews(uri) + return self.views['any'] == true +end + +---@param uri uri ---@return boolean -function mt:hasClass() - self:_computeViews() +function mt:hasClass(uri) + self:_computeViews(uri) return self._hasClass == true end +---@param uri uri ---@return boolean -function mt:hasFunction() - self:_computeViews() +function mt:hasFunction(uri) + self:_computeViews(uri) return self.views['function'] == true or self._hasDocFunction == true end -function mt:_computeViews() +---@param uri uri +function mt:_computeViews(uri) if self.views then return end @@ -261,7 +364,7 @@ function mt:_computeViews() self.views = {} for n in self.node:eachObject() do - local view = viewNodeSwitch(n.type, n, self) + local view = viewNodeSwitch(n.type, n, self, uri) if view then self.views[view] = true end @@ -270,11 +373,11 @@ function mt:_computeViews() self:_trim() end +---@param uri uri ---@param default? string ----@param uri? uri ---@return string -function mt:view(default, uri) - self:_computeViews() +function mt:view(uri, default) + self:_computeViews(uri) if self.views['any'] then return 'any' @@ -282,7 +385,7 @@ function mt:view(default, uri) local drop if self._hasClass then - drop = self:_eraseAlias(uri or self.uri) + drop = self:_eraseAlias(uri) end local array = {} @@ -302,7 +405,7 @@ function mt:view(default, uri) end) local max = #array - local limit = config.get(uri or self.uri, 'Lua.hover.enumsLimit') + local limit = config.get(uri, 'Lua.hover.enumsLimit') local view if #array == 0 then @@ -329,8 +432,9 @@ function mt:view(default, uri) return view end -function mt:eachView() - self:_computeViews() +---@param uri uri +function mt:eachView(uri) + self:_computeViews(uri) return next, self.views end @@ -346,7 +450,6 @@ function mt:merge(other) local infer = setmetatable({ node = vm.createNode(self.node, other.node), - uri = self.uri, }, mt) return infer @@ -365,7 +468,7 @@ function mt:viewLiterals() or n.type == 'integer' or n.type == 'boolean' then local literal = util.viewLiteral(n[1]) - if not mark[literal] then + if literal and not mark[literal] then literals[#literals+1] = literal mark[literal] = true end @@ -374,7 +477,14 @@ function mt:viewLiterals() if #literals == 0 then return nil end - table.sort(literals) + table.sort(literals, function (a, b) + local sa = inferSorted[a] or 0 + local sb = inferSorted[b] or 0 + if sa == sb then + return a < b + end + return sa < sb + end) return table.concat(literals, '|') end @@ -401,8 +511,9 @@ function mt:viewClass() return table.concat(class, '|') end ----@param source parser.object +---@param source vm.node.object +---@param uri uri ---@return string? -function vm.viewObject(source) - return viewNodeSwitch(source.type, source, {}) +function vm.viewObject(source, uri) + return viewNodeSwitch(source.type, source, {}, uri) end diff --git a/script/vm/init.lua b/script/vm/init.lua index f5003c11..87c046f0 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -1,6 +1,6 @@ local vm = require 'vm.vm' ----@alias vm.object parser.object | vm.global | vm.generic +---@alias vm.object parser.object | vm.generic require 'vm.compiler' require 'vm.value' @@ -17,4 +17,6 @@ require 'vm.generic' require 'vm.sign' require 'vm.local-id' require 'vm.global' +require 'vm.function' +require 'vm.operator' return vm diff --git a/script/vm/local-id.lua b/script/vm/local-id.lua index 80c68769..9168d680 100644 --- a/script/vm/local-id.lua +++ b/script/vm/local-id.lua @@ -4,8 +4,8 @@ local guide = require 'parser.guide' local vm = require 'vm.vm' ---@class parser.object ----@field _localID string ----@field _localIDs table<string, parser.object[]> +---@field _localID string|false +---@field _localIDs table<string, { sets: parser.object[], gets: parser.object[] }> local compileLocalID, getLocal @@ -21,6 +21,7 @@ local compileSwitch = util.switch() compileLocalID(ref) end end) + : case 'setlocal' : case 'getlocal' : call(function (source) source._localID = ('%d'):format(source.node.start) @@ -114,10 +115,19 @@ end function vm.insertLocalID(id, source) local root = guide.getRoot(source) if not root._localIDs then - root._localIDs = util.multiTable(2) + root._localIDs = util.multiTable(2, function () + return { + sets = {}, + gets = {}, + } + end) end local sources = root._localIDs[id] - sources[#sources+1] = source + if guide.isSet(source) then + sources.sets[#sources.sets+1] = source + else + sources.gets[#sources.gets+1] = source + end end function compileLocalID(source) @@ -137,7 +147,7 @@ function compileLocalID(source) end ---@param source parser.object ----@return string? +---@return string|false function vm.getLocalID(source) if source._localID ~= nil then return source._localID @@ -154,7 +164,28 @@ end ---@param source parser.object ---@param key? string ---@return parser.object[]? -function vm.getLocalSources(source, key) +function vm.getLocalSourcesSets(source, key) + local id = vm.getLocalID(source) + if not id then + return nil + end + local root = guide.getRoot(source) + if not root._localIDs then + return nil + end + if key then + if type(key) ~= 'string' then + return nil + end + id = id .. vm.ID_SPLITE .. key + end + return root._localIDs[id].sets +end + +---@param source parser.object +---@param key? string +---@return parser.object[]? +function vm.getLocalSourcesGets(source, key) local id = vm.getLocalID(source) if not id then return nil @@ -169,12 +200,13 @@ function vm.getLocalSources(source, key) end id = id .. vm.ID_SPLITE .. key end - return root._localIDs[id] + return root._localIDs[id].gets end ---@param source parser.object ----@return parser.object[] -function vm.getLocalFields(source) +---@param includeGets boolean +---@return parser.object[]? +function vm.getLocalFields(source, includeGets) local id = vm.getLocalID(source) if not id then return nil @@ -192,9 +224,14 @@ function vm.getLocalFields(source) and lid:sub(#id + 1, #id + 1) == vm.ID_SPLITE -- only one field and not lid:find(vm.ID_SPLITE, #id + 2) then - for _, src in ipairs(sources) do + for _, src in ipairs(sources.sets) do fields[#fields+1] = src end + if includeGets then + for _, src in ipairs(sources.gets) do + fields[#fields+1] = src + end + end end end local cost = os.clock() - clock diff --git a/script/vm/node.lua b/script/vm/node.lua index e76542aa..49207b13 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -2,28 +2,32 @@ local files = require 'files' ---@class vm local vm = require 'vm.vm' local ws = require 'workspace.workspace' +local guide = require 'parser.guide' ---@type table<vm.object, vm.node> vm.nodeCache = {} +---@alias vm.node.object vm.object | vm.global + ---@class vm.node ----@field [integer] vm.object +---@field [integer] vm.node.object +---@field [vm.node.object] true local mt = {} mt.__index = mt mt.id = 0 mt.type = 'vm.node' mt.optional = nil -mt.lastInfer = nil mt.data = nil ----@param node vm.node | vm.object +---@param node vm.node | vm.node.object +---@return vm.node function mt:merge(node) if not node then - return + return self end if node.type == 'vm.node' then if node == self then - return + return self end if node:isOptional() then self.optional = true @@ -35,11 +39,13 @@ function mt:merge(node) end end else + ---@cast node -vm.node if not self[node] then self[node] = true self[#self+1] = node end end + return self end ---@return boolean @@ -56,7 +62,7 @@ function mt:clear() end ---@param n integer ----@return vm.object? +---@return vm.node.object? function mt:get(n) return self[n] end @@ -68,6 +74,7 @@ function mt:setData(k, v) self.data[k] = v end +---@return any function mt:getData(k) if not self.data then return nil @@ -81,6 +88,7 @@ end function mt:removeOptional() self:remove 'nil' + return self end ---@return boolean @@ -106,6 +114,19 @@ function mt:hasFalsy() end ---@return boolean +function mt:hasKnownType() + for _, c in ipairs(self) do + if c.type == 'global' and c.cate == 'type' then + return true + end + if guide.isLiteral(c) then + return true + end + end + return false +end + +---@return boolean function mt:isNullable() if self.optional then return true @@ -116,7 +137,8 @@ function mt:isNullable() for _, c in ipairs(self) do if c.type == 'nil' or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') - or (c.type == 'global' and c.cate == 'type' and c.name == 'any') then + or (c.type == 'global' and c.cate == 'type' and c.name == 'any') + or (c.type == 'global' and c.cate == 'type' and c.name == '...') then return true end end @@ -140,18 +162,25 @@ function mt:setTruthy() self[c] = nil goto CONTINUE end - if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean') - or (c.type == 'boolean' or c.type == 'doc.type.boolean') then + if c.type == 'global' and c.cate == 'type' and c.name == 'boolean' then hasBoolean = true table.remove(self, index) self[c] = nil goto CONTINUE end + if c.type == 'boolean' or c.type == 'doc.type.boolean' then + if c[1] == false then + table.remove(self, index) + self[c] = nil + goto CONTINUE + end + end ::CONTINUE:: end if hasBoolean then - self[#self+1] = vm.declareGlobal('type', 'true') + self:merge(vm.declareGlobal('type', 'true')) end + return self end ---@return vm.node @@ -165,21 +194,39 @@ function mt:setFalsy() if c.type == 'nil' or (c.type == 'global' and c.cate == 'type' and c.name == 'nil') or (c.type == 'global' and c.cate == 'type' and c.name == 'false') - or (c.type == 'boolean' and c[1] == true) - or (c.type == 'doc.type.boolean' and c[1] == true) then + or (c.type == 'boolean' and c[1] == false) + or (c.type == 'doc.type.boolean' and c[1] == false) then goto CONTINUE end - if (c.type == 'global' and c.cate == 'type' and c.name == 'boolean') - or (c.type == 'boolean' or c.type == 'doc.type.boolean') then + if c.type == 'global' and c.cate == 'type' and c.name == 'boolean' then hasBoolean = true table.remove(self, index) self[c] = nil + goto CONTINUE + end + if c.type == 'boolean' or c.type == 'doc.type.boolean' then + if c[1] == true then + table.remove(self, index) + self[c] = nil + goto CONTINUE + end + end + if (c.type == 'global' and c.cate == 'type') then + table.remove(self, index) + self[c] = nil + goto CONTINUE + end + if guide.isLiteral(c) then + table.remove(self, index) + self[c] = nil + goto CONTINUE end ::CONTINUE:: end if hasBoolean then - self[#self+1] = vm.declareGlobal('type', 'false') + self:merge(vm.declareGlobal('type', 'false')) end + return self end ---@param name string @@ -193,25 +240,144 @@ function mt:remove(name) or (c.type == name) or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer')) or (c.type == 'doc.type.boolean' and name == 'boolean') + or (c.type == 'doc.type.boolean' and name == 'true' and c[1] == true) + or (c.type == 'doc.type.boolean' and name == 'false' and c[1] == false) or (c.type == 'doc.type.table' and name == 'table') or (c.type == 'doc.type.array' and name == 'table') + or (c.type == 'doc.type.sign' and name == c.node[1]) or (c.type == 'doc.type.function' and name == 'function') then table.remove(self, index) self[c] = nil end end + return self +end + +---@param name string +function mt:narrow(name) + if name ~= 'nil' and self.optional == true then + self.optional = nil + end + for index = #self, 1, -1 do + local c = self[index] + if (c.type == name) + or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer')) + or (c.type == 'doc.type.boolean' and name == 'boolean') + or (c.type == 'doc.type.table' and name == 'table') + or (c.type == 'doc.type.array' and name == 'table') + or (c.type == 'doc.type.sign' and name == c.node[1]) + or (c.type == 'doc.type.function' and name == 'function') then + goto CONTINUE + end + if c.type == 'global' and c.cate == 'type' then + if (c.name == name) + or (c.name == 'integer' and name == 'number') then + goto CONTINUE + end + end + table.remove(self, index) + self[c] = nil + ::CONTINUE:: + end + if #self == 0 then + self[#self+1] = vm.getGlobal('type', name) + end + return self +end + +---@param obj vm.object +function mt:removeObject(obj) + for index, c in ipairs(self) do + if c == obj then + table.remove(self, index) + self[c] = nil + return + end + end end ---@param node vm.node function mt:removeNode(node) for _, c in ipairs(node) do if c.type == 'global' and c.cate == 'type' then + ---@cast c vm.global self:remove(c.name) + elseif c.type == 'nil' then + self:remove 'nil' + elseif c.type == 'boolean' + or c.type == 'doc.type.boolean' then + if c[1] == true then + self:remove 'true' + else + self:remove 'false' + end + else + ---@cast c -vm.global + self:removeObject(c) end end end ----@return fun():vm.object +---@param name string +---@return boolean +function mt:hasType(name) + for _, c in ipairs(self) do + if c.type == 'global' and c.cate == 'type' and c.name == name then + return true + end + end + return false +end + +---@param name string +---@return boolean +function mt:hasName(name) + if name == 'nil' and self.optional == true then + return true + end + for _, c in ipairs(self) do + if c.type == 'global' and c.cate == 'type' and c.name == name then + return true + end + if c.type == name then + return true + end + -- TODO + end + return false +end + +---@return vm.node +function mt:asTable() + self.optional = nil + for index = #self, 1, -1 do + local c = self[index] + if c.type == 'table' + or c.type == 'doc.type.table' + or c.type == 'doc.type.array' then + goto CONTINUE + end + if c.type == 'doc.type.sign' then + if c.node[1] == 'table' + or not guide.isBasicType(c.node[1]) then + goto CONTINUE + end + end + if c.type == 'global' and c.cate == 'type' then + ---@cast c vm.global + if c.name == 'table' + or not guide.isBasicType(c.name) then + goto CONTINUE + end + end + table.remove(self, index) + self[c] = nil + ::CONTINUE:: + end + return self +end + +---@return fun():vm.node.object function mt:eachObject() local i = 0 return function () @@ -226,8 +392,9 @@ function mt:copy() end ---@param source vm.object ----@param node vm.node | vm.object +---@param node vm.node | vm.node.object ---@param cover? boolean +---@return vm.node function vm.setNode(source, node, cover) if not node then if TEST then @@ -236,23 +403,23 @@ function vm.setNode(source, node, cover) log.error('Can not set nil node') end end - if source.type == 'global' then - error('Can not set node to global') - end if cover then + ---@cast node vm.node vm.nodeCache[source] = node - return + return node end local me = vm.nodeCache[source] if me then me:merge(node) else if node.type == 'vm.node' then - vm.nodeCache[source] = node:copy() + me = node:copy() else - vm.nodeCache[source] = vm.createNode(node) + me = vm.createNode(node) end + vm.nodeCache[source] = me end + return me end ---@param source vm.object @@ -291,8 +458,8 @@ end local ID = 0 ----@param a? vm.node | vm.object ----@param b? vm.node | vm.object +---@param a? vm.node | vm.node.object +---@param b? vm.node | vm.node.object ---@return vm.node function vm.createNode(a, b) ID = ID + 1 @@ -315,3 +482,9 @@ files.watch(function (ev, uri) end end end) + +ws.watch(function (ev, uri) + if ev == 'reload' then + vm.clearNodeCache() + end +end) diff --git a/script/vm/operator.lua b/script/vm/operator.lua new file mode 100644 index 00000000..9dea01c1 --- /dev/null +++ b/script/vm/operator.lua @@ -0,0 +1,368 @@ +---@class vm +local vm = require 'vm.vm' +local util = require 'utility' +local guide = require 'parser.guide' +local config = require 'config' + +vm.UNARY_OP = { + 'unm', + 'bnot', + 'len', +} +vm.BINARY_OP = { + 'add', + 'sub', + 'mul', + 'div', + 'mod', + 'pow', + 'idiv', + 'band', + 'bor', + 'bxor', + 'shl', + 'shr', + 'concat', +} +vm.OTHER_OP = { + 'call', +} + +local unaryMap = { + ['-'] = 'unm', + ['~'] = 'bnot', + ['#'] = 'len', +} + +local binaryMap = { + ['+'] = 'add', + ['-'] = 'sub', + ['*'] = 'mul', + ['/'] = 'div', + ['%'] = 'mod', + ['^'] = 'pow', + ['//'] = 'idiv', + ['&'] = 'band', + ['|'] = 'bor', + ['~'] = 'bxor', + ['<<'] = 'shl', + ['>>'] = 'shr', + ['..'] = 'concat', +} + +local otherMap = { + ['()'] = 'call', +} + +vm.OP_UNARY_MAP = util.revertMap(unaryMap) +vm.OP_BINARY_MAP = util.revertMap(binaryMap) +vm.OP_OTHER_MAP = util.revertMap(otherMap) + +---@param operators parser.object[] +---@param op string +---@param value? parser.object +---@param result? vm.node +---@return vm.node? +local function checkOperators(operators, op, value, result) + for _, operator in ipairs(operators) do + if operator.op[1] ~= op + or not operator.extends then + goto CONTINUE + end + if value and operator.exp then + local valueNode = vm.compileNode(value) + local expNode = vm.compileNode(operator.exp) + local uri = guide.getUri(operator) + if not vm.isSubType(uri, valueNode, expNode) then + goto CONTINUE + end + end + if not result then + result = vm.createNode() + end + result:merge(vm.compileNode(operator.extends)) + ::CONTINUE:: + end + return result +end + +---@param op string +---@param exp parser.object +---@param value? parser.object +---@return vm.node? +function vm.runOperator(op, exp, value) + local uri = guide.getUri(exp) + local node = vm.compileNode(exp) + local result + for c in node:eachObject() do + if c.type == 'string' + or c.type == 'doc.type.string' then + c = vm.declareGlobal('type', 'string') + end + if c.type == 'global' and c.cate == 'type' then + ---@cast c vm.global + for _, set in ipairs(c:getSets(uri)) do + if set.operators and #set.operators > 0 then + result = checkOperators(set.operators, op, value, result) + end + end + end + end + return result +end + +vm.unarySwich = util.switch() + : case 'not' + : call(function (source) + local result = vm.testCondition(source[1]) + if result == nil then + vm.setNode(source, vm.declareGlobal('type', 'boolean')) + else + vm.setNode(source, { + type = 'boolean', + start = source.start, + finish = source.finish, + parent = source, + [1] = not result, + }) + end + end) + : case '#' + : call(function (source) + local node = vm.runOperator('len', source[1]) + vm.setNode(source, node or vm.declareGlobal('type', 'integer')) + end) + : case '-' + : call(function (source) + local v = vm.getNumber(source[1]) + if v == nil then + local uri = guide.getUri(source) + local infer = vm.getInfer(source[1]) + if infer:hasType(uri, 'integer') then + vm.setNode(source, vm.declareGlobal('type', 'integer')) + elseif infer:hasType(uri, 'number') then + vm.setNode(source, vm.declareGlobal('type', 'number')) + else + local node = vm.runOperator('unm', source[1]) + vm.setNode(source, node or vm.declareGlobal('type', 'number')) + end + else + vm.setNode(source, { + type = 'number', + start = source.start, + finish = source.finish, + parent = source, + [1] = -v, + }) + end + end) + : case '~' + : call(function (source) + local v = vm.getInteger(source[1]) + if v == nil then + local node = vm.runOperator('bnot', source[1]) + vm.setNode(source, node or vm.declareGlobal('type', 'integer')) + else + vm.setNode(source, { + type = 'integer', + start = source.start, + finish = source.finish, + parent = source, + [1] = ~v, + }) + end + end) + +vm.binarySwitch = util.switch() + : case 'and' + : call(function (source) + local node1 = vm.compileNode(source[1]) + local node2 = vm.compileNode(source[2]) + local r1 = vm.testCondition(source[1]) + if r1 == true then + vm.setNode(source, node2) + elseif r1 == false then + vm.setNode(source, node1) + else + local node = node1:copy():setFalsy():merge(node2) + vm.setNode(source, node) + end + end) + : case 'or' + : call(function (source) + local node1 = vm.compileNode(source[1]) + local node2 = vm.compileNode(source[2]) + local r1 = vm.testCondition(source[1]) + if r1 == true then + vm.setNode(source, node1) + elseif r1 == false then + vm.setNode(source, node2) + else + local node = node1:copy():setTruthy():merge(node2) + vm.setNode(source, node) + end + end) + : case '==' + : case '~=' + : call(function (source) + local result = vm.equal(source[1], source[2]) + if result == nil then + vm.setNode(source, vm.declareGlobal('type', 'boolean')) + else + if source.op.type == '~=' then + result = not result + end + vm.setNode(source, { + type = 'boolean', + start = source.start, + finish = source.finish, + parent = source, + [1] = result, + }) + end + end) + : case '<<' + : case '>>' + : case '&' + : case '|' + : case '~' + : call(function (source) + local a = vm.getInteger(source[1]) + local b = vm.getInteger(source[2]) + local op = source.op.type + if a and b then + local result = op == '<<' and a << b + or op == '>>' and a >> b + or op == '&' and a & b + or op == '|' and a | b + or op == '~' and a ~ b + vm.setNode(source, { + type = 'integer', + start = source.start, + finish = source.finish, + parent = source, + [1] = result, + }) + else + local node = vm.runOperator(binaryMap[op], source[1], source[2]) + vm.setNode(source, node or vm.declareGlobal('type', 'integer')) + end + end) + : case '+' + : case '-' + : case '*' + : case '/' + : case '%' + : case '//' + : case '^' + : call(function (source) + local a = vm.getNumber(source[1]) + local b = vm.getNumber(source[2]) + local op = source.op.type + local zero = b == 0 + and ( op == '%' + or op == '/' + or op == '//' + ) + if a and b and not zero then + local result = op == '+' and a + b + or op == '-' and a - b + or op == '*' and a * b + or op == '/' and a / b + or op == '%' and a % b + or op == '//' and a // b + or op == '^' and a ^ b + vm.setNode(source, { + type = math.type(result) == 'integer' and 'integer' or 'number', + start = source.start, + finish = source.finish, + parent = source, + [1] = result, + }) + else + local node = vm.runOperator(binaryMap[op], source[1], source[2]) + if node then + vm.setNode(source, node) + return + end + if op == '+' + or op == '-' + or op == '*' + or op == '//' + or op == '%' then + local uri = guide.getUri(source) + local infer1 = vm.getInfer(source[1]) + local infer2 = vm.getInfer(source[2]) + if infer1:hasType(uri, 'integer') + or infer2:hasType(uri, 'integer') then + if not infer1:hasType(uri, 'number') + and not infer2:hasType(uri, 'number') then + vm.setNode(source, vm.declareGlobal('type', 'integer')) + return + end + end + end + vm.setNode(source, node or vm.declareGlobal('type', 'number')) + end + end) + : case '..' + : call(function (source) + local a = vm.getString(source[1]) + or vm.getNumber(source[1]) + local b = vm.getString(source[2]) + or vm.getNumber(source[2]) + if a and b then + if type(a) == 'number' or type(b) == 'number' then + local uri = guide.getUri(source) + local version = config.get(uri, 'Lua.runtime.version') + if math.tointeger(a) and math.type(a) == 'float' then + if version == 'Lua 5.3' or version == 'Lua 5.4' then + a = ('%.1f'):format(a) + else + a = ('%.0f'):format(a) + end + end + if math.tointeger(b) and math.type(b) == 'float' then + if version == 'Lua 5.3' or version == 'Lua 5.4' then + b = ('%.1f'):format(b) + else + b = ('%.0f'):format(b) + end + end + end + vm.setNode(source, { + type = 'string', + start = source.start, + finish = source.finish, + parent = source, + [1] = a .. b, + }) + else + local node = vm.runOperator(binaryMap[source.op.type], source[1], source[2]) + vm.setNode(source, node or vm.declareGlobal('type', 'string')) + end + end) + : case '>' + : case '<' + : case '>=' + : case '<=' + : call(function (source) + local a = vm.getNumber(source[1]) + local b = vm.getNumber(source[2]) + if a and b then + local op = source.op.type + local result = op == '>' and a > b + or op == '<' and a < b + or op == '>=' and a >= b + or op == '<=' and a <= b + vm.setNode(source, { + type = 'boolean', + start = source.start, + finish = source.finish, + parent = source, + [1] =result, + }) + else + vm.setNode(source, vm.declareGlobal('type', 'boolean')) + end + end) diff --git a/script/vm/ref.lua b/script/vm/ref.lua index 545c294a..0135d11f 100644 --- a/script/vm/ref.lua +++ b/script/vm/ref.lua @@ -9,58 +9,7 @@ local lang = require 'language' local simpleSwitch -local function searchGetLocal(source, node, pushResult) - local key = guide.getKeyName(source) - for _, ref in ipairs(node.node.ref) do - if ref.type == 'getlocal' - and ref.next - and guide.getKeyName(ref.next) == key then - pushResult(ref.next) - end - end -end - simpleSwitch = util.switch() - : case 'local' - : call(function (source, pushResult) - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' - or ref.type == 'getlocal' then - pushResult(ref) - end - end - end - end) - : case 'getlocal' - : case 'setlocal' - : call(function (source, pushResult) - simpleSwitch('local', source.node, pushResult) - end) - : case 'field' - : call(function (source, pushResult) - local parent = source.parent - if parent.type ~= 'tablefield' then - simpleSwitch(parent.type, parent, pushResult) - end - end) - : case 'setfield' - : case 'getfield' - : call(function (source, pushResult) - local node = source.node - if node.type == 'getlocal' then - searchGetLocal(source, node, pushResult) - return - end - end) - : case 'getindex' - : case 'setindex' - : call(function (source, pushResult) - local node = source.node - if node.type == 'getlocal' then - searchGetLocal(source, node, pushResult) - end - end) : case 'goto' : call(function (source, pushResult) if source.node then @@ -142,21 +91,21 @@ local function searchField(source, pushResult, defMap, fileNotify) return end ---@async - guide.eachSourceType(state.ast, 'getfield', function (src) + guide.eachSourceTypes(state.ast, {'getfield', 'setfield'}, function (src) if src.field and src.field[1] == key then checkDef(src) await.delay() end end) ---@async - guide.eachSourceType(state.ast, 'getmethod', function (src) + guide.eachSourceTypes(state.ast, {'getmethod', 'setmethod'}, function (src) if src.method and src.method[1] == key then checkDef(src) await.delay() end end) ---@async - guide.eachSourceType(state.ast, 'getindex', function (src) + guide.eachSourceTypes(state.ast, {'getindex', 'setindex'}, function (src) if src.index and src.index.type == 'string' and src.index[1] == key then checkDef(src) await.delay() @@ -240,19 +189,24 @@ end ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchByLocalID(source, pushResult) - local idSources = vm.getLocalSources(source) - if not idSources then - return + local sourceSets = vm.getLocalSourcesSets(source) + if sourceSets then + for _, src in ipairs(sourceSets) do + pushResult(src) + end end - for _, src in ipairs(idSources) do - pushResult(src) + local sourceGets = vm.getLocalSourcesGets(source) + if sourceGets then + for _, src in ipairs(sourceGets) do + pushResult(src) + end end end ---@async ---@param source parser.object ---@param pushResult fun(src: parser.object) ----@param fileNotify fun(uri: uri): boolean +---@param fileNotify? fun(uri: uri): boolean function searchByParentNode(source, pushResult, defMap, fileNotify) nodeSwitch(source.type, source, pushResult, defMap, fileNotify) end @@ -279,10 +233,22 @@ local function searchByDef(source, pushResult) defMap[source] = true return defMap end - local defs = vm.getDefs(source) - for _, def in ipairs(defs) do - pushResult(def) - defMap[def] = true + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + defMap[source] = true + if guide.isSet(source) then + local defs = vm.getDefs(source) + for _, def in ipairs(defs) do + pushResult(def) + end + else + local defs = vm.getDefs(source) + for _, def in ipairs(defs) do + pushResult(def) + defMap[def] = true + end end return defMap end diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 9fe0f172..2f047983 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -2,257 +2,22 @@ local vm = require 'vm.vm' local guide = require 'parser.guide' +---@alias vm.runner.callback fun(src: parser.object, node?: vm.node) + ---@class vm.runner ----@field loc parser.object ----@field mainBlock parser.object ----@field blocks table<parser.object, true> ----@field steps vm.runner.step[] +---@field _loc parser.object +---@field _casts parser.object[] +---@field _callback vm.runner.callback +---@field _mark table +---@field _has table<parser.object, true> +---@field _main parser.object local mt = {} mt.__index = mt -mt.index = 1 - ----@class parser.object ----@field _casts parser.object[] - ----@class vm.runner.step ----@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast' ----@field pos integer ----@field order? integer ----@field node? vm.node ----@field object? parser.object ----@field name? string ----@field cast? parser.object ----@field tag? string ----@field copy? boolean ----@field new? boolean ----@field ref1? vm.runner.step ----@field ref2? vm.runner.step - ----@param filter parser.object ----@param outStep vm.runner.step ----@param blockStep vm.runner.step -function mt:_compileNarrowByFilter(filter, outStep, blockStep) - if not filter then - return - end - if filter.type == 'paren' then - if filter.exp then - self:_compileNarrowByFilter(filter.exp, outStep, blockStep) - end - return - end - if filter.type == 'unary' then - if not filter.op - or not filter[1] then - return - end - if filter.op.type == 'not' then - local exp = filter[1] - if exp.type == 'getlocal' and exp.node == self.loc then - self.steps[#self.steps+1] = { - type = 'falsy', - pos = filter.finish, - new = true, - } - self.steps[#self.steps+1] = { - type = 'truthy', - pos = filter.finish, - ref1 = outStep, - } - end - end - elseif filter.type == 'binary' then - if not filter.op - or not filter[1] - or not filter[2] then - return - end - if filter.op.type == 'and' then - local dummyStep = { - type = 'save', - copy = true, - ref1 = outStep, - pos = filter.start - 1, - } - self.steps[#self.steps+1] = dummyStep - self:_compileNarrowByFilter(filter[1], dummyStep, blockStep) - self:_compileNarrowByFilter(filter[2], dummyStep, blockStep) - end - if filter.op.type == 'or' then - self:_compileNarrowByFilter(filter[1], outStep, blockStep) - local dummyStep = { - type = 'push', - copy = true, - ref1 = outStep, - pos = filter.op.finish, - } - self.steps[#self.steps+1] = dummyStep - self:_compileNarrowByFilter(filter[2], outStep, dummyStep) - self.steps[#self.steps+1] = { - type = 'push', - tag = 'or reset', - ref1 = blockStep, - pos = filter.finish, - } - end - if filter.op.type == '==' - or filter.op.type == '~=' then - local loc, exp - for i = 1, 2 do - loc = filter[i] - if loc.type == 'getlocal' and loc.node == self.loc then - exp = filter[i % 2 + 1] - break - end - end - if not loc or not exp then - return - end - if guide.isLiteral(exp) then - if filter.op.type == '==' then - self.steps[#self.steps+1] = { - type = 'remove', - name = exp.type, - pos = filter.finish, - ref1 = outStep, - } - self.steps[#self.steps+1] = { - type = 'as', - name = exp.type, - pos = filter.finish, - new = true, - } - end - if filter.op.type == '~=' then - self.steps[#self.steps+1] = { - type = 'as', - name = exp.type, - pos = filter.finish, - ref1 = outStep, - } - self.steps[#self.steps+1] = { - type = 'remove', - name = exp.type, - pos = filter.finish, - new = true, - } - end - end - end - else - if filter.type == 'getlocal' and filter.node == self.loc then - self.steps[#self.steps+1] = { - type = 'truthy', - pos = filter.finish, - new = true, - } - self.steps[#self.steps+1] = { - type = 'falsy', - pos = filter.finish, - ref1 = outStep, - } - end - end -end - ----@param block parser.object -function mt:_compileBlock(block) - if self.blocks[block] then - return - end - self.blocks[block] = true - if block == self.mainBlock then - return - end - - local parentBlock = guide.getParentBlock(block) - self:_compileBlock(parentBlock) - - if block.type == 'if' then - ---@type vm.runner.step[] - local finals = {} - for i, childBlock in ipairs(block) do - local blockStep = { - type = 'save', - tag = 'block', - copy = true, - pos = childBlock.start, - } - local outStep = { - type = 'save', - tag = 'out', - copy = true, - pos = childBlock.start, - } - self.steps[#self.steps+1] = blockStep - self.steps[#self.steps+1] = outStep - self.steps[#self.steps+1] = { - type = 'push', - ref1 = blockStep, - pos = childBlock.start, - } - self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep) - if not childBlock.hasReturn - and not childBlock.hasGoTo - and not childBlock.hasBreak then - local finalStep = { - type = 'save', - pos = childBlock.finish, - tag = 'final #' .. i, - } - finals[#finals+1] = finalStep - self.steps[#self.steps+1] = finalStep - end - self.steps[#self.steps+1] = { - type = 'push', - tag = 'reset child', - ref1 = outStep, - pos = childBlock.finish, - } - end - self.steps[#self.steps+1] = { - type = 'push', - tag = 'reset if', - pos = block.finish, - copy = true, - } - for _, final in ipairs(finals) do - self.steps[#self.steps+1] = { - type = 'merge', - ref2 = final, - pos = block.finish, - } - end - end - - if block.type == 'function' - or block.type == 'while' - or block.type == 'loop' - or block.type == 'in' - or block.type == 'repeat' - or block.type == 'for' then - local savePoint = { - type = 'save', - copy = true, - pos = block.start, - } - self.steps[#self.steps+1] = { - type = 'push', - copy = true, - pos = block.start, - } - self.steps[#self.steps+1] = savePoint - self.steps[#self.steps+1] = { - type = 'push', - pos = block.finish, - ref1 = savePoint, - } - end -end +mt._index = 1 ---@return parser.object[] function mt:_getCasts() - local root = guide.getRoot(self.loc) + local root = guide.getRoot(self._loc) if not root._casts then root._casts = {} local docs = root.docs @@ -265,180 +30,332 @@ function mt:_getCasts() return root._casts end -function mt:_preCompile() - local startPos = self.loc.start - local finishPos = 0 - - for _, ref in ipairs(self.loc.ref) do - self.steps[#self.steps+1] = { - type = 'object', - object = ref, - pos = ref.range or ref.start, - } - if ref.start > finishPos then - finishPos = ref.start +---@param obj parser.object +function mt:_markHas(obj) + while true do + if self._has[obj] then + return end - local block = guide.getParentBlock(ref) - self:_compileBlock(block) + self._has[obj] = true + if obj == self._main then + return + end + obj = obj.parent end +end - for i, step in ipairs(self.steps) do - if step.type ~= 'object' then - step.order = i +function mt:_collect() + local startPos = self._loc.start + local finishPos = 0 + + for _, ref in ipairs(self._loc.ref) do + if ref.type == 'getlocal' + or ref.type == 'setlocal' then + self:_markHas(ref) + if ref.finish > finishPos then + finishPos = ref.finish + end end end local casts = self:_getCasts() for _, cast in ipairs(casts) do - if cast.loc[1] == self.loc[1] + if cast.loc[1] == self._loc[1] and cast.start > startPos and cast.finish < finishPos - and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then - self.steps[#self.steps+1] = { - type = 'cast', - cast = cast, - pos = cast.start, - } + and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then + self._casts[#self._casts+1] = cast end end - - table.sort(self.steps, function (a, b) - if a.pos == b.pos then - return (a.order or 0) < (b.order or 0) - else - return a.pos < b.pos - end - end) end ----@param loc parser.object ----@param node vm.node +---@param pos integer +---@param topNode vm.node ---@return vm.node -local function checkAssert(loc, node) - local parent = loc.parent - if parent.type == 'binary' then - if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then - local exp - for i = 1, 2 do - if parent[i] == loc then - exp = parent[i % 2 + 1] +function mt:_fastWardCasts(pos, topNode) + for i = self._index, #self._casts do + local action = self._casts[i] + if action.start > pos then + self._index = i + return topNode + end + topNode = topNode:copy() + for _, cast in ipairs(action.casts) do + if cast.mode == '+' then + if cast.optional then + topNode:addOptional() end - end - if exp and guide.isLiteral(exp) then - local callargs = parent.parent - if callargs.type == 'callargs' - and callargs.parent.node.special == 'assert' - and callargs[1] == parent then - if parent.op.type == '~=' then - node:remove(exp.type) - end - if parent.op.type == '==' then - node = vm.compileNode(exp) - end + if cast.extends then + topNode:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + topNode:removeOptional() + end + if cast.extends then + topNode:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + topNode:clear() + topNode:merge(vm.compileNode(cast.extends)) end end end end - if parent.type == 'callargs' - and parent.parent.node.special == 'assert' - and parent[1] == loc then - node:setTruthy() - end - return node + self._index = self._index + 1 + return topNode end ----@param callback fun(src: parser.object, node: vm.node) -function mt:launch(callback) - local topNode = vm.getNode(self.loc):copy() - for _, step in ipairs(self.steps) do - local node = step.ref1 and step.ref1.node or topNode - if step.type == 'truthy' then - if step.new then - node = node:copy() - topNode = node - end - node:setTruthy() - elseif step.type == 'falsy' then - if step.new then - node = node:copy() - topNode = node - end - node:setFalsy() - elseif step.type == 'as' then - if step.new then - topNode = vm.createNode(vm.getGlobal('type', step.name)) - else - node:clear() - node:merge(vm.getGlobal('type', step.name)) - end - elseif step.type == 'add' then - if step.new then - node = node:copy() - topNode = node - end - node:merge(vm.getGlobal('type', step.name)) - elseif step.type == 'remove' then - if step.new then - node = node:copy() - topNode = node - end - node:remove(step.name) - elseif step.type == 'object' then - topNode = callback(step.object, node) or node - if step.object.type == 'getlocal' then - topNode = checkAssert(step.object, node) +---@param action parser.object +---@param topNode vm.node +---@param outNode? vm.node +---@return vm.node topNode +---@return vm.node outNode +function mt:_lookIntoChild(action, topNode, outNode) + if not self._has[action] + or self._mark[action] then + return topNode, topNode or outNode + end + self._mark[action] = true + topNode = self:_fastWardCasts(action.start, topNode) + if action.type == 'getlocal' then + if action.node == self._loc then + self._callback(action, topNode) + if outNode then + topNode = topNode:copy():setTruthy() + outNode = outNode:copy():setFalsy() end - elseif step.type == 'save' then - if step.copy then - node = node:copy() + end + elseif action.type == 'function' then + self:_lookIntoBlock(action, topNode:copy()) + elseif action.type == 'unary' then + if not action[1] then + goto RETURN + end + if action.op.type == 'not' then + outNode = outNode or topNode:copy() + outNode, topNode = self:_lookIntoChild(action[1], topNode, outNode) + outNode = outNode:copy() + end + elseif action.type == 'binary' then + if not action[1] or not action[2] then + goto RETURN + end + if action.op.type == 'and' then + topNode = self:_lookIntoChild(action[1], topNode, topNode:copy()) + topNode = self:_lookIntoChild(action[2], topNode, topNode:copy()) + elseif action.op.type == 'or' then + outNode = outNode or topNode:copy() + local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode) + local topNode2, outNode2 = self:_lookIntoChild(action[2], outNode1, outNode1:copy()) + topNode = vm.createNode(topNode1, topNode2) + outNode = outNode2:copy() + elseif action.op.type == '==' + or action.op.type == '~=' then + local handler, checker + for i = 1, 2 do + if guide.isLiteral(action[i]) then + checker = action[i] + handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` + end end - step.node = node - elseif step.type == 'push' then - if step.copy then - node = node:copy() + if not handler then + goto RETURN end - topNode = node - elseif step.type == 'merge' then - node:merge(step.ref2.node) - elseif step.type == 'cast' then - topNode = node:copy() - for _, cast in ipairs(step.cast.casts) do - if cast.mode == '+' then - if cast.optional then - topNode:addOptional() + if handler.type == 'getlocal' + and handler.node == self._loc then + -- if x == y then + topNode = self:_lookIntoChild(handler, topNode, outNode) + local checkerNode = vm.compileNode(checker) + if action.op.type == '==' then + topNode = checkerNode + if outNode then + outNode:removeNode(topNode) end - if cast.extends then - topNode:merge(vm.compileNode(cast.extends)) - end - elseif cast.mode == '-' then - if cast.optional then - topNode:removeOptional() + else + topNode:removeNode(checkerNode) + if outNode then + outNode = checkerNode end - if cast.extends then - topNode:removeNode(vm.compileNode(cast.extends)) + end + elseif handler.type == 'call' + and checker.type == 'string' + and handler.node.special == 'type' + and handler.args + and handler.args[1] + and handler.args[1].type == 'getlocal' + and handler.args[1].node == self._loc then + -- if type(x) == 'string' then + self:_lookIntoChild(handler, topNode:copy()) + if action.op.type == '==' then + topNode:narrow(checker[1]) + if outNode then + outNode:remove(checker[1]) end else - if cast.extends then - topNode:clear() - topNode:merge(vm.compileNode(cast.extends)) + topNode:remove(checker[1]) + if outNode then + outNode:narrow(checker[1]) + end + end + elseif handler.type == 'getlocal' + and checker.type == 'string' then + local nodeValue = vm.getObjectValue(handler.node) + if nodeValue + and nodeValue.type == 'select' + and nodeValue.sindex == 1 then + local call = nodeValue.vararg + if call + and call.type == 'call' + and call.node.special == 'type' + and call.args + and call.args[1] + and call.args[1].type == 'getlocal' + and call.args[1].node == self._loc then + -- `local tp = type(x);if tp == 'string' then` + if action.op.type == '==' then + topNode:narrow(checker[1]) + if outNode then + outNode:remove(checker[1]) + end + else + topNode:remove(checker[1]) + if outNode then + outNode:narrow(checker[1]) + end + end + end + end + end + end + elseif action.type == 'loop' + or action.type == 'in' + or action.type == 'repeat' + or action.type == 'for' then + topNode = self:_lookIntoBlock(action, topNode:copy()) + elseif action.type == 'while' then + local blockNode, mainNode + if action.filter then + blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy()) + else + blockNode = topNode:copy() + mainNode = topNode:copy() + end + blockNode = self:_lookIntoBlock(action, blockNode:copy()) + topNode = mainNode:merge(blockNode) + if action.filter then + -- look into filter again + guide.eachSource(action.filter, function (src) + self._mark[src] = nil + end) + blockNode, topNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy()) + end + elseif action.type == 'if' then + local hasElse + local mainNode = topNode:copy() + local blockNodes = {} + for _, subBlock in ipairs(action) do + local blockNode = mainNode:copy() + if subBlock.filter then + blockNode, mainNode = self:_lookIntoChild(subBlock.filter, blockNode, mainNode) + else + hasElse = true + mainNode:clear() + end + blockNode = self:_lookIntoBlock(subBlock, blockNode:copy()) + local neverReturn = subBlock.hasReturn + or subBlock.hasGoTo + or subBlock.hasBreak + or subBlock.hasError + if not neverReturn then + blockNodes[#blockNodes+1] = blockNode + end + end + if not hasElse and not topNode:hasKnownType() then + mainNode:merge(vm.declareGlobal('type', 'unknown')) + end + for _, blockNode in ipairs(blockNodes) do + mainNode:merge(blockNode) + end + topNode = mainNode + elseif action.type == 'call' then + if action.node.special == 'assert' and action.args and action.args[1] then + topNode = self:_lookIntoChild(action.args[1], topNode, topNode:copy()) + end + elseif action.type == 'paren' then + topNode, outNode = self:_lookIntoChild(action.exp, topNode, outNode) + elseif action.type == 'setlocal' then + if action.node == self._loc then + if action.value then + self:_lookIntoChild(action.value, topNode) + end + topNode = self._callback(action, topNode) + end + elseif action.type == 'local' then + if action.value + and action.ref + and action.value.type == 'select' then + local index = action.value.sindex + local call = action.value.vararg + if index == 1 + and call.type == 'call' + and call.node + and call.node.special == 'type' + and call.args then + local getLoc = call.args[1] + if getLoc + and getLoc.type == 'getlocal' + and getLoc.node == self._loc then + for _, ref in ipairs(action.ref) do + self:_markHas(ref) end end end end end + ::RETURN:: + guide.eachChild(action, function (src) + if self._has[src] then + self:_lookIntoChild(src, topNode) + end + end) + return topNode, outNode or topNode +end + +---@param block parser.object +---@param topNode vm.node +---@return vm.node topNode +function mt:_lookIntoBlock(block, topNode) + if not self._has[block] then + return topNode + end + for _, action in ipairs(block) do + if self._has[action] then + topNode = self:_lookIntoChild(action, topNode) + end + end + topNode = self:_fastWardCasts(block.finish, topNode) + return topNode end ---@param loc parser.object ----@return vm.runner -function vm.createRunner(loc) +---@param callback vm.runner.callback +function vm.launchRunner(loc, callback) + local main = guide.getParentBlock(loc) + if not main then + return + end local self = setmetatable({ - loc = loc, - mainBlock = guide.getParentBlock(loc), - blocks = {}, - steps = {}, + _loc = loc, + _casts = {}, + _mark = {}, + _has = {}, + _main = main, + _callback = callback, }, mt) - self:_preCompile() + self:_collect() - return self + self:_lookIntoBlock(main, vm.getNode(loc):copy()) end diff --git a/script/vm/sign.lua b/script/vm/sign.lua index fe112bc2..7c95fd08 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -17,14 +17,14 @@ end ---@param uri uri ---@param args parser.object ---@param removeGeneric true? ----@return table<string, vm.node> +---@return table<string, vm.node>? function mt:resolve(uri, args, removeGeneric) if not args then return nil end local resolved = {} - ---@param object parser.object + ---@param object vm.node.object ---@param node vm.node local function resolve(object, node) if object.type == 'doc.generic.name' then @@ -33,6 +33,7 @@ function mt:resolve(uri, args, removeGeneric) -- 'number' -> `T` for n in node:eachObject() do if n.type == 'string' then + ---@cast n parser.object local type = vm.declareGlobal('type', n[1], guide.getUri(n)) resolved[key] = vm.createNode(type, resolved[key]) end @@ -50,40 +51,59 @@ function mt:resolve(uri, args, removeGeneric) end if n.type == 'doc.type.table' then -- { [integer]: number } -> T[] - local tvalueNode = vm.getTableValue(uri, node, 'integer') + local tvalueNode = vm.getTableValue(uri, node, 'integer', true) if tvalueNode then resolve(object.node, tvalueNode) end end if n.type == 'global' and n.cate == 'type' then -- ---@field [integer]: number -> T[] + ---@cast n vm.global vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), false, function (field) resolve(object.node, vm.compileNode(field.extends)) end) end + if n.type == 'table' and #n >= 1 then + -- { x } / { ... } -> T[] + resolve(object.node, vm.compileNode(n[1])) + end end end if object.type == 'doc.type.table' then for _, ufield in ipairs(object.fields) do local ufieldNode = vm.compileNode(ufield.name) local uvalueNode = vm.compileNode(ufield.extends) - if ufieldNode:get(1).type == 'doc.generic.name' and uvalueNode:get(1).type == 'doc.generic.name' then + local firstField = ufieldNode:get(1) + local firstValue = uvalueNode:get(1) + if not firstField or not firstValue then + goto CONTINUE + end + if firstField.type == 'doc.generic.name' and firstValue.type == 'doc.generic.name' then -- { [number]: number} -> { [K]: V } - local tfieldNode = vm.getTableKey(uri, node, 'any') - local tvalueNode = vm.getTableValue(uri, node, 'any') - resolve(ufieldNode:get(1), tfieldNode) - resolve(uvalueNode:get(1), tvalueNode) + local tfieldNode = vm.getTableKey(uri, node, 'any', true) + local tvalueNode = vm.getTableValue(uri, node, 'any', true) + if tfieldNode then + resolve(firstField, tfieldNode) + end + if tvalueNode then + resolve(firstValue, tvalueNode) + end else if ufieldNode:get(1).type == 'doc.generic.name' then -- { [number]: number}|number[] -> { [K]: number } - local tnode = vm.getTableKey(uri, node, uvalueNode) - resolve(ufieldNode:get(1), tnode) + local tnode = vm.getTableKey(uri, node, uvalueNode, true) + if tnode then + resolve(firstField, tnode) + end elseif uvalueNode:get(1).type == 'doc.generic.name' then -- { [number]: number}|number[] -> { [number]: V } - local tnode = vm.getTableValue(uri, node, ufieldNode) - resolve(uvalueNode:get(1), tnode) + local tnode = vm.getTableValue(uri, node, ufieldNode, true) + if tnode then + resolve(firstValue, tnode) + end end end + ::CONTINUE:: end end end @@ -102,6 +122,7 @@ function mt:resolve(uri, args, removeGeneric) if obj.type == 'doc.type.table' or obj.type == 'doc.type.function' or obj.type == 'doc.type.array' then + ---@cast obj parser.object local hasGeneric guide.eachSourceType(obj, 'doc.generic.name', function (src) hasGeneric = true @@ -111,7 +132,7 @@ function mt:resolve(uri, args, removeGeneric) goto CONTINUE end end - local view = vm.viewObject(obj) + local view = vm.viewObject(obj, uri) if view then knownTypes[view] = true end @@ -122,21 +143,31 @@ function mt:resolve(uri, args, removeGeneric) -- remove un-generic type ---@param argNode vm.node + ---@param sign vm.node ---@param knownTypes table<string, true> ---@return vm.node - local function buildArgNode(argNode, knownTypes) + local function buildArgNode(argNode, sign, knownTypes) local newArgNode = vm.createNode() + local needRemoveNil = sign:hasFalsy() for n in argNode:eachObject() do - if argNode:hasFalsy() then - goto CONTINUE + if needRemoveNil then + if n.type == 'nil' then + goto CONTINUE + end + if n.type == 'global' and n.cate == 'type' and n.name == 'nil' then + goto CONTINUE + end end - local view = vm.viewObject(n) + local view = vm.viewObject(n, uri) if knownTypes[view] then goto CONTINUE end newArgNode:merge(n) ::CONTINUE:: end + if not needRemoveNil and argNode:isOptional() then + newArgNode:addOptional() + end return newArgNode end @@ -158,7 +189,7 @@ function mt:resolve(uri, args, removeGeneric) local argNode = vm.compileNode(arg) local knownTypes, genericNames = getSignInfo(sign) if not isAllResolved(genericNames) then - local newArgNode = buildArgNode(argNode, knownTypes) + local newArgNode = buildArgNode(argNode,sign, knownTypes) for n in sign:eachObject() do resolve(n, newArgNode) end diff --git a/script/vm/type.lua b/script/vm/type.lua index c3264993..d112be2c 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -1,67 +1,252 @@ ---@class vm local vm = require 'vm.vm' +local guide = require 'parser.guide' +local config = require 'config.config' +local util = require 'utility' ----@param uri uri ----@param child vm.node|string ----@param parent vm.node|string ----@param mark? table ----@return boolean -function vm.isSubType(uri, child, parent, mark) - if type(parent) == 'string' then - parent = vm.createNode(vm.getGlobal('type', parent)) +---@param object vm.node.object +---@return string? +local function getNodeName(object) + if object.type == 'global' and object.cate == 'type' then + ---@cast object vm.global + return object.name end - if type(child) == 'string' then - child = vm.createNode(vm.getGlobal('type', child)) + if object.type == 'nil' + or object.type == 'boolean' + or object.type == 'number' + or object.type == 'string' + or object.type == 'table' + or object.type == 'function' + or object.type == 'integer' then + return object.type + end + if object.type == 'doc.type.boolean' then + return 'boolean' end + if object.type == 'doc.type.integer' then + return 'integer' + end + if object.type == 'doc.type.function' then + return 'function' + end + if object.type == 'doc.type.table' then + return 'table' + end + if object.type == 'doc.type.array' then + return 'table' + end + if object.type == 'doc.type.string' then + return 'string' + end + return nil +end - if not child or not parent then - return false +---@param parentName string +---@param child vm.node.object +---@param uri uri +---@return boolean? +local function checkEnum(parentName, child, uri) + local parentClass = vm.getGlobal('type', parentName) + if not parentClass then + return nil + end + for _, set in ipairs(parentClass:getSets(uri)) do + if set.type == 'doc.enum' then + if not set._enums then + return false + end + if child.type ~= 'string' + and child.type ~= 'doc.type.string' + and child.type ~= 'integer' + and child.type ~= 'number' + and child.type ~= 'doc.type.integer' then + return false + end + return util.arrayHas(set._enums, child[1]) + end + end + + return nil +end + +---@param parent vm.node.object +---@param child vm.node.object +---@return boolean +local function checkValue(parent, child) + if parent.type == 'doc.type.integer' then + if child.type == 'integer' + or child.type == 'doc.type.integer' + or child.type == 'number' then + return parent[1] == child[1] + end + elseif parent.type == 'doc.type.string' then + if child.type == 'string' + or child.type == 'doc.type.string' then + return parent[1] == child[1] + end end + return true +end + +---@param uri uri +---@param child vm.node|string|vm.node.object +---@param parent vm.node|string|vm.node.object +---@param mark? table +---@return boolean +function vm.isSubType(uri, child, parent, mark) mark = mark or {} - for obj in child:eachObject() do - if obj.type ~= 'global' - or obj.cate ~= 'type' then - goto CONTINUE_CHILD + + if type(child) == 'string' then + local global = vm.getGlobal('type', child) + if not global then + return false + end + child = global + elseif child.type == 'vm.node' then + if config.get(uri, 'Lua.type.weakUnionCheck') then + local hasKnownType + for n in child:eachObject() do + if getNodeName(n) then + hasKnownType = true + if vm.isSubType(uri, n, parent, mark) then + return true + end + end + end + return not hasKnownType + else + local weakNil = config.get(uri, 'Lua.type.weakNilCheck') + for n in child:eachObject() do + local nodeName = getNodeName(n) + if nodeName + and not (nodeName == 'nil' and weakNil) + and not vm.isSubType(uri, n, parent, mark) then + return false + end + end + if not weakNil and child:isOptional() then + if not vm.isSubType(uri, 'nil', parent, mark) then + return false + end + end + return true end - if mark[obj.name] then + end + + if type(parent) == 'string' then + local global = vm.getGlobal('type', parent) + if not global then return false end - mark[obj.name] = true - for parentNode in parent:eachObject() do - if parentNode.type ~= 'global' - or parentNode.cate ~= 'type' then - goto CONTINUE_PARENT + parent = global + elseif parent.type == 'vm.node' then + for n in parent:eachObject() do + if getNodeName(n) + and vm.isSubType(uri, child, n, mark) then + return true end - if parentNode.name == 'any' or obj.name == 'any' then + if n.type == 'doc.generic.name' then return true end - - if parentNode.name == obj.name then + end + if parent:isOptional() then + if vm.isSubType(uri, child, 'nil', mark) then return true end + end + return false + end + + ---@cast child vm.node.object + ---@cast parent vm.node.object - for _, set in ipairs(obj:getSets(uri)) do + local childName = getNodeName(child) + local parentName = getNodeName(parent) + if childName == 'any' + or parentName == 'any' + or childName == 'unknown' + or parentName == 'unknown' + or not childName + or not parentName then + return true + end + + if childName == parentName then + if not checkValue(parent, child) then + return false + end + return true + end + + if parentName == 'number' and childName == 'integer' then + return true + end + + if parentName == 'integer' and childName == 'number' then + if config.get(uri, 'Lua.type.castNumberToInteger') then + return true + end + if child.type == 'number' + and child[1] + and not math.tointeger(child[1]) then + return false + end + if child.type == 'global' + and child.cate == 'type' then + return false + end + return true + end + + local isEnum = checkEnum(parentName, child, uri) + if isEnum ~= nil then + return isEnum + end + + -- TODO: check duck + if parentName == 'table' and not guide.isBasicType(childName) then + return true + end + if childName == 'table' and not guide.isBasicType(parentName) then + return true + end + + -- check class parent + if childName and not mark[childName] then + mark[childName] = true + local isBasicType = guide.isBasicType(childName) + local childClass = vm.getGlobal('type', childName) + if childClass then + for _, set in ipairs(childClass:getSets(uri)) do if set.type == 'doc.class' and set.extends then for _, ext in ipairs(set.extends) do if ext.type == 'doc.extends.name' - and vm.isSubType(uri, ext[1], parentNode.name, mark) then + and (not isBasicType or guide.isBasicType(ext[1])) + and vm.isSubType(uri, ext[1], parent, mark) then return true end end end - if set.type == 'doc.alias' and set.extends then - for _, ext in ipairs(set.extends.types) do - if ext.type == 'doc.type.name' - and vm.isSubType(uri, ext[1], parentNode.name, mark) then - return true - end - end + if set.type == 'doc.alias' + or set.type == 'doc.enum' then + return true end end - ::CONTINUE_PARENT:: end - ::CONTINUE_CHILD:: + mark[childName] = nil + end + + --[[ + ---@class A: string + + ---@type A + local x = '' --> `string` set to `A` + ]] + if guide.isBasicType(childName) + and guide.isLiteral(child) + and vm.isSubType(uri, parentName, childName) then + return true end return false @@ -69,16 +254,24 @@ end ---@param uri uri ---@param tnode vm.node ----@param knode vm.node +---@param knode vm.node|string +---@param inversion? boolean ---@return vm.node? -function vm.getTableValue(uri, tnode, knode) +function vm.getTableValue(uri, tnode, knode, inversion) local result = vm.createNode() for tn in tnode:eachObject() do if tn.type == 'doc.type.table' then for _, field in ipairs(tn.fields) do - if vm.isSubType(uri, vm.compileNode(field.name), knode) then - if field.extends then - result:merge(vm.compileNode(field.extends)) + if field.name.type ~= 'doc.field.name' + and field.extends then + if inversion then + if vm.isSubType(uri, vm.compileNode(field.name), knode) then + result:merge(vm.compileNode(field.extends)) + end + else + if vm.isSubType(uri, knode, vm.compileNode(field.name)) then + result:merge(vm.compileNode(field.extends)) + end end end end @@ -88,25 +281,38 @@ function vm.getTableValue(uri, tnode, knode) end if tn.type == 'table' then for _, field in ipairs(tn) do - if field.type == 'tableindex' then - if field.value then - result:merge(vm.compileNode(field.value)) - end + if field.type == 'tableindex' + and field.value then + result:merge(vm.compileNode(field.value)) end - if field.type == 'tablefield' then - if vm.isSubType(uri, knode, 'string') then - if field.value then + if field.type == 'tablefield' + and field.value then + if inversion then + if vm.isSubType(uri, 'string', knode) then + result:merge(vm.compileNode(field.value)) + end + else + if vm.isSubType(uri, knode, 'string') then result:merge(vm.compileNode(field.value)) end end end - if field.type == 'tableexp' then - if vm.isSubType(uri, knode, 'integer') and field.tindex == 1 then - if field.value then + if field.type == 'tableexp' + and field.value + and field.tindex == 1 then + if inversion then + if vm.isSubType(uri, 'integer', knode) then + result:merge(vm.compileNode(field.value)) + end + else + if vm.isSubType(uri, knode, 'integer') then result:merge(vm.compileNode(field.value)) end end end + if field.type == 'varargs' then + result:merge(vm.compileNode(field)) + end end end end @@ -118,16 +324,24 @@ end ---@param uri uri ---@param tnode vm.node ----@param vnode vm.node +---@param vnode vm.node|string|vm.object +---@param reverse? boolean ---@return vm.node? -function vm.getTableKey(uri, tnode, vnode) +function vm.getTableKey(uri, tnode, vnode, reverse) local result = vm.createNode() for tn in tnode:eachObject() do if tn.type == 'doc.type.table' then for _, field in ipairs(tn.fields) do - if field.extends then - if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then - result:merge(vm.compileNode(field.name)) + if field.name.type ~= 'doc.field.name' + and field.extends then + if reverse then + if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then + result:merge(vm.compileNode(field.name)) + end + else + if vm.isSubType(uri, vnode, vm.compileNode(field.extends)) then + result:merge(vm.compileNode(field.name)) + end end end end @@ -156,3 +370,49 @@ function vm.getTableKey(uri, tnode, vnode) end return result end + +---@param uri uri +---@param defNode vm.node +---@param refNode vm.node +---@return boolean +function vm.canCastType(uri, defNode, refNode) + local defInfer = vm.getInfer(defNode) + local refInfer = vm.getInfer(refNode) + + if defInfer:hasAny(uri) then + return true + end + if refInfer:hasAny(uri) then + return true + end + if defInfer:view(uri) == 'unknown' then + return true + end + if refInfer:view(uri) == 'unknown' then + return true + end + + if vm.isSubType(uri, refNode, 'nil') then + -- allow `local x = {};x = nil`, + -- but not allow `local x ---@type table;x = nil` + if defInfer:hasType(uri, 'table') + and not defNode:hasType 'table' then + return true + end + end + + if vm.isSubType(uri, refNode, 'number') then + -- allow `local x = 0;x = 1.0`, + -- but not allow `local x ---@type integer;x = 1.0` + if defInfer:hasType(uri, 'integer') + and not defNode:hasType 'integer' then + return true + end + end + + if vm.isSubType(uri, refNode, defNode) then + return true + end + + return false +end diff --git a/script/vm/value.lua b/script/vm/value.lua index d29ca9d0..7eab4a8e 100644 --- a/script/vm/value.lua +++ b/script/vm/value.lua @@ -4,11 +4,14 @@ local vm = require 'vm.vm' ---@param source parser.object? ---@return boolean|nil -function vm.test(source) +function vm.testCondition(source) if not source then return nil end local node = vm.compileNode(source) + if node.optional then + return nil + end local hasTrue, hasFalse for n in node:eachObject() do if n.type == 'boolean' @@ -19,24 +22,20 @@ function vm.test(source) if n[1] == false then hasFalse = true end - end - if n.type == 'global' and n.cate == 'type' then - if n.name == 'true' then - hasTrue = true + elseif n.type == 'global' and n.cate == 'type' then + if n.name == 'boolean' + or n.name == 'unknown' then + return nil end if n.name == 'false' or n.name == 'nil' then hasFalse = true + else + hasTrue = true end - end - if n.type == 'nil' then + elseif n.type == 'nil' then hasFalse = true - end - if n.type == 'string' - or n.type == 'number' - or n.type == 'integer' - or n.type == 'table' - or n.type == 'function' then + elseif guide.isLiteral(n) then hasTrue = true end end @@ -50,8 +49,8 @@ function vm.test(source) end end ----@param v vm.object ----@return string? +---@param v vm.node.object +---@return string|false local function getUnique(v) if v.type == 'boolean' then if v[1] == nil then @@ -72,16 +71,18 @@ local function getUnique(v) return ('num:%s'):format(v[1]) end if v.type == 'table' then + ---@cast v parser.object return ('table:%s@%d'):format(guide.getUri(v), v.start) end if v.type == 'function' then + ---@cast v parser.object return ('func:%s@%d'):format(guide.getUri(v), v.start) end return false end ----@param a vm.object? ----@param b vm.object? +---@param a parser.object? +---@param b parser.object? ---@return boolean|nil function vm.equal(a, b) if not a or not b then @@ -141,7 +142,7 @@ function vm.getInteger(v) end ---@param v vm.object? ----@return integer? +---@return string? function vm.getString(v) if not v then return nil diff --git a/script/vm/vm.lua b/script/vm/vm.lua index 8117d311..5437b632 100644 --- a/script/vm/vm.lua +++ b/script/vm/vm.lua @@ -64,6 +64,20 @@ function m.getObjectValue(source) return nil end +---@param source parser.object +---@return parser.object? +function m.getObjectFunctionValue(source) + local value = m.getObjectValue(source) + if value == nil then return end + if value.type == 'function' or value.type == 'doc.type.function' then + return value + end + if value.type == 'getlocal' then + return m.getObjectFunctionValue(value.node) + end + return value +end + m.cacheTracker = setmetatable({}, weakMT) function m.flushCache() diff --git a/script/workspace/require-path.lua b/script/workspace/require-path.lua index aec298a6..1b7b25f9 100644 --- a/script/workspace/require-path.lua +++ b/script/workspace/require-path.lua @@ -3,30 +3,43 @@ local files = require 'files' local furi = require 'file-uri' local workspace = require "workspace" local config = require 'config' -local collector = require 'core.collector' local scope = require 'workspace.scope' +local util = require 'utility' ---@class require-path local m = {} -local function addRequireName(suri, uri, name) - local separator = config.get(uri, 'Lua.completion.requireSeparator') - local fsname = name:gsub('%' .. separator, '/') - local scp = scope.getScope(suri) - ---@type collector - local clt = scp:get('requireName') or scp:set('requireName', collector()) - clt:subscribe(uri, fsname, name) +---@class require-manager +---@field scp scope +---@field nameMap table<string, string> +---@field visibleCache table<string, require-manager.visibleResult[]> +local mt = {} +mt.__index = mt + +---@alias require-manager.visibleResult { searcher: string, name: string } + +---@param scp scope +---@return require-manager +local function createRequireManager(scp) + return setmetatable({ + scp = scp, + nameMap = {}, + visibleCache = {}, + }, mt) end --- `aaa/bbb/ccc.lua` 与 `?.lua` 将返回 `aaa.bbb.cccc` -local function getOnePath(uri, path, searcher) - local separator = config.get(uri, 'Lua.completion.requireSeparator') +---@param path string +---@param searcher string +---@return string? +function mt:getRequireNameByPath(path, searcher) + local separator = config.get(self.scp.uri, 'Lua.completion.requireSeparator') local stemPath = path : gsub('%.[^%.]+$', '') : gsub('[/\\%.]+', separator) local stemSearcher = searcher : gsub('%.[^%.]+$', '') - : gsub('[/\\%.]+', separator) + : gsub('[/\\]+', separator) local start = stemSearcher:match '()%?' or 1 if stemPath:sub(1, start - 1) ~= stemSearcher:sub(1, start - 1) then return nil @@ -41,162 +54,188 @@ local function getOnePath(uri, path, searcher) return nil end -function m.getVisiblePath(suri, path) - local searchers = config.get(suri, 'Lua.runtime.path') - local strict = config.get(suri, 'Lua.runtime.pathStrict') - path = workspace.normalize(path) +---@param path string +---@return require-manager.visibleResult[] +function mt:getRequireResultByPath(path) local uri = furi.encode(path) - local scp = scope.getScope(suri) - if not scp:isChildUri(uri) - and not scp:isLinkedUri(uri) then - return {} - end - local libraryPath = furi.decode(files.getLibraryUri(suri, uri)) - local cache = scp:get('visiblePath') or scp:set('visiblePath', {}) - local result = cache[path] - if not result then - result = {} - cache[path] = result - for _, searcher in ipairs(searchers) do - local isAbsolute = searcher:match '^[/\\]' - or searcher:match '^%a+%:' - searcher = workspace.normalize(searcher) - local cutedPath = path - local currentPath = path - local head - local pos = 1 - if not isAbsolute then - if libraryPath then - currentPath = currentPath:sub(#libraryPath + 2) - else - currentPath = workspace.getRelativePath(uri) - end + local searchers = config.get(self.scp.uri, 'Lua.runtime.path') + local strict = config.get(self.scp.uri, 'Lua.runtime.pathStrict') + local libUri = files.getLibraryUri(self.scp.uri, uri) + local libraryPath = libUri and furi.decode(libUri) + local result = {} + for _, searcher in ipairs(searchers) do + local isAbsolute = searcher:match '^[/\\]' + or searcher:match '^%a+%:' + searcher = workspace.normalize(searcher) + if searcher:sub(1, 1) == '.' then + strict = true + end + local cutedPath = path + local currentPath = path + local head + local pos = 1 + if not isAbsolute then + if libraryPath then + currentPath = currentPath:sub(#libraryPath + 2) + else + currentPath = workspace.getRelativePath(uri) end - repeat - cutedPath = currentPath:sub(pos) - head = currentPath:sub(1, pos - 1) - pos = currentPath:match('[/\\]+()', pos) - if platform.OS == 'Windows' then - searcher = searcher :gsub('[/\\]+', '\\') - else - searcher = searcher :gsub('[/\\]+', '/') - end - local expect = getOnePath(suri, cutedPath, searcher) - if expect then - local mySearcher = searcher - if head then - mySearcher = head .. searcher - end - result[#result+1] = { - searcher = mySearcher, - expect = expect, - } - addRequireName(suri, uri, expect) - end - until not pos or strict end - end - return result -end ---- 查找符合指定require path的所有uri ----@param path string -function m.findUrisByRequirePath(suri, path) - if type(path) ~= 'string' then - return {} - end - local separator = config.get(suri, 'Lua.completion.requireSeparator') - local fspath = path:gsub('%' .. separator, '/') - tracy.ZoneBeginN('findUrisByRequirePath') - local results = {} - local searchers = {} - for uri in files.eachDll() do - local opens = files.getDllOpens(uri) or {} - for _, open in ipairs(opens) do - if open == fspath then - results[#results+1] = uri + -- handle `../?.lua` + local parentCount = 0 + for _ = 1, 1000 do + if searcher:match '^%.%.[/\\]' then + parentCount = parentCount + 1 + searcher = searcher:sub(4) + else + break end end - end - - ---@type collector - local clt = scope.getScope(suri):get('requireName') - if clt then - for _, uri in clt:each(suri, fspath) do - if uri ~= suri then - local infos = m.getVisiblePath(suri, furi.decode(uri)) - for _, info in ipairs(infos) do - local fsexpect = info.expect:gsub('%' .. separator, '/') - if fsexpect == fspath then - results[#results+1] = uri - searchers[uri] = info.searcher - end + if parentCount > 0 then + local parentPath = libraryPath + or (self.scp.uri and furi.decode(self.scp.uri)) + if parentPath then + local tail + for _ = 1, parentCount do + parentPath, tail = parentPath:match '^(.+)[/\\]([^/\\]*)$' + currentPath = tail .. '/' .. currentPath end end end + + repeat + cutedPath = currentPath:sub(pos) + head = currentPath:sub(1, pos - 1) + pos = currentPath:match('[/\\]+()', pos) + if platform.OS == 'Windows' then + searcher = searcher :gsub('[/\\]+', '\\') + else + searcher = searcher :gsub('[/\\]+', '/') + end + local name = self:getRequireNameByPath(cutedPath, searcher) + if name then + local mySearcher = searcher + if head then + mySearcher = head .. searcher + end + result[#result+1] = { + name = name, + searcher = mySearcher, + } + end + until not pos or strict end + return result +end - tracy.ZoneEnd() - return results, searchers +---@param name string +function mt:addName(name) + local separator = config.get(self.scp.uri, 'Lua.completion.requireSeparator') + local fsname = name:gsub('%' .. separator, '/') + self.nameMap[fsname] = name end -local function createVisiblePath(uri) - for _, scp in ipairs(workspace.folders) do - m.getVisiblePath(scp.uri, furi.decode(uri)) +---@return require-manager.visibleResult[] +function mt:getVisiblePath(path) + local uri = furi.encode(path) + if not self.scp:isChildUri(uri) + and not self.scp:isLinkedUri(uri) then + return {} + end + path = workspace.normalize(path) + local result = self.visibleCache[path] + if not result then + result = self:getRequireResultByPath(path) + self.visibleCache[path] = result end - m.getVisiblePath(nil, furi.decode(uri)) + return result end -local function removeVisiblePath(uri) - local path = furi.decode(uri) - path = workspace.normalize(path) - if not path then - return +--- 查找符合指定require name的所有uri +---@param suri uri +---@param name string +---@return uri[] +---@return table<uri, string>? +function mt:findUrisByRequireName(suri, name) + if type(name) ~= 'string' then + return {} end - for _, scp in ipairs(workspace.folders) do - if scp:get('visiblePath') then - scp:get('visiblePath')[path] = nil + local searchers = config.get(self.scp.uri, 'Lua.runtime.path') + local strict = config.get(self.scp.uri, 'Lua.runtime.pathStrict') + local separator = config.get(self.scp.uri, 'Lua.completion.requireSeparator') + local path = name:gsub('%' .. separator, '/') + local results = {} + local searcherMap = {} + + for _, searcher in ipairs(searchers) do + local fspath = searcher:gsub('%?', (path:gsub('%%', '%%%%'))) + local fullPath = workspace.getAbsolutePath(self.scp.uri, fspath) + if fullPath then + local fullUri = furi.encode(fullPath) + if files.exists(fullUri) + and fullUri ~= suri then + results[#results+1] = fullUri + searcherMap[fullUri] = searcher + end end - ---@type collector - local clt = scp:get('requireName') - if clt then - clt:dropUri(uri) + if not strict then + local tail = '/' .. furi.encode(fspath):gsub('^file:[/]*', '') + for uri in files.eachFile(self.scp.uri) do + if not searcherMap[uri] + and suri ~= uri + and util.stringEndWith(uri, tail) then + results[#results+1] = uri + local parentUri = files.getLibraryUri(self.scp.uri, uri) or self.scp.uri + if parentUri == nil or parentUri == '' then + parentUri = furi.encode '' + end + local relative = uri:sub(#parentUri + 1):sub(1, - #tail) + searcherMap[uri] = workspace.normalize(relative .. searcher) + end + end end end - if scope.fallback:get('visiblePath') then - scope.fallback:get('visiblePath')[path] = nil - end - ---@type collector - local clt = scope.fallback:get('requireName') - if clt then - clt:dropUri(uri) + + for uri in files.eachDll() do + local opens = files.getDllOpens(uri) or {} + for _, open in ipairs(opens) do + if open == path then + results[#results+1] = uri + end + end end + + return results, searcherMap end -function m.flush(suri) - local scp = scope.getScope(suri) - scp:set('visiblePath', {}) - ---@type collector - local clt = scp:get('requireName') - if clt then - clt:dropAll() - end - for uri in files.eachFile(suri) do - m.getVisiblePath(scp.uri, furi.decode(uri)) - end +---@param uri uri +---@param path string +---@return require-manager.visibleResult[] +function m.getVisiblePath(uri, path) + local scp = scope.getScope(uri) + ---@type require-manager + local mgr = scp:get 'requireManager' + or scp:set('requireManager', createRequireManager(scp)) + return mgr:getVisiblePath(path) end -for _, scp in ipairs(scope.folders) do - m.flush(scp.uri) +---@param uri uri +---@param name string +function m.findUrisByRequireName(uri, name) + local scp = scope.getScope(uri) + ---@type require-manager + local mgr = scp:get 'requireManager' + or scp:set('requireManager', createRequireManager(scp)) + return mgr:findUrisByRequireName(uri, name) end -m.flush(nil) files.watch(function (ev, uri) - if ev == 'create' then - createVisiblePath(uri) - end - if ev == 'remove' then - removeVisiblePath(uri) + if ev == 'create' or ev == 'delete' then + for _, scp in ipairs(workspace.folders) do + scp:set('requireManager', nil) + end + scope.fallback:set('requireManager', nil) end end) @@ -204,7 +243,8 @@ config.watch(function (uri, key, value, oldValue) if key == 'Lua.completion.requireSeparator' or key == 'Lua.runtime.path' or key == 'Lua.runtime.pathStrict' then - m.flush(uri) + local scp = scope.getScope(uri) + scp:set('requireManager', nil) end end) diff --git a/script/workspace/scope.lua b/script/workspace/scope.lua index a0f4fbf7..4649d354 100644 --- a/script/workspace/scope.lua +++ b/script/workspace/scope.lua @@ -11,6 +11,7 @@ local m = {} ---@field _links table<uri, boolean> ---@field _data table<string, any> ---@field _gc gc +---@field _removed? true local mt = {} mt.__index = mt @@ -85,7 +86,7 @@ function mt:getLinkedUri(uri) end ---@param uri uri ----@return uri +---@return uri? function mt:getRootUri(uri) if self:isChildUri(uri) then return self.uri @@ -117,9 +118,30 @@ end function mt:flushGC() self._gc:remove() + if self._removed then + return + end self._gc = gc() end +function mt:remove() + if self._removed then + return + end + self._removed = true + for i, scp in ipairs(m.folders) do + if scp == self then + table.remove(m.folders, i) + break + end + end + self:flushGC() +end + +function mt:isRemoved() + return self._removed == true +end + ---@param scopeType scope.type ---@return scope local function createScope(scopeType) @@ -164,7 +186,7 @@ function m.createFolder(uri) end ---@param uri uri ----@return scope +---@return scope? function m.getFolder(uri) for _, scope in ipairs(m.folders) do if scope:isChildUri(uri) then @@ -175,7 +197,7 @@ function m.getFolder(uri) end ---@param uri uri ----@return scope +---@return scope? function m.getLinkedScope(uri) if m.override and m.override:isLinkedUri(uri) then return m.override @@ -188,6 +210,7 @@ function m.getLinkedScope(uri) if m.fallback:isLinkedUri(uri) then return m.fallback end + return nil end ---@param uri uri diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua index 33f8784d..9d2ad637 100644 --- a/script/workspace/workspace.lua +++ b/script/workspace/workspace.lua @@ -7,12 +7,12 @@ local glob = require 'glob' local platform = require 'bee.platform' local await = require 'await' local client = require 'client' -local plugin = require 'plugin' local util = require 'utility' local fw = require 'filewatch' local scope = require 'workspace.scope' local loading = require 'workspace.loading' local inspect = require 'inspect' +local lang = require 'language' ---@class workspace local m = {} @@ -45,13 +45,35 @@ end --- 初始化工作区 function m.create(uri) + if furi.isValid(uri) then + uri = furi.normalize(uri) + end log.info('Workspace create: ', uri) - local path = m.normalize(furi.decode(uri)) - fw.watch(path) + if uri == furi.encode '/' + or uri == furi.encode(os.getenv 'HOME' or '') then + client.showMessage('Error', lang.script('WORKSPACE_NOT_ALLOWED', furi.decode(uri))) + return + end local scp = scope.createFolder(uri) m.folders[#m.folders+1] = scp end +function m.remove(uri) + log.info('Workspace remove: ', uri) + for i, scp in ipairs(m.folders) do + if scp.uri == uri then + scp:remove() + table.remove(m.folders, i) + scp:set('ready', false) + scp:set('nativeMatcher', nil) + scp:set('libraryMatcher', nil) + scp:removeAllLinks() + m.flushFiles(scp) + return + end + end +end + function m.reset() ---@type scope[] m.folders = {} @@ -135,7 +157,7 @@ function m.getNativeMatcher(scp) end end end - for path in pairs(config.get(scp.uri, 'Lua.workspace.library')) do + for _, path in ipairs(config.get(scp.uri, 'Lua.workspace.library')) do path = m.getAbsolutePath(scp.uri, path) if path then log.debug('Ignore by library:', path) @@ -148,7 +170,7 @@ function m.getNativeMatcher(scp) end local matcher = glob.gitignore(pattern, { - root = furi.decode(scp.uri), + root = scp.uri and furi.decode(scp.uri), ignoreCase = platform.OS == 'Windows', }, globInteferFace) @@ -177,7 +199,7 @@ function m.getLibraryMatchers(scp) end local librarys = {} - for path in pairs(config.get(scp.uri, 'Lua.workspace.library')) do + for _, path in ipairs(config.get(scp.uri, 'Lua.workspace.library')) do path = m.getAbsolutePath(scp.uri, path) if path then librarys[m.normalize(path)] = true @@ -273,6 +295,10 @@ function m.awaitPreload(scp) scp:flushGC() + if scp:isRemoved() then + return + end + local ld <close> = loading.create(scp) scp:set('loading', ld) @@ -283,22 +309,35 @@ function m.awaitPreload(scp) if scp.uri then log.info('Scan files at:', scp:getName()) + local count = 0 ---@async native:scan(furi.decode(scp.uri), function (path) local uri = files.getRealUri(furi.encode(path)) scp:get('cachedUris')[uri] = true ld:loadFile(uri) + end, function () ---@async + count = count + 1 + if count == 100000 then + client.showMessage('Warning', lang.script('WORKSPACE_SCAN_TOO_MUCH', count, furi.decode(scp.uri))) + end end) + scp:gc(fw.watch(m.normalize(furi.decode(scp.uri)))) end for _, libMatcher in ipairs(librarys) do log.info('Scan library at:', libMatcher.uri) + local count = 0 scp:addLink(libMatcher.uri) ---@async libMatcher.matcher:scan(furi.decode(libMatcher.uri), function (path) local uri = files.getRealUri(furi.encode(path)) scp:get('cachedUris')[uri] = true ld:loadFile(uri, libMatcher.uri) + end, function () ---@async + count = count + 1 + if count == 100000 then + client.showMessage('Warning', lang.script('WORKSPACE_SCAN_TOO_MUCH', count, furi.decode(libMatcher.uri))) + end end) scp:gc(fw.watch(furi.decode(libMatcher.uri))) end @@ -336,9 +375,6 @@ end ---@param path string ---@return string function m.normalize(path) - if not path then - return nil - end path = path:gsub('%$%{(.-)%}', function (key) if key == '3rd' then return (ROOT / 'meta' / '3rd'):string() @@ -350,9 +386,20 @@ function m.normalize(path) end) path = util.expandPath(path) path = path:gsub('^%.[/\\]+', '') + for _ = 1, 1000 do + if path:sub(1, 2) == '..' then + break + end + local count + path, count = path:gsub('[^/\\]+[/\\]+%.%.[/\\]', '/', 1) + if count == 0 then + break + end + end if platform.OS == 'Windows' then path = path:gsub('[/\\]+', '\\') :gsub('[/\\]+$', '') + :gsub('^(%a:)$', '%1\\') else path = path:gsub('[/\\]+', '/') :gsub('[/\\]+$', '') @@ -360,11 +407,10 @@ function m.normalize(path) return path end ----@return string +---@param folderUri? uri +---@param path string +---@return string? function m.getAbsolutePath(folderUri, path) - if not path or path == '' then - return nil - end path = m.normalize(path) if fs.path(path):is_relative() then if not folderUri then @@ -378,6 +424,7 @@ end ---@param uriOrPath uri|string ---@return string +---@return boolean suc function m.getRelativePath(uriOrPath) local path, uri if uriOrPath:sub(1, 5) == 'file:' then @@ -427,16 +474,24 @@ function m.flushFiles(scp) for uri in pairs(cachedUris) do files.delRef(uri) end + collectgarbage() + collectgarbage() + -- TODO: wait maillist + collectgarbage 'restart' end ---@param scp scope function m.resetFiles(scp) local cachedUris = scp:get 'cachedUris' - if not cachedUris then - return + if cachedUris then + for uri in pairs(cachedUris) do + files.resetText(uri) + end end - for uri in pairs(cachedUris) do - files.resetText(uri) + for uri in pairs(files.openMap) do + if scope.getScope(uri) == scp then + files.resetText(uri) + end end end @@ -448,7 +503,7 @@ function m.awaitReload(scp) scp:set('libraryMatcher', nil) scp:removeAllLinks() m.flushFiles(scp) - plugin.init(scp) + m.onWatch('startReload', scp.uri) m.awaitPreload(scp) scp:set('ready', true) local waiting = scp:get('waitingReady') @@ -501,6 +556,7 @@ end config.watch(function (uri, key, value, oldValue) if key:find '^Lua.runtime' or key:find '^Lua.workspace' + or key:find '^Lua.type' or key:find '^files' then if value ~= oldValue then m.reload(scope.getScope(uri)) |