diff options
Diffstat (limited to 'script')
115 files changed, 22490 insertions, 0 deletions
diff --git a/script/await.lua b/script/await.lua new file mode 100644 index 00000000..d8e2a9ad --- /dev/null +++ b/script/await.lua @@ -0,0 +1,227 @@ +local timer = require 'timer' +local util = require 'utility' + +---@class await +local m = {} +m.type = 'await' + +m.coMap = setmetatable({}, { __mode = 'k' }) +m.idMap = {} +m.delayQueue = {} +m.delayQueueIndex = 1 +m.watchList = {} +m._enable = true + +--- 设置错误处理器 +---@param errHandle function {comment = '当有错误发生时,会以错误堆栈为参数调用该函数'} +function m.setErrorHandle(errHandle) + m.errorHandle = errHandle +end + +function m.checkResult(co, ...) + local suc, err = ... + if not suc and m.errorHandle then + m.errorHandle(debug.traceback(co, err)) + end + return ... +end + +--- 创建一个任务 +function m.call(callback, ...) + local co = coroutine.create(callback) + local closers = {} + m.coMap[co] = { + closers = closers, + priority = false, + } + for i = 1, select('#', ...) do + local id = select(i, ...) + if not id then + break + end + m.setID(id, co) + end + + local currentCo = coroutine.running() + local current = m.coMap[currentCo] + if current then + for closer in pairs(current.closers) do + closers[closer] = true + closer(co) + end + end + return m.checkResult(co, coroutine.resume(co)) +end + +--- 创建一个任务,并挂起当前线程,当任务完成后再延续当前线程/若任务被关闭,则返回nil +function m.await(callback, ...) + if not coroutine.isyieldable() then + return callback(...) + end + return m.wait(function (waker, ...) + m.call(function () + local returnNil <close> = util.defer(waker) + waker(callback()) + end, ...) + end, ...) +end + +--- 设置一个id,用于批量关闭任务 +function m.setID(id, co) + co = co or coroutine.running() + if not m.idMap[id] then + m.idMap[id] = setmetatable({}, { __mode = 'k' }) + end + m.idMap[id][co] = true +end + +--- 根据id批量关闭任务 +function m.close(id) + local map = m.idMap[id] + if not map then + return + end + local count = 0 + for co in pairs(map) do + map[co] = nil + coroutine.close(co) + count = count + 1 + end + log.debug('Close await:', id, count) +end + +function m.hasID(id, co) + co = co or coroutine.running() + return m.idMap[id] and m.idMap[id][co] ~= nil +end + +--- 休眠一段时间 +---@param time number +function m.sleep(time) + if not coroutine.isyieldable() then + if m.errorHandle then + m.errorHandle(debug.traceback('Cannot yield')) + end + return + end + local co = coroutine.running() + timer.wait(time, function () + if coroutine.status(co) ~= 'suspended' then + return + end + return m.checkResult(co, coroutine.resume(co)) + end) + return coroutine.yield() +end + +--- 等待直到唤醒 +---@param callback function +function m.wait(callback, ...) + if not coroutine.isyieldable() then + return + end + local co = coroutine.running() + local waked + callback(function (...) + if waked then + return + end + waked = true + if coroutine.status(co) ~= 'suspended' then + return + end + return m.checkResult(co, coroutine.resume(co, ...)) + end, ...) + return coroutine.yield() +end + +--- 延迟 +function m.delay() + if not m._enable then + return + end + if not coroutine.isyieldable() then + return + end + local co = coroutine.running() + local current = m.coMap[co] + if m.onWatch('delay', co) == false then + return + end + -- TODO + if current.priority then + return + end + m.delayQueue[#m.delayQueue+1] = function () + if coroutine.status(co) ~= 'suspended' then + return + end + return m.checkResult(co, coroutine.resume(co)) + end + return coroutine.yield() +end + +local function warnStepTime(passed, waker) + if passed < 1 then + log.warn(('Await step takes [%.3f] sec.'):format(passed)) + return + end + for i = 1, 100 do + local name, v = debug.getupvalue(waker, i) + if not name then + return + end + if name == 'co' then + log.warn(debug.traceback(v, ('[fire]Await step takes [%.3f] sec.'):format(passed))) + return + end + end +end + +--- 步进 +function m.step() + local waker = m.delayQueue[m.delayQueueIndex] + if waker then + m.delayQueue[m.delayQueueIndex] = false + m.delayQueueIndex = m.delayQueueIndex + 1 + local clock = os.clock() + waker() + local passed = os.clock() - clock + if passed > 0.1 then + warnStepTime(passed, waker) + end + return true + else + m.delayQueue = {} + m.delayQueueIndex = 1 + return false + end +end + +function m.setPriority(n) + m.coMap[coroutine.running()].priority = true +end + +function m.enable() + m._enable = true +end + +function m.disable() + m._enable = false +end + +--- 注册事件 +function m.watch(callback) + m.watchList[#m.watchList+1] = callback +end + +function m.onWatch(ev, ...) + for _, callback in ipairs(m.watchList) do + local res = callback(ev, ...) + if res ~= nil then + return res + end + end +end + +return m diff --git a/script/brave/brave.lua b/script/brave/brave.lua new file mode 100644 index 00000000..08909074 --- /dev/null +++ b/script/brave/brave.lua @@ -0,0 +1,70 @@ +local thread = require 'bee.thread' + +---@class pub_brave +local m = {} +m.type = 'brave' +m.ability = {} +m.queue = {} + +--- 注册成为勇者 +function m.register(id) + m.taskpad = thread.channel('taskpad' .. id) + m.waiter = thread.channel('waiter' .. id) + m.id = id + + if #m.queue > 0 then + for _, info in ipairs(m.queue) do + m.waiter:push(info.name, info.params) + end + end + m.queue = nil + + m.start() +end + +--- 注册能力 +function m.on(name, callback) + m.ability[name] = callback +end + +--- 报告 +function m.push(name, params) + if m.waiter then + m.waiter:push(name, params) + else + m.queue[#m.queue+1] = { + name = name, + params = params, + } + end +end + +--- 开始找工作 +function m.start() + m.push('mem', collectgarbage 'count') + while true do + local suc, name, id, params = m.taskpad:pop() + if not suc then + -- 找不到工作的勇者,只好睡觉 + thread.sleep(0.001) + goto CONTINUE + end + local ability = m.ability[name] + -- TODO + if not ability then + m.waiter:push(id) + log.error('Brave can not handle this work: ' .. name) + goto CONTINUE + end + local ok, res = xpcall(ability, log.error, params) + if ok then + m.waiter:push(id, res) + else + m.waiter:push(id) + end + m.push('mem', collectgarbage 'count') + ::CONTINUE:: + end +end + +return m diff --git a/script/brave/init.lua b/script/brave/init.lua new file mode 100644 index 00000000..24c2e412 --- /dev/null +++ b/script/brave/init.lua @@ -0,0 +1,4 @@ +local brave = require 'brave.brave' +require 'brave.work' + +return brave diff --git a/script/brave/log.lua b/script/brave/log.lua new file mode 100644 index 00000000..18ea7853 --- /dev/null +++ b/script/brave/log.lua @@ -0,0 +1,54 @@ +local brave = require 'brave' + +local tablePack = table.pack +local tostring = tostring +local tableConcat = table.concat +local debugTraceBack = debug.traceback +local debugGetInfo = debug.getinfo +local osClock = os.clock + +_ENV = nil + +local function pushLog(level, ...) + local t = tablePack(...) + for i = 1, t.n do + t[i] = tostring(t[i]) + end + local str = tableConcat(t, '\t', 1, t.n) + if level == 'error' then + str = str .. '\n' .. debugTraceBack(nil, 3) + end + local info = debugGetInfo(3, 'Sl') + brave.push('log', { + level = level, + msg = str, + src = info.source, + line = info.currentline, + clock = osClock(), + }) + return str +end + +local m = {} + +function m.info(...) + pushLog('info', ...) +end + +function m.debug(...) + pushLog('debug', ...) +end + +function m.trace(...) + pushLog('trace', ...) +end + +function m.warn(...) + pushLog('warn', ...) +end + +function m.error(...) + pushLog('error', ...) +end + +return m diff --git a/script/brave/work.lua b/script/brave/work.lua new file mode 100644 index 00000000..5ec8178f --- /dev/null +++ b/script/brave/work.lua @@ -0,0 +1,60 @@ +local brave = require 'brave.brave' +local parser = require 'parser' +local fs = require 'bee.filesystem' +local furi = require 'file-uri' +local util = require 'utility' +local thread = require 'bee.thread' + +brave.on('loadProto', function () + local jsonrpc = require 'jsonrpc' + while true do + local proto, err = jsonrpc.decode(io.read, log.error) + --log.debug('loaded proto', proto.method) + if not proto then + brave.push('protoerror', err) + return + end + brave.push('proto', proto) + thread.sleep(0.001) + end +end) + +brave.on('compile', function (text) + local state, err = parser:compile(text, 'lua', 'Lua 5.4') + if not state then + log.error(err) + return + end + local lines = parser:lines(text) + return { + root = state.root, + value = state.value, + errs = state.errs, + lines = lines, + } +end) + +brave.on('listDirectory', function (uri) + local path = fs.path(furi.decode(uri)) + local uris = {} + for child in path:list_directory() do + local childUri = furi.encode(child:string()) + uris[#uris+1] = childUri + end + return uris +end) + +brave.on('isDirectory', function (uri) + local path = fs.path(furi.decode(uri)) + return fs.is_directory(path) +end) + +brave.on('loadFile', function (uri) + local filename = furi.decode(uri) + return util.loadFile(filename) +end) + +brave.on('saveFile', function (params) + local filename = furi.decode(params.uri) + return util.saveFile(filename, params.text) +end) diff --git a/script/config.lua b/script/config.lua new file mode 100644 index 00000000..0544c317 --- /dev/null +++ b/script/config.lua @@ -0,0 +1,218 @@ +local util = require 'utility' +local define = require 'proto.define' + +local m = {} +m.version = 0 + +local function Boolean(v) + if type(v) == 'boolean' then + return true, v + end + return false +end + +local function Integer(v) + if type(v) == 'number' then + return true, math.floor(v) + end + return false +end + +local function String(v) + return true, tostring(v) +end + +local function Str2Hash(sep) + return function (v) + if type(v) == 'string' then + local t = {} + for s in v:gmatch('[^'..sep..']+') do + t[s] = true + end + return true, t + end + if type(v) == 'table' then + local t = {} + for _, s in ipairs(v) do + if type(s) == 'string' then + t[s] = true + end + end + return true, t + end + return false + end +end + +local function Array(checker) + return function (tbl) + if type(tbl) ~= 'table' then + return false + end + local t = {} + for _, v in ipairs(tbl) do + local ok, result = checker(v) + if ok then + t[#t+1] = result + end + end + return true, t + end +end + +local function Hash(keyChecker, valueChecker) + return function (tbl) + if type(tbl) ~= 'table' then + return false + end + local t = {} + for k, v in pairs(tbl) do + local ok1, key = keyChecker(k) + local ok2, value = valueChecker(v) + if ok1 and ok2 then + t[key] = value + end + end + if not next(t) then + return false + end + return true, t + end +end + +local function Or(...) + local checkers = {...} + return function (obj) + for _, checker in ipairs(checkers) do + local suc, res = checker(obj) + if suc then + return true, res + end + end + return false + end +end + +local ConfigTemplate = { + runtime = { + version = {'Lua 5.4', String}, + library = {{}, Str2Hash ';'}, + path = {{ + "?.lua", + "?/init.lua", + "?/?.lua" + }, Array(String)}, + special = {{}, Hash(String, String)}, + meta = {'${version} ${language}', String}, + }, + diagnostics = { + enable = {true, Boolean}, + globals = {{}, Str2Hash ';'}, + disable = {{}, Str2Hash ';'}, + severity = { + util.deepCopy(define.DiagnosticDefaultSeverity), + Hash(String, String), + }, + workspaceDelay = {0, Integer}, + workspaceRate = {100, Integer}, + }, + workspace = { + ignoreDir = {{}, Str2Hash ';'}, + ignoreSubmodules= {true, Boolean}, + useGitIgnore = {true, Boolean}, + maxPreload = {1000, Integer}, + preloadFileSize = {100, Integer}, + library = {{}, Hash( + String, + Or(Boolean, Array(String)) + )} + }, + completion = { + enable = {true, Boolean}, + callSnippet = {'Disable', String}, + keywordSnippet = {'Replace', String}, + displayContext = {6, Integer}, + }, + signatureHelp = { + enable = {true, Boolean}, + }, + hover = { + enable = {true, Boolean}, + viewString = {true, Boolean}, + viewStringMax = {1000, Integer}, + viewNumber = {true, Boolean}, + fieldInfer = {3000, Integer}, + }, + color = { + mode = {'Semantic', String}, + }, + luadoc = { + enable = {true, Boolean}, + }, + plugin = { + enable = {false, Boolean}, + path = {'.vscode/lua-plugin/*.lua', String}, + }, + intelliSense = { + searchDepth = {0, Integer}, + fastGlobal = {true, Boolean}, + }, +} + +local OtherTemplate = { + associations = {{}, Hash(String, String)}, + exclude = {{}, Hash(String, Boolean)}, +} + +local function init() + if m.config then + return + end + + m.config = {} + for c, t in pairs(ConfigTemplate) do + m.config[c] = {} + for k, info in pairs(t) do + m.config[c][k] = info[1] + end + end + + m.other = {} + for k, info in pairs(OtherTemplate) do + m.other[k] = info[1] + end +end + +function m.setConfig(config, other) + m.version = m.version + 1 + xpcall(function () + for c, t in pairs(config) do + for k, v in pairs(t) do + local region = ConfigTemplate[c] + if region then + local info = region[k] + local suc, v = info[2](v) + if suc then + m.config[c][k] = v + else + m.config[c][k] = info[1] + end + end + end + end + for k, v in pairs(other) do + local info = OtherTemplate[k] + local suc, v = info[2](v) + if suc then + m.other[k] = v + else + m.other[k] = info[1] + end + end + log.debug('Config update: ', util.dump(m.config), util.dump(m.other)) + end, log.error) +end + +init() + +return m diff --git a/script/core/code-action.lua b/script/core/code-action.lua new file mode 100644 index 00000000..69304f98 --- /dev/null +++ b/script/core/code-action.lua @@ -0,0 +1,269 @@ +local files = require 'files' +local lang = require 'language' +local define = require 'proto.define' +local guide = require 'parser.guide' +local util = require 'utility' +local sp = require 'bee.subprocess' + +local function disableDiagnostic(uri, code, results) + results[#results+1] = { + title = lang.script('ACTION_DISABLE_DIAG', code), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_DISABLE_DIAG, + command = 'lua.config', + arguments = { + { + key = 'Lua.diagnostics.disable', + action = 'add', + value = code, + uri = uri, + } + } + } + } +end + +local function markGlobal(uri, name, results) + results[#results+1] = { + title = lang.script('ACTION_MARK_GLOBAL', name), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_MARK_GLOBAL, + command = 'lua.config', + arguments = { + { + key = 'Lua.diagnostics.globals', + action = 'add', + value = name, + uri = uri, + } + } + } + } +end + +local function changeVersion(uri, version, results) + results[#results+1] = { + title = lang.script('ACTION_RUNTIME_VERSION', version), + kind = 'quickfix', + command = { + title = lang.script.COMMAND_RUNTIME_VERSION, + command = 'lua.config', + arguments = { + { + key = 'Lua.runtime.version', + action = 'set', + value = version, + uri = uri, + } + } + }, + } +end + +local function solveUndefinedGlobal(uri, diag, results) + local ast = files.getAst(uri) + local text = files.getText(uri) + local lines = files.getLines(uri) + local offset = define.offsetOfWord(lines, text, diag.range.start) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type ~= 'getglobal' then + return + end + + local name = guide.getName(source) + markGlobal(uri, name, results) + + -- TODO check other version + end) +end + +local function solveLowercaseGlobal(uri, diag, results) + local ast = files.getAst(uri) + local text = files.getText(uri) + local lines = files.getLines(uri) + local offset = define.offsetOfWord(lines, text, diag.range.start) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type ~= 'setglobal' then + return + end + + local name = guide.getName(source) + markGlobal(uri, name, results) + end) +end + +local function findSyntax(uri, diag) + local ast = files.getAst(uri) + local text = files.getText(uri) + local lines = files.getLines(uri) + for _, err in ipairs(ast.errs) do + if err.type:lower():gsub('_', '-') == diag.code then + local range = define.range(lines, text, err.start, err.finish) + if util.equal(range, diag.range) then + return err + end + end + end + return nil +end + +local function solveSyntaxByChangeVersion(uri, err, results) + if type(err.version) == 'table' then + for _, version in ipairs(err.version) do + changeVersion(uri, version, results) + end + else + changeVersion(uri, err.version, results) + end +end + +local function solveSyntaxByAddDoEnd(uri, err, results) + local text = files.getText(uri) + local lines = files.getLines(uri) + results[#results+1] = { + title = lang.script.ACTION_ADD_DO_END, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + range = define.range(lines, text, err.start, err.finish), + newText = ('do %s end'):format(text:sub(err.start, err.finish)), + }, + } + } + } + } +end + +local function solveSyntaxByFix(uri, err, results) + local text = files.getText(uri) + local lines = files.getLines(uri) + local changes = {} + for _, fix in ipairs(err.fix) do + changes[#changes+1] = { + range = define.range(lines, text, fix.start, fix.finish), + newText = fix.text, + } + end + results[#results+1] = { + title = lang.script['ACTION_' .. err.fix.title], + kind = 'quickfix', + edit = { + changes = { + [uri] = changes, + } + } + } +end + +local function solveSyntax(uri, diag, results) + local err = findSyntax(uri, diag) + if not err then + return + end + if err.version then + solveSyntaxByChangeVersion(uri, err, results) + end + if err.type == 'ACTION_AFTER_BREAK' or err.type == 'ACTION_AFTER_RETURN' then + solveSyntaxByAddDoEnd(uri, err, results) + end + if err.fix then + solveSyntaxByFix(uri, err, results) + end +end + +local function solveNewlineCall(uri, diag, results) + local text = files.getText(uri) + local lines = files.getLines(uri) + results[#results+1] = { + title = lang.script.ACTION_ADD_SEMICOLON, + kind = 'quickfix', + edit = { + changes = { + [uri] = { + { + range = { + start = diag.range.start, + ['end'] = diag.range.start, + }, + newText = ';', + } + } + } + } + } +end + +local function solveAmbiguity1(uri, diag, results) + results[#results+1] = { + title = lang.script.ACTION_ADD_BRACKETS, + kind = 'quickfix', + command = { + title = lang.script.COMMAND_ADD_BRACKETS, + command = 'lua.solve:' .. sp:get_id(), + arguments = { + { + name = 'ambiguity-1', + uri = uri, + range = diag.range, + } + } + }, + } +end + +local function solveTrailingSpace(uri, diag, results) + results[#results+1] = { + title = lang.script.ACTION_REMOVE_SPACE, + kind = 'quickfix', + command = { + title = lang.script.COMMAND_REMOVE_SPACE, + command = 'lua.removeSpace:' .. sp:get_id(), + arguments = { + { + uri = uri, + } + } + }, + } +end + +local function solveDiagnostic(uri, diag, results) + if diag.source == lang.script.DIAG_SYNTAX_CHECK then + solveSyntax(uri, diag, results) + return + end + if not diag.code then + return + end + if diag.code == 'undefined-global' then + solveUndefinedGlobal(uri, diag, results) + elseif diag.code == 'lowercase-global' then + solveLowercaseGlobal(uri, diag, results) + elseif diag.code == 'newline-call' then + solveNewlineCall(uri, diag, results) + elseif diag.code == 'ambiguity-1' then + solveAmbiguity1(uri, diag, results) + elseif diag.code == 'trailing-space' then + solveTrailingSpace(uri, diag, results) + end + disableDiagnostic(uri, diag.code, results) +end + +return function (uri, range, diagnostics) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local results = {} + + for _, diag in ipairs(diagnostics) do + solveDiagnostic(uri, diag, results) + end + + return results +end diff --git a/script/core/command/removeSpace.lua b/script/core/command/removeSpace.lua new file mode 100644 index 00000000..e8b09932 --- /dev/null +++ b/script/core/command/removeSpace.lua @@ -0,0 +1,56 @@ +local files = require 'files' +local define = require 'proto.define' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' + +local function isInString(ast, offset) + return guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'string' then + return true + end + end) or false +end + +return function (data) + local uri = data.uri + local lines = files.getLines(uri) + local text = files.getText(uri) + local ast = files.getAst(uri) + if not lines then + return + end + + local textEdit = {} + for i = 1, #lines do + local line = guide.lineContent(lines, text, i, true) + local pos = line:find '[ \t]+$' + if pos then + local start, finish = guide.lineRange(lines, i, true) + start = start + pos - 1 + if isInString(ast, start) then + goto NEXT_LINE + end + textEdit[#textEdit+1] = { + range = define.range(lines, text, start, finish), + newText = '', + } + goto NEXT_LINE + end + + ::NEXT_LINE:: + end + + if #textEdit == 0 then + return + end + + proto.awaitRequest('workspace/applyEdit', { + label = lang.script.COMMAND_REMOVE_SPACE, + edit = { + changes = { + [uri] = textEdit, + } + }, + }) +end diff --git a/script/core/command/solve.lua b/script/core/command/solve.lua new file mode 100644 index 00000000..d3b8f94e --- /dev/null +++ b/script/core/command/solve.lua @@ -0,0 +1,96 @@ +local files = require 'files' +local define = require 'proto.define' +local guide = require 'parser.guide' +local proto = require 'proto' +local lang = require 'language' + +local opMap = { + ['+'] = true, + ['-'] = true, + ['*'] = true, + ['/'] = true, + ['//'] = true, + ['^'] = true, + ['<<'] = true, + ['>>'] = true, + ['&'] = true, + ['|'] = true, + ['~'] = true, + ['..'] = true, +} + +local literalMap = { + ['number'] = true, + ['boolean'] = true, + ['string'] = true, + ['table'] = true, +} + +return function (data) + local uri = data.uri + local lines = files.getLines(uri) + local text = files.getText(uri) + local ast = files.getAst(uri) + if not ast then + return + end + + local start = define.offsetOfWord(lines, text, data.range.start) + local finish = define.offsetOfWord(lines, text, data.range['end']) + + local result = guide.eachSourceContain(ast.ast, start, function (source) + if source.start ~= start + or source.finish ~= finish then + return + end + if not source.op or source.op.type ~= 'or' then + return + end + local first = source[1] + local second = source[2] + -- a + b or 0 --> a + (b or 0) + do + if first.op + and opMap[first.op.type] + and first.type ~= 'unary' + and not second.op + and literalMap[second.type] then + return { + start = source[1][2].start, + finish = source[2].finish, + } + end + end + -- a or b + c --> (a or b) + c + do + if second.op + and opMap[second.op.type] + and second.type ~= 'unary' + and not first.op + and literalMap[second[1].type] then + return { + start = source[1].start, + finish = source[2][1].finish, + } + end + end + end) + + if not result then + return + end + + proto.awaitRequest('workspace/applyEdit', { + label = lang.script.COMMAND_REMOVE_SPACE, + edit = { + changes = { + [uri] = { + { + range = define.range(lines, text, result.start, result.finish), + newText = ('(%s)'):format(text:sub(result.start, result.finish)), + } + }, + } + }, + }) +end diff --git a/script/core/completion.lua b/script/core/completion.lua new file mode 100644 index 00000000..44874b39 --- /dev/null +++ b/script/core/completion.lua @@ -0,0 +1,1284 @@ +local define = require 'proto.define' +local files = require 'files' +local guide = require 'parser.guide' +local matchKey = require 'core.matchkey' +local vm = require 'vm' +local getLabel = require 'core.hover.label' +local getName = require 'core.hover.name' +local getArg = require 'core.hover.arg' +local getDesc = require 'core.hover.description' +local getHover = require 'core.hover' +local config = require 'config' +local util = require 'utility' +local markdown = require 'provider.markdown' +local findSource = require 'core.find-source' +local await = require 'await' +local parser = require 'parser' +local keyWordMap = require 'core.keyword' +local workspace = require 'workspace' +local furi = require 'file-uri' +local rpath = require 'workspace.require-path' +local lang = require 'language' + +local stackID = 0 +local stacks = {} +local function stack(callback) + stackID = stackID + 1 + stacks[stackID] = callback + return stackID +end + +local function clearStack() + stacks = {} +end + +local function resolveStack(id) + local callback = stacks[id] + if not callback then + return nil + end + + -- 当进行新的 resolve 时,放弃当前的 resolve + await.close('completion.resove') + return await.await(callback, 'completion.resove') +end + +local function trim(str) + return str:match '^%s*(%S+)%s*$' +end + +local function isSpace(char) + if char == ' ' + or char == '\n' + or char == '\r' + or char == '\t' then + return true + end + return false +end + +local function skipSpace(text, offset) + for i = offset, 1, -1 do + local char = text:sub(i, i) + if not isSpace(char) then + return i + end + end + return 0 +end + +local function findWord(text, offset) + for i = offset, 1, -1 do + if not text:sub(i, i):match '[%w_]' then + if i == offset then + return nil + end + return text:sub(i+1, offset), i+1 + end + end + return text:sub(1, offset), 1 +end + +local function findSymbol(text, offset) + for i = offset, 1, -1 do + local char = text:sub(i, i) + if isSpace(char) then + goto CONTINUE + end + if char == '.' + or char == ':' + or char == '(' then + return char, i + else + return nil + end + ::CONTINUE:: + end + return nil +end + +local function findAnyPos(text, offset) + for i = offset, 1, -1 do + if not isSpace(text:sub(i, i)) then + return i + end + end + return nil +end + +local function findParent(ast, text, offset) + for i = offset, 1, -1 do + local char = text:sub(i, i) + if isSpace(char) then + goto CONTINUE + end + local oop + if char == '.' then + -- `..` 的情况 + if text:sub(i-1, i-1) == '.' then + return nil, nil + end + oop = false + elseif char == ':' then + oop = true + else + return nil, nil + end + local anyPos = findAnyPos(text, i-1) + if not anyPos then + return nil, nil + end + local parent = guide.eachSourceContain(ast.ast, anyPos, function (source) + if source.finish == anyPos then + return source + end + end) + if parent then + return parent, oop + end + ::CONTINUE:: + end + return nil, nil +end + +local function findParentInStringIndex(ast, text, offset) + local near, nearStart + guide.eachSourceContain(ast.ast, offset, function (source) + local start = guide.getStartFinish(source) + if not start then + return + end + if not nearStart or nearStart < start then + near = source + nearStart = start + end + end) + if not near or near.type ~= 'string' then + return + end + local parent = near.parent + if not parent or parent.index ~= near then + return + end + -- index不可能是oop模式 + return parent.node, false +end + +local function buildFunctionSnip(source, oop) + local name = getName(source):gsub('^.-[$.:]', '') + local defs = vm.getDefs(source, 'deep') + local args = '' + for _, def in ipairs(defs) do + local defArgs = getArg(def, oop) + if defArgs ~= '' then + args = defArgs + break + end + end + local id = 0 + args = args:gsub('[^,]+', function (arg) + id = id + 1 + return arg:gsub('^(%s*)(.+)', function (sp, word) + return ('%s${%d:%s}'):format(sp, id, word) + end) + end) + return ('%s(%s)'):format(name, args) +end + +local function buildDetail(source) + local types = vm.getInferType(source, 'deep') + local literals = vm.getInferLiteral(source, 'deep') + if literals then + return types .. ' = ' .. literals + else + return types + end +end + +local function getSnip(source) + local context = config.config.completion.displayContext + if context <= 0 then + return nil + end + local defs = vm.getRefs(source, 'deep') + for _, def in ipairs(defs) do + def = guide.getObjectValue(def) or def + if def ~= source and def.type == 'function' then + local uri = guide.getUri(def) + local text = files.getText(uri) + local lines = files.getLines(uri) + if not text then + goto CONTINUE + end + if vm.isMetaFile(uri) then + goto CONTINUE + end + local row = guide.positionOf(lines, def.start) + local firstRow = lines[row] + local lastRow = lines[math.min(row + context - 1, #lines)] + local snip = text:sub(firstRow.start, lastRow.finish) + return snip + end + ::CONTINUE:: + end +end + +local function buildDesc(source) + local hover = getHover.get(source) + local md = markdown() + md:add('lua', hover.label) + md:add('md', hover.description) + local snip = getSnip(source) + if snip then + md:add('md', '-------------') + md:add('lua', snip) + end + return md:string() +end + +local function buildFunction(results, source, oop, data) + local snipType = config.config.completion.callSnippet + if snipType == 'Disable' or snipType == 'Both' then + results[#results+1] = data + end + if snipType == 'Both' or snipType == 'Replace' then + local snipData = util.deepCopy(data) + snipData.kind = define.CompletionItemKind.Snippet + snipData.label = snipData.label .. '()' + snipData.insertText = buildFunctionSnip(source, oop) + snipData.insertTextFormat = 2 + snipData.id = stack(function () + return { + detail = buildDetail(source), + description = buildDesc(source), + } + end) + results[#results+1] = snipData + end +end + +local function isSameSource(ast, source, pos) + if not files.eq(guide.getUri(source), guide.getUri(ast.ast)) then + return false + end + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + return source.start <= pos and source.finish >= pos +end + +local function checkLocal(ast, word, offset, results) + local locals = guide.getVisibleLocals(ast.ast, offset) + for name, source in pairs(locals) do + if isSameSource(ast, source, offset) then + goto CONTINUE + end + if not matchKey(word, name) then + goto CONTINUE + end + if vm.hasType(source, 'function') then + buildFunction(results, source, false, { + label = name, + kind = define.CompletionItemKind.Function, + id = stack(function () + return { + detail = buildDetail(source), + description = buildDesc(source), + } + end), + }) + else + results[#results+1] = { + label = name, + kind = define.CompletionItemKind.Variable, + id = stack(function () + return { + detail = buildDetail(source), + description = buildDesc(source), + } + end), + } + end + ::CONTINUE:: + end +end + +local function checkFieldFromFieldToIndex(name, parent, word, start, offset) + if name:match '^[%a_][%w_]*$' then + return nil + end + local textEdit, additionalTextEdits + local uri = guide.getUri(parent) + local text = files.getText(uri) + local wordStart + if word == '' then + wordStart = text:match('()%S', start + 1) or (offset + 1) + else + wordStart = offset - #word + 1 + end + textEdit = { + start = wordStart, + finish = offset, + newText = ('[%q]'):format(name), + } + local nxt = parent.next + if nxt then + local dotStart + if nxt.type == 'setfield' + or nxt.type == 'getfield' + or nxt.type == 'tablefield' then + dotStart = nxt.dot.start + elseif nxt.type == 'setmethod' + or nxt.type == 'getmethod' then + dotStart = nxt.colon.start + end + if dotStart then + additionalTextEdits = { + { + start = dotStart, + finish = dotStart, + newText = '', + } + } + end + else + if config.config.runtime.version == 'Lua 5.1' + or config.config.runtime.version == 'LuaJIT' then + textEdit.newText = '_G' .. textEdit.newText + else + textEdit.newText = '_ENV' .. textEdit.newText + end + end + return textEdit, additionalTextEdits +end + +local function checkFieldThen(name, src, word, start, offset, parent, oop, results) + local value = guide.getObjectValue(src) or src + local kind = define.CompletionItemKind.Field + if value.type == 'function' then + if oop then + kind = define.CompletionItemKind.Method + else + kind = define.CompletionItemKind.Function + end + buildFunction(results, src, oop, { + label = name, + kind = kind, + deprecated = vm.isDeprecated(src) or nil, + id = stack(function () + return { + detail = buildDetail(src), + description = buildDesc(src), + } + end), + }) + return + end + if oop then + return + end + local literal = guide.getLiteral(value) + if literal ~= nil then + kind = define.CompletionItemKind.Enum + end + local textEdit, additionalTextEdits + if parent.next and parent.next.index then + local str = parent.next.index + textEdit = { + start = str.start + #str[2], + finish = offset, + newText = name, + } + else + textEdit, additionalTextEdits = checkFieldFromFieldToIndex(name, parent, word, start, offset) + end + results[#results+1] = { + label = name, + kind = kind, + textEdit = textEdit, + additionalTextEdits = additionalTextEdits, + id = stack(function () + return { + detail = buildDetail(src), + description = buildDesc(src), + } + end) + } +end + +local function checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, isGlobal) + local fields = {} + local count = 0 + for _, src in ipairs(refs) do + local key = vm.getKeyName(src) + if not key or key:sub(1, 1) ~= 's' then + goto CONTINUE + end + if isSameSource(ast, src, start) then + -- 由于fastGlobal的优化,全局变量只会找出一个值,有可能找出自己 + -- 所以遇到自己的时候重新找一下有没有其他定义 + if not isGlobal then + goto CONTINUE + end + if #vm.getGlobals(key) <= 1 then + goto CONTINUE + elseif not vm.isSet(src) then + src = vm.getGlobalSets(key)[1] or src + end + end + local name = key:sub(3) + if locals and locals[name] then + goto CONTINUE + end + if not matchKey(word, name, count >= 100) then + goto CONTINUE + end + local last = fields[name] + if not last then + fields[name] = src + count = count + 1 + goto CONTINUE + end + if src.type == 'tablefield' + or src.type == 'setfield' + or src.type == 'tableindex' + or src.type == 'setindex' + or src.type == 'setmethod' + or src.type == 'setglobal' then + fields[name] = src + goto CONTINUE + end + ::CONTINUE:: + end + for name, src in util.sortPairs(fields) do + checkFieldThen(name, src, word, start, offset, parent, oop, results) + end +end + +local function checkField(ast, word, start, offset, parent, oop, results) + local refs = vm.getFields(parent, 'deep') + checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results) +end + +local function checkGlobal(ast, word, start, offset, parent, oop, results) + local locals = guide.getVisibleLocals(ast.ast, offset) + local refs = vm.getGlobalSets '*' + checkFieldOfRefs(refs, ast, word, start, offset, parent, oop, results, locals, 'global') +end + +local function checkTableField(ast, word, start, results) + local source = guide.eachSourceContain(ast.ast, start, function (source) + if source.start == start + and source.parent + and source.parent.type == 'table' then + return source + end + end) + if not source then + return + end + local used = {} + guide.eachSourceType(ast.ast, 'tablefield', function (src) + if not src.field then + return + end + local key = src.field[1] + if not used[key] + and matchKey(word, key) + and src ~= source then + used[key] = true + results[#results+1] = { + label = key, + kind = define.CompletionItemKind.Property, + } + end + end) +end + +local function checkCommon(word, text, offset, results) + local used = {} + for _, result in ipairs(results) do + used[result.label] = true + end + for _, data in ipairs(keyWordMap) do + used[data[1]] = true + end + for str, pos in text:gmatch '([%a_][%w_]*)()' do + if not used[str] and pos - 1 ~= offset then + used[str] = true + if matchKey(word, str) then + results[#results+1] = { + label = str, + kind = define.CompletionItemKind.Text, + } + end + end + end +end + +local function isInString(ast, offset) + return guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'string' then + return true + end + end) +end + +local function checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results) + local snipType = config.config.completion.keywordSnippet + for _, data in ipairs(keyWordMap) do + local key = data[1] + local eq + if hasSpace then + eq = word == key + else + eq = matchKey(word, key) + end + if afterLocal and key ~= 'function' then + eq = false + end + if eq then + local replaced + local extra + if snipType == 'Both' or snipType == 'Replace' then + local func = data[2] + if func then + replaced = func(hasSpace, results) + extra = true + end + end + if snipType == 'Both' then + replaced = false + end + if not replaced then + if not hasSpace then + local item = { + label = key, + kind = define.CompletionItemKind.Keyword, + } + if extra then + table.insert(results, #results, item) + else + results[#results+1] = item + end + end + end + local checkStop = data[3] + if checkStop then + local stop = checkStop(ast, start) + if stop then + return true + end + end + end + end +end + +local function checkProvideLocal(ast, word, start, results) + local block + guide.eachSourceContain(ast.ast, start, function (source) + if source.type == 'function' + or source.type == 'main' then + block = source + end + end) + if not block then + return + end + local used = {} + guide.eachSourceType(block, 'getglobal', function (source) + if source.start > start + and not used[source[1]] + and matchKey(word, source[1]) then + used[source[1]] = true + results[#results+1] = { + label = source[1], + kind = define.CompletionItemKind.Variable, + } + end + end) + guide.eachSourceType(block, 'getlocal', function (source) + if source.start > start + and not used[source[1]] + and matchKey(word, source[1]) then + used[source[1]] = true + results[#results+1] = { + label = source[1], + kind = define.CompletionItemKind.Variable, + } + end + end) +end + +local function checkFunctionArgByDocParam(ast, word, start, results) + local func = guide.eachSourceContain(ast.ast, start, function (source) + if source.type == 'function' then + return source + end + end) + if not func then + return + end + local docs = func.bindDocs + if not docs then + return + end + local params = {} + for _, doc in ipairs(docs) do + if doc.type == 'doc.param' then + params[#params+1] = doc + end + end + local firstArg = func.args and func.args[1] + if not firstArg + or firstArg.start <= start and firstArg.finish >= start then + local firstParam = params[1] + if firstParam and matchKey(word, firstParam.param[1]) then + local label = {} + for _, param in ipairs(params) do + label[#label+1] = param.param[1] + end + results[#results+1] = { + label = table.concat(label, ', '), + kind = define.CompletionItemKind.Snippet, + } + end + end + for _, doc in ipairs(params) do + if matchKey(word, doc.param[1]) then + results[#results+1] = { + label = doc.param[1], + kind = define.CompletionItemKind.Interface, + } + end + end +end + +local function isAfterLocal(text, start) + local pos = skipSpace(text, start-1) + local word = findWord(text, pos) + return word == 'local' +end + +local function checkUri(ast, text, offset, results) + local collect = {} + local myUri = guide.getUri(ast.ast) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type ~= 'string' then + return + end + local callargs = source.parent + if not callargs or callargs.type ~= 'callargs' then + return + end + if callargs[1] ~= source then + return + end + local call = callargs.parent + local func = call.node + local literal = guide.getLiteral(source) + local libName = vm.getLibraryName(func) + if not libName then + return + end + if libName == 'require' then + for uri in files.eachFile() do + uri = files.getOriginUri(uri) + if files.eq(myUri, uri) then + goto CONTINUE + end + if vm.isMetaFile(uri) then + goto CONTINUE + end + local path = workspace.getRelativePath(uri) + local infos = rpath.getVisiblePath(path, config.config.runtime.path) + for _, info in ipairs(infos) do + if matchKey(literal, info.expect) then + if not collect[info.expect] then + collect[info.expect] = { + textEdit = { + start = source.start + #source[2], + finish = source.finish - #source[2], + } + } + end + collect[info.expect][#collect[info.expect]+1] = ([=[* [%s](%s) %s]=]):format( + path, + uri, + lang.script('HOVER_USE_LUA_PATH', info.searcher) + ) + end + end + ::CONTINUE:: + end + elseif libName == 'dofile' + or libName == 'loadfile' then + for uri in files.eachFile() do + uri = files.getOriginUri(uri) + if files.eq(myUri, uri) then + goto CONTINUE + end + if vm.isMetaFile(uri) then + goto CONTINUE + end + local path = workspace.getRelativePath(uri) + if matchKey(literal, path) then + if not collect[path] then + collect[path] = { + textEdit = { + start = source.start + #source[2], + finish = source.finish - #source[2], + } + } + end + collect[path][#collect[path]+1] = ([=[[%s](%s)]=]):format( + path, + uri + ) + end + ::CONTINUE:: + end + end + end) + for label, infos in util.sortPairs(collect) do + local mark = {} + local des = {} + for _, info in ipairs(infos) do + if not mark[info] then + mark[info] = true + des[#des+1] = info + end + end + results[#results+1] = { + label = label, + kind = define.CompletionItemKind.Reference, + description = table.concat(des, '\n'), + textEdit = infos.textEdit, + } + end +end + +local function checkLenPlusOne(ast, text, offset, results) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'getindex' + or source.type == 'setindex' then + local _, pos = text:find('%s*%[%s*%#', source.node.finish) + if not pos then + return + end + local nodeText = text:sub(source.node.start, source.node.finish) + local writingText = trim(text:sub(pos + 1, offset - 1)) or '' + if not matchKey(writingText, nodeText) then + return + end + if source.parent == guide.getParentBlock(source) then + -- state + local label = text:match('%#[ \t]*', pos) .. nodeText .. '+1' + local eq = text:find('^%s*%]?%s*%=', source.finish) + local newText = label .. ']' + if not eq then + newText = newText .. ' = ' + end + results[#results+1] = { + label = label, + kind = define.CompletionItemKind.Snippet, + textEdit = { + start = pos, + finish = source.finish, + newText = newText, + }, + } + else + -- exp + local label = text:match('%#[ \t]*', pos) .. nodeText + local newText = label .. ']' + results[#results+1] = { + label = label, + kind = define.CompletionItemKind.Snippet, + textEdit = { + start = pos, + finish = source.finish, + newText = newText, + }, + } + end + end + end) +end + +local function isFuncArg(ast, offset) + return guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'funcargs' then + return true + end + end) +end + +local function trySpecial(ast, text, offset, results) + if isInString(ast, offset) then + checkUri(ast, text, offset, results) + return + end + -- x[#x+1] + checkLenPlusOne(ast, text, offset, results) +end + +local function tryIndex(ast, text, offset, results) + local parent, oop = findParentInStringIndex(ast, text, offset) + if not parent then + return + end + local word = parent.next.index[1] + checkField(ast, word, offset, offset, parent, oop, results) +end + +local function tryWord(ast, text, offset, results) + local finish = skipSpace(text, offset) + local word, start = findWord(text, finish) + if not word then + return nil + end + local hasSpace = finish ~= offset + if isInString(ast, offset) then + else + local parent, oop = findParent(ast, text, start - 1) + if parent then + if not hasSpace then + checkField(ast, word, start, offset, parent, oop, results) + end + elseif isFuncArg(ast, offset) then + checkProvideLocal(ast, word, start, results) + checkFunctionArgByDocParam(ast, word, start, results) + else + local afterLocal = isAfterLocal(text, start) + local stop = checkKeyWord(ast, text, start, word, hasSpace, afterLocal, results) + if stop then + return + end + if not hasSpace then + if afterLocal then + checkProvideLocal(ast, word, start, results) + else + checkLocal(ast, word, start, results) + checkTableField(ast, word, start, results) + local env = guide.getENV(ast.ast, start) + checkGlobal(ast, word, start, offset, env, false, results) + end + end + end + if not hasSpace then + checkCommon(word, text, offset, results) + end + end +end + +local function trySymbol(ast, text, offset, results) + local symbol, start = findSymbol(text, offset) + if not symbol then + return nil + end + if isInString(ast, offset) then + return nil + end + if symbol == '.' + or symbol == ':' then + local parent, oop = findParent(ast, text, start) + if parent then + checkField(ast, '', start, offset, parent, oop, results) + end + end + if symbol == '(' then + checkFunctionArgByDocParam(ast, '', start, results) + end +end + +local function getCallEnums(source, index) + if source.type == 'function' and source.bindDocs then + if not source.args then + return + end + local arg + if index <= #source.args then + arg = source.args[index] + else + local lastArg = source.args[#source.args] + if lastArg.type == '...' then + arg = lastArg + else + return + end + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.param' + and doc.param[1] == arg[1] then + local enums = {} + for _, enum in ipairs(vm.getDocEnums(doc.extends)) do + enums[#enums+1] = { + label = enum[1], + description = enum.comment, + kind = define.CompletionItemKind.EnumMember, + } + end + return enums + elseif doc.type == 'doc.vararg' + and arg.type == '...' then + local enums = {} + for _, enum in ipairs(vm.getDocEnums(doc.vararg)) do + enums[#enums+1] = { + label = enum[1], + description = enum.comment, + kind = define.CompletionItemKind.EnumMember, + } + end + return enums + end + end + end +end + +local function tryLabelInString(label, arg) + if not arg or arg.type ~= 'string' then + return label + end + local str = parser:grammar(label, 'String') + if not str then + return label + end + if not matchKey(arg[1], str[1]) then + return nil + end + return util.viewString(str[1], arg[2]) +end + +local function mergeEnums(a, b, text, arg) + local mark = {} + for _, enum in ipairs(a) do + mark[enum.label] = true + end + for _, enum in ipairs(b) do + local label = tryLabelInString(enum.label, arg) + if label and not mark[label] then + mark[label] = true + local result = { + label = label, + kind = define.CompletionItemKind.EnumMember, + description = enum.description, + textEdit = arg and { + start = arg.start, + finish = arg.finish, + newText = label, + }, + } + a[#a+1] = result + end + end +end + +local function findCall(ast, text, offset) + local call + guide.eachSourceContain(ast.ast, offset, function (src) + if src.type == 'call' then + if not call or call.start < src.start then + call = src + end + end + end) + return call +end + +local function getCallArgInfo(call, text, offset) + if not call.args then + return 1, nil + end + for index, arg in ipairs(call.args) do + if arg.start <= offset and arg.finish >= offset then + return index, arg + end + end + return #call.args + 1, nil +end + +local function tryCallArg(ast, text, offset, results) + local call = findCall(ast, text, offset) + if not call then + return + end + local myResults = {} + local argIndex, arg = getCallArgInfo(call, text, offset) + if arg and arg.type == 'function' then + return + end + local defs = vm.getDefs(call.node, 'deep') + for _, def in ipairs(defs) do + def = guide.getObjectValue(def) or def + local enums = getCallEnums(def, argIndex) + if enums then + mergeEnums(myResults, enums, text, arg) + end + end + for _, enum in ipairs(myResults) do + results[#results+1] = enum + end +end + +local function getComment(ast, offset) + for _, comm in ipairs(ast.comms) do + if offset >= comm.start and offset <= comm.finish then + return comm + end + end + return nil +end + +local function tryLuaDocCate(line, results) + local word = line:sub(3) + for _, docType in ipairs { + 'class', + 'type', + 'alias', + 'param', + 'return', + 'field', + 'generic', + 'vararg', + 'overload', + 'deprecated', + 'meta', + 'version', + } do + if matchKey(word, docType) then + results[#results+1] = { + label = docType, + kind = define.CompletionItemKind.Event, + } + end + end +end + +local function getLuaDocByContain(ast, offset) + local result + local range = math.huge + guide.eachSourceContain(ast.ast.docs, offset, function (src) + if not src.start then + return + end + if range >= offset - src.start + and offset <= src.finish then + range = offset - src.start + result = src + end + end) + return result +end + +local function getLuaDocByErr(ast, text, start, offset) + local targetError + for _, err in ipairs(ast.errs) do + if err.finish <= offset + and err.start >= start then + if not text:sub(err.finish + 1, offset):find '%S' then + targetError = err + break + end + end + end + if not targetError then + return nil + end + local targetDoc + for i = #ast.ast.docs, 1, -1 do + local doc = ast.ast.docs[i] + if doc.finish <= targetError.start then + targetDoc = doc + break + end + end + return targetError, targetDoc +end + +local function tryLuaDocBySource(ast, offset, source, results) + if source.type == 'doc.extends.name' then + if source.parent.type == 'doc.class' then + for _, doc in ipairs(vm.getDocTypes '*') do + if doc.type == 'doc.class.name' + and doc.parent ~= source.parent + and matchKey(source[1], doc[1]) then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + end + elseif source.type == 'doc.type.name' then + for _, doc in ipairs(vm.getDocTypes '*') do + if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') + and doc.parent ~= source.parent + and matchKey(source[1], doc[1]) then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + elseif source.type == 'doc.param.name' then + local funcs = {} + guide.eachSourceBetween(ast.ast, offset, math.huge, function (src) + if src.type == 'function' and src.start > offset then + funcs[#funcs+1] = src + end + end) + table.sort(funcs, function (a, b) + return a.start < b.start + end) + local func = funcs[1] + if not func or not func.args then + return + end + for _, arg in ipairs(func.args) do + if arg[1] and matchKey(source[1], arg[1]) then + results[#results+1] = { + label = arg[1], + kind = define.CompletionItemKind.Interface, + } + end + end + end +end + +local function tryLuaDocByErr(ast, offset, err, docState, results) + if err.type == 'LUADOC_MISS_CLASS_EXTENDS_NAME' then + for _, doc in ipairs(vm.getDocTypes '*') do + if doc.type == 'doc.class.name' + and doc.parent ~= docState then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + elseif err.type == 'LUADOC_MISS_TYPE_NAME' then + for _, doc in ipairs(vm.getDocTypes '*') do + if (doc.type == 'doc.class.name' or doc.type == 'doc.alias.name') then + results[#results+1] = { + label = doc[1], + kind = define.CompletionItemKind.Class, + } + end + end + elseif err.type == 'LUADOC_MISS_PARAM_NAME' then + local funcs = {} + guide.eachSourceBetween(ast.ast, offset, math.huge, function (src) + if src.type == 'function' and src.start > offset then + funcs[#funcs+1] = src + end + end) + table.sort(funcs, function (a, b) + return a.start < b.start + end) + local func = funcs[1] + if not func or not func.args then + return + end + local label = {} + local insertText = {} + for i, arg in ipairs(func.args) do + if arg[1] then + label[#label+1] = arg[1] + if i == 1 then + insertText[i] = ('%s ${%d:any}'):format(arg[1], i) + else + insertText[i] = ('---@param %s ${%d:any}'):format(arg[1], i) + end + end + end + results[#results+1] = { + label = table.concat(label, ', '), + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = table.concat(insertText, '\n'), + } + for i, arg in ipairs(func.args) do + if arg[1] then + results[#results+1] = { + label = arg[1], + kind = define.CompletionItemKind.Interface, + } + end + end + end +end + +local function tryLuaDocFeatures(line, ast, comm, offset, results) +end + +local function tryLuaDoc(ast, text, offset, results) + local comm = getComment(ast, offset) + local line = text:sub(comm.start, offset) + if not line then + return + end + if line:sub(1, 2) ~= '-@' then + return + end + -- 尝试 ---@$ + local cate = line:match('%a*', 3) + if #cate + 2 >= #line then + tryLuaDocCate(line, results) + return + end + -- 尝试一些其他特征 + if tryLuaDocFeatures(line, ast, comm, offset, results) then + return + end + -- 根据输入中的source来补全 + local source = getLuaDocByContain(ast, offset) + if source then + tryLuaDocBySource(ast, offset, source, results) + return + end + -- 根据附近的错误消息来补全 + local err, doc = getLuaDocByErr(ast, text, comm.start, offset) + if err then + tryLuaDocByErr(ast, offset, err, doc, results) + return + end +end + +local function completion(uri, offset) + local ast = files.getAst(uri) + local text = files.getText(uri) + local results = {} + clearStack() + if ast then + if getComment(ast, offset) then + tryLuaDoc(ast, text, offset, results) + else + trySpecial(ast, text, offset, results) + tryWord(ast, text, offset, results) + tryIndex(ast, text, offset, results) + trySymbol(ast, text, offset, results) + tryCallArg(ast, text, offset, results) + end + else + local word = findWord(text, offset) + if word then + checkCommon(word, text, offset, results) + end + end + + if #results == 0 then + return nil + end + return results +end + +local function resolve(id) + return resolveStack(id) +end + +return { + completion = completion, + resolve = resolve, +} diff --git a/script/core/definition.lua b/script/core/definition.lua new file mode 100644 index 00000000..c143939d --- /dev/null +++ b/script/core/definition.lua @@ -0,0 +1,156 @@ +local guide = require 'parser.guide' +local workspace = require 'workspace' +local files = require 'files' +local vm = require 'vm' +local findSource = require 'core.find-source' + +local function sortResults(results) + -- 先按照顺序排序 + table.sort(results, function (a, b) + local u1 = guide.getUri(a.target) + local u2 = guide.getUri(b.target) + if u1 == u2 then + return a.target.start < b.target.start + else + return u1 < u2 + end + end) + -- 如果2个结果处于嵌套状态,则取范围小的那个 + local lf, lu + for i = #results, 1, -1 do + local res = results[i].target + local f = res.finish + local uri = guide.getUri(res) + if lf and f > lf and uri == lu then + table.remove(results, i) + else + lu = uri + lf = f + end + end +end + +local accept = { + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['label'] = true, + ['goto'] = true, + ['field'] = true, + ['method'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['string'] = true, + ['boolean'] = true, + ['number'] = true, + + ['doc.type.name'] = true, + ['doc.class.name'] = true, + ['doc.extends.name'] = true, + ['doc.alias.name'] = true, +} + +local function checkRequire(source, offset) + if source.type ~= 'string' then + return nil + end + local callargs = source.parent + if callargs.type ~= 'callargs' then + return + end + if callargs[1] ~= source then + return + end + local call = callargs.parent + local func = call.node + local literal = guide.getLiteral(source) + local libName = vm.getLibraryName(func) + if not libName then + return nil + end + if libName == 'require' then + return workspace.findUrisByRequirePath(literal) + elseif libName == 'dofile' + or libName == 'loadfile' then + return workspace.findUrisByFilePath(literal) + end + return nil +end + +local function convertIndex(source) + if not source then + return + end + if source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' then + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + return parent + end + end + return source +end + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local source = convertIndex(findSource(ast, offset, accept)) + if not source then + return nil + end + + local results = {} + local uris = checkRequire(source) + if uris then + for i, uri in ipairs(uris) do + results[#results+1] = { + uri = files.getOriginUri(uri), + source = source, + target = { + start = 0, + finish = 0, + uri = uri, + } + } + end + end + + for _, src in ipairs(vm.getDefs(source, 'deep')) do + local root = guide.getRoot(src) + if not root then + goto CONTINUE + end + src = src.field or src.method or src.index or src + if src.type == 'table' and src.parent.type ~= 'return' then + goto CONTINUE + end + if src.type == 'doc.class.name' + and source.type ~= 'doc.type.name' + and source.type ~= 'doc.extends.name' then + goto CONTINUE + end + results[#results+1] = { + target = src, + uri = files.getOriginUri(root.uri), + source = source, + } + ::CONTINUE:: + end + + if #results == 0 then + return nil + end + + sortResults(results) + + return results +end diff --git a/script/core/diagnostics/ambiguity-1.lua b/script/core/diagnostics/ambiguity-1.lua new file mode 100644 index 00000000..37815fb5 --- /dev/null +++ b/script/core/diagnostics/ambiguity-1.lua @@ -0,0 +1,69 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +local opMap = { + ['+'] = true, + ['-'] = true, + ['*'] = true, + ['/'] = true, + ['//'] = true, + ['^'] = true, + ['<<'] = true, + ['>>'] = true, + ['&'] = true, + ['|'] = true, + ['~'] = true, + ['..'] = true, +} + +local literalMap = { + ['number'] = true, + ['boolean'] = true, + ['string'] = true, + ['table'] = true, +} + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local text = files.getText(uri) + guide.eachSourceType(ast.ast, 'binary', function (source) + if source.op.type ~= 'or' then + return + end + local first = source[1] + local second = source[2] + -- a + (b or 0) --> (a + b) or 0 + do + if opMap[first.op and first.op.type] + and first.type ~= 'unary' + and not second.op + and literalMap[second.type] + and not literalMap[first[2].type] + then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_AMBIGUITY_1', text:sub(first.start, first.finish)) + } + end + end + -- (a or 0) + c --> a or (0 + c) + do + if opMap[second.op and second.op.type] + and second.type ~= 'unary' + and not first.op + and literalMap[second[1].type] + then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_AMBIGUITY_1', text:sub(second.start, second.finish)) + } + end + end + end) +end diff --git a/script/core/diagnostics/circle-doc-class.lua b/script/core/diagnostics/circle-doc-class.lua new file mode 100644 index 00000000..55179447 --- /dev/null +++ b/script/core/diagnostics/circle-doc-class.lua @@ -0,0 +1,54 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.class' then + if not doc.extends then + goto CONTINUE + end + local myName = guide.getName(doc) + local list = { doc } + local mark = {} + for i = 1, 999 do + local current = list[i] + if not current then + goto CONTINUE + end + if current.extends then + local newName = current.extends[1] + if newName == myName then + callback { + start = doc.start, + finish = doc.finish, + message = lang.script('DIAG_CIRCLE_DOC_CLASS', myName) + } + goto CONTINUE + end + if not mark[newName] then + mark[newName] = true + local docs = vm.getDocTypes(newName) + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' then + list[#list+1] = otherDoc.parent + end + end + end + end + end + ::CONTINUE:: + end + end +end diff --git a/script/core/diagnostics/code-after-break.lua b/script/core/diagnostics/code-after-break.lua new file mode 100644 index 00000000..a2bac8a4 --- /dev/null +++ b/script/core/diagnostics/code-after-break.lua @@ -0,0 +1,34 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + local mark = {} + guide.eachSourceType(state.ast, 'break', function (source) + local list = source.parent + if mark[list] then + return + end + mark[list] = true + for i = #list, 1, -1 do + local src = list[i] + if src == source then + if i == #list then + return + end + callback { + start = list[i+1].start, + finish = list[#list].range or list[#list].finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_CODE_AFTER_BREAK, + } + end + end + end) +end diff --git a/script/core/diagnostics/doc-field-no-class.lua b/script/core/diagnostics/doc-field-no-class.lua new file mode 100644 index 00000000..f27bbb32 --- /dev/null +++ b/script/core/diagnostics/doc-field-no-class.lua @@ -0,0 +1,41 @@ +local files = require 'files' +local lang = require 'language' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type ~= 'doc.field' then + goto CONTINUE + end + local bindGroup = doc.bindGroup + if not bindGroup then + goto CONTINUE + end + local ok + for _, other in ipairs(bindGroup) do + if other.type == 'doc.class' then + ok = true + break + end + if other == doc then + break + end + end + if not ok then + callback { + start = doc.start, + finish = doc.finish, + message = lang.script('DIAG_DOC_FIELD_NO_CLASS'), + } + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/duplicate-doc-class.lua b/script/core/diagnostics/duplicate-doc-class.lua new file mode 100644 index 00000000..259c048b --- /dev/null +++ b/script/core/diagnostics/duplicate-doc-class.lua @@ -0,0 +1,46 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + local cache = {} + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.class' + or doc.type == 'doc.alias' then + local name = guide.getName(doc) + if not cache[name] then + local docs = vm.getDocTypes(name) + cache[name] = {} + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' + or otherDoc.type == 'doc.alias.name' then + cache[name][#cache[name]+1] = { + start = otherDoc.start, + finish = otherDoc.finish, + uri = guide.getUri(otherDoc), + } + end + end + end + if #cache[name] > 1 then + callback { + start = doc.start, + finish = doc.finish, + related = cache, + message = lang.script('DIAG_DUPLICATE_DOC_CLASS', name) + } + end + end + end +end diff --git a/script/core/diagnostics/duplicate-doc-field.lua b/script/core/diagnostics/duplicate-doc-field.lua new file mode 100644 index 00000000..b621fd9e --- /dev/null +++ b/script/core/diagnostics/duplicate-doc-field.lua @@ -0,0 +1,34 @@ +local files = require 'files' +local lang = require 'language' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + local mark + for _, group in ipairs(state.ast.docs.groups) do + for _, doc in ipairs(group) do + if doc.type == 'doc.class' then + mark = {} + elseif doc.type == 'doc.field' then + if mark then + local name = doc.field[1] + if mark[name] then + callback { + start = doc.field.start, + finish = doc.field.finish, + message = lang.script('DIAG_DUPLICATE_DOC_FIELD', name), + } + end + mark[name] = true + end + end + end + end +end diff --git a/script/core/diagnostics/duplicate-doc-param.lua b/script/core/diagnostics/duplicate-doc-param.lua new file mode 100644 index 00000000..676a6fb4 --- /dev/null +++ b/script/core/diagnostics/duplicate-doc-param.lua @@ -0,0 +1,37 @@ +local files = require 'files' +local lang = require 'language' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type ~= 'doc.param' then + goto CONTINUE + end + local name = doc.param[1] + local bindGroup = doc.bindGroup + if not bindGroup then + goto CONTINUE + end + for _, other in ipairs(bindGroup) do + if other ~= doc + and other.type == 'doc.param' + and other.param[1] == name then + callback { + start = doc.param.start, + finish = doc.param.finish, + message = lang.script('DIAG_DUPLICATE_DOC_PARAM', name) + } + goto CONTINUE + end + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/duplicate-index.lua b/script/core/diagnostics/duplicate-index.lua new file mode 100644 index 00000000..dabe1b3c --- /dev/null +++ b/script/core/diagnostics/duplicate-index.lua @@ -0,0 +1,63 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'table', function (source) + local mark = {} + for _, obj in ipairs(source) do + if obj.type == 'tablefield' + or obj.type == 'tableindex' then + local name = vm.getKeyName(obj) + if name then + if not mark[name] then + mark[name] = {} + end + mark[name][#mark[name]+1] = obj.field or obj.index + end + end + end + + for name, defs in pairs(mark) do + local sname = name:match '^.|(.+)$' + if #defs > 1 and sname then + local related = {} + for i = 1, #defs do + local def = defs[i] + related[i] = { + start = def.start, + finish = def.finish, + uri = uri, + } + end + for i = 1, #defs - 1 do + local def = defs[i] + callback { + start = def.start, + finish = def.finish, + related = related, + message = lang.script('DIAG_DUPLICATE_INDEX', sname), + level = define.DiagnosticSeverity.Hint, + tags = { define.DiagnosticTag.Unnecessary }, + } + end + for i = #defs, #defs do + local def = defs[i] + callback { + start = def.start, + finish = def.finish, + related = related, + message = lang.script('DIAG_DUPLICATE_INDEX', sname), + } + end + end + end + end) +end diff --git a/script/core/diagnostics/empty-block.lua b/script/core/diagnostics/empty-block.lua new file mode 100644 index 00000000..2024f4e3 --- /dev/null +++ b/script/core/diagnostics/empty-block.lua @@ -0,0 +1,49 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' + +-- 检查空代码块 +-- 但是排除忙等待(repeat/while) +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'if', function (source) + for _, block in ipairs(source) do + if #block > 0 then + return + end + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) + guide.eachSourceType(ast.ast, 'loop', function (source) + if #source > 0 then + return + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) + guide.eachSourceType(ast.ast, 'in', function (source) + if #source > 0 then + return + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_EMPTY_BLOCK, + } + end) +end diff --git a/script/core/diagnostics/global-in-nil-env.lua b/script/core/diagnostics/global-in-nil-env.lua new file mode 100644 index 00000000..9a0d4f35 --- /dev/null +++ b/script/core/diagnostics/global-in-nil-env.lua @@ -0,0 +1,66 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +-- TODO: 检查路径是否可达 +local function mayRun(path) + return true +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local root = guide.getRoot(ast.ast) + local env = guide.getENV(root) + + local nilDefs = {} + if not env.ref then + return + end + for _, ref in ipairs(env.ref) do + if ref.type == 'setlocal' then + if ref.value and ref.value.type == 'nil' then + nilDefs[#nilDefs+1] = ref + end + end + end + + if #nilDefs == 0 then + return + end + + local function check(source) + local node = source.node + if node.tag == '_ENV' then + local ok + for _, nilDef in ipairs(nilDefs) do + local mode, pathA = guide.getPath(nilDef, source) + if mode == 'before' + and mayRun(pathA) then + ok = nilDef + break + end + end + if ok then + callback { + start = source.start, + finish = source.finish, + uri = uri, + message = lang.script.DIAG_GLOBAL_IN_NIL_ENV, + related = { + { + start = ok.start, + finish = ok.finish, + uri = uri, + } + } + } + end + end + end + + guide.eachSourceType(ast.ast, 'getglobal', check) + guide.eachSourceType(ast.ast, 'setglobal', check) +end diff --git a/script/core/diagnostics/init.lua b/script/core/diagnostics/init.lua new file mode 100644 index 00000000..a6b61e12 --- /dev/null +++ b/script/core/diagnostics/init.lua @@ -0,0 +1,56 @@ +local files = require 'files' +local define = require 'proto.define' +local config = require 'config' +local await = require 'await' + +-- 把耗时最长的诊断放到最后面 +local diagLevel = { + ['redundant-parameter'] = 100, +} + +local diagList = {} +for k in pairs(define.DiagnosticDefaultSeverity) do + diagList[#diagList+1] = k +end +table.sort(diagList, function (a, b) + return (diagLevel[a] or 0) < (diagLevel[b] or 0) +end) + +local function check(uri, name, results) + if config.config.diagnostics.disable[name] then + return + end + local level = config.config.diagnostics.severity[name] + or define.DiagnosticDefaultSeverity[name] + if level == 'Hint' and not files.isOpen(uri) then + return + end + local severity = define.DiagnosticSeverity[level] + local clock = os.clock() + require('core.diagnostics.' .. name)(uri, function (result) + result.level = severity or result.level + result.code = name + results[#results+1] = result + end, name) + local passed = os.clock() - clock + if passed >= 0.5 then + log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed)) + end +end + +return function (uri, response) + local vm = require 'vm' + local ast = files.getAst(uri) + if not ast then + return nil + end + + local isOpen = files.isOpen(uri) + + for _, name in ipairs(diagList) do + await.delay() + local results = {} + check(uri, name, results) + response(results) + end +end diff --git a/script/core/diagnostics/lowercase-global.lua b/script/core/diagnostics/lowercase-global.lua new file mode 100644 index 00000000..fe5d1eca --- /dev/null +++ b/script/core/diagnostics/lowercase-global.lua @@ -0,0 +1,56 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local config = require 'config' +local vm = require 'vm' + +local function isDocClass(source) + if not source.bindDocs then + return false + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.class' then + return true + end + end + return false +end + +-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删) +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + local definedGlobal = {} + for name in pairs(config.config.diagnostics.globals) do + definedGlobal[name] = true + end + + guide.eachSourceType(ast.ast, 'setglobal', function (source) + local name = guide.getName(source) + if definedGlobal[name] then + return + end + local first = name:match '%w' + if not first then + return + end + if not first:match '%l' then + return + end + -- 如果赋值被标记为 doc.class ,则认为是允许的 + if isDocClass(source) then + return + end + if vm.isGlobalLibraryName(name) then + return + end + callback { + start = source.start, + finish = source.finish, + message = lang.script.DIAG_LOWERCASE_GLOBAL, + } + end) +end diff --git a/script/core/diagnostics/newfield-call.lua b/script/core/diagnostics/newfield-call.lua new file mode 100644 index 00000000..75681cbc --- /dev/null +++ b/script/core/diagnostics/newfield-call.lua @@ -0,0 +1,37 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + local lines = files.getLines(uri) + local text = files.getText(uri) + + guide.eachSourceType(ast.ast, 'table', function (source) + for i = 1, #source do + local field = source[i] + if field.type == 'call' then + local func = field.node + local args = field.args + if args then + local funcLine = guide.positionOf(lines, func.finish) + local argsLine = guide.positionOf(lines, args.start) + if argsLine > funcLine then + callback { + start = field.start, + finish = field.finish, + message = lang.script('DIAG_PREFIELD_CALL' + , text:sub(func.start, func.finish) + , text:sub(args.start, args.finish) + ) + } + end + end + end + end + end) +end diff --git a/script/core/diagnostics/newline-call.lua b/script/core/diagnostics/newline-call.lua new file mode 100644 index 00000000..cb318380 --- /dev/null +++ b/script/core/diagnostics/newline-call.lua @@ -0,0 +1,38 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local lines = files.getLines(uri) + + guide.eachSourceType(ast.ast, 'call', function (source) + local node = source.node + local args = source.args + if not args then + return + end + + -- 必须有其他人在继续使用当前对象 + if not source.next then + return + end + + local nodeRow = guide.positionOf(lines, node.finish) + local argRow = guide.positionOf(lines, args.start) + if nodeRow == argRow then + return + end + + if #args == 1 then + callback { + start = args.start, + finish = args.finish, + message = lang.script.DIAG_PREVIOUS_CALL, + } + end + end) +end diff --git a/script/core/diagnostics/redefined-local.lua b/script/core/diagnostics/redefined-local.lua new file mode 100644 index 00000000..5e53d837 --- /dev/null +++ b/script/core/diagnostics/redefined-local.lua @@ -0,0 +1,32 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'local', function (source) + local name = source[1] + if name == '_' + or name == ast.ENVMode then + return + end + local exist = guide.getLocal(source, name, source.start-1) + if exist then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_REDEFINED_LOCAL', name), + related = { + { + start = exist.start, + finish = exist.finish, + uri = uri, + } + }, + } + end + end) +end diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua new file mode 100644 index 00000000..2fae20e8 --- /dev/null +++ b/script/core/diagnostics/redundant-parameter.lua @@ -0,0 +1,82 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' +local define = require 'proto.define' +local await = require 'await' + +local function countCallArgs(source) + local result = 0 + if not source.args then + return 0 + end + if source.node and source.node.type == 'getmethod' then + result = result + 1 + end + result = result + #source.args + return result +end + +local function countFuncArgs(source) + local result = 0 + if source.parent and source.parent.type == 'setmethod' then + result = result + 1 + end + if not source.args then + return result + end + if source.args[#source.args].type == '...' then + return math.maxinteger + end + result = result + #source.args + return result +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'call', function (source) + local callArgs = countCallArgs(source) + if callArgs == 0 then + return + end + + local func = source.node + local funcArgs + local defs = vm.getDefs(func) + for _, def in ipairs(defs) do + if def.value then + def = def.value + end + if def.type == 'function' then + local args = countFuncArgs(def) + if not funcArgs or args > funcArgs then + funcArgs = args + end + end + end + + if not funcArgs then + return + end + + local delta = callArgs - funcArgs + if delta <= 0 then + return + end + for i = #source.args - delta + 1, #source.args do + local arg = source.args[i] + if arg then + callback { + start = arg.start, + finish = arg.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs) + } + end + end + end) +end diff --git a/script/core/diagnostics/redundant-value.lua b/script/core/diagnostics/redundant-value.lua new file mode 100644 index 00000000..be483448 --- /dev/null +++ b/script/core/diagnostics/redundant-value.lua @@ -0,0 +1,24 @@ +local files = require 'files' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback, code) + local ast = files.getAst(uri) + if not ast then + return + end + + local diags = ast.diags[code] + if not diags then + return + end + + for _, info in ipairs(diags) do + callback { + start = info.start, + finish = info.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_OVER_MAX_VALUES', info.max, info.passed) + } + end +end diff --git a/script/core/diagnostics/trailing-space.lua b/script/core/diagnostics/trailing-space.lua new file mode 100644 index 00000000..e54a6e60 --- /dev/null +++ b/script/core/diagnostics/trailing-space.lua @@ -0,0 +1,55 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' + +local function isInString(ast, offset) + local result = false + guide.eachSourceType(ast, 'string', function (source) + if offset >= source.start and offset <= source.finish then + result = true + end + end) + return result +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + local text = files.getText(uri) + local lines = files.getLines(uri) + for i = 1, #lines do + local start = lines[i].start + local range = lines[i].range + local lastChar = text:sub(range, range) + if lastChar ~= ' ' and lastChar ~= '\t' then + goto NEXT_LINE + end + if isInString(ast.ast, range) then + goto NEXT_LINE + end + local first = start + for n = range - 1, start, -1 do + local char = text:sub(n, n) + if char ~= ' ' and char ~= '\t' then + first = n + 1 + break + end + end + if first == start then + callback { + start = first, + finish = range, + message = lang.script.DIAG_LINE_ONLY_SPACE, + } + else + callback { + start = first, + finish = range, + message = lang.script.DIAG_LINE_POST_SPACE, + } + end + ::NEXT_LINE:: + end +end diff --git a/script/core/diagnostics/undefined-doc-class.lua b/script/core/diagnostics/undefined-doc-class.lua new file mode 100644 index 00000000..bbfdceec --- /dev/null +++ b/script/core/diagnostics/undefined-doc-class.lua @@ -0,0 +1,46 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + local cache = {} + for _, doc in ipairs(state.ast.docs) do + if doc.type == 'doc.class' then + local ext = doc.extends + if not ext then + goto CONTINUE + end + local name = ext[1] + local docs = vm.getDocTypes(name) + if cache[name] == nil then + cache[name] = false + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' then + cache[name] = true + break + end + end + end + if not cache[name] then + callback { + start = ext.start, + finish = ext.finish, + related = cache, + message = lang.script('DIAG_UNDEFINED_DOC_CLASS', name) + } + end + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua new file mode 100644 index 00000000..5c1e8fbf --- /dev/null +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -0,0 +1,60 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +local function hasNameOfClassOrAlias(name) + local docs = vm.getDocTypes(name) + for _, otherDoc in ipairs(docs) do + if otherDoc.type == 'doc.class.name' + or otherDoc.type == 'doc.alias.name' then + return true + end + end + return false +end + +local function hasNameOfGeneric(name, source) + if not source.typeGeneric then + return false + end + if not source.typeGeneric[name] then + return false + end + return true +end + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + guide.eachSource(state.ast.docs, function (source) + if source.type ~= 'doc.extends.name' + and source.type ~= 'doc.type.name' then + return + end + if source.parent.type == 'doc.class' then + return + end + local name = source[1] + if name == '...' then + return + end + if hasNameOfClassOrAlias(name) + or hasNameOfGeneric(name, source) then + return + end + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_UNDEFINED_DOC_NAME', name) + } + end) +end diff --git a/script/core/diagnostics/undefined-doc-param.lua b/script/core/diagnostics/undefined-doc-param.lua new file mode 100644 index 00000000..af3e07bc --- /dev/null +++ b/script/core/diagnostics/undefined-doc-param.lua @@ -0,0 +1,52 @@ +local files = require 'files' +local guide = require 'parser.guide' +local lang = require 'language' +local define = require 'proto.define' +local vm = require 'vm' + +local function hasParamName(func, name) + if not func.args then + return false + end + for _, arg in ipairs(func.args) do + if arg[1] == name then + return true + end + end + return false +end + +return function (uri, callback) + local state = files.getAst(uri) + if not state then + return + end + + if not state.ast.docs then + return + end + + for _, doc in ipairs(state.ast.docs) do + if doc.type ~= 'doc.param' then + goto CONTINUE + end + local binds = doc.bindSources + if not binds then + goto CONTINUE + end + local param = doc.param + local name = param[1] + for _, source in ipairs(binds) do + if source.type == 'function' then + if not hasParamName(source, name) then + callback { + start = param.start, + finish = param.finish, + message = lang.script('DIAG_UNDEFINED_DOC_PARAM', name) + } + end + end + end + ::CONTINUE:: + end +end diff --git a/script/core/diagnostics/undefined-env-child.lua b/script/core/diagnostics/undefined-env-child.lua new file mode 100644 index 00000000..6b8c62f0 --- /dev/null +++ b/script/core/diagnostics/undefined-env-child.lua @@ -0,0 +1,27 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'getglobal', function (source) + -- 单独验证自己是否在重载过的 _ENV 中有定义 + if source.node.tag == '_ENV' then + return + end + local defs = guide.requestDefinition(source) + if #defs > 0 then + return + end + local key = source[1] + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_UNDEF_ENV_CHILD', key), + } + end) +end diff --git a/script/core/diagnostics/undefined-global.lua b/script/core/diagnostics/undefined-global.lua new file mode 100644 index 00000000..778fc1f1 --- /dev/null +++ b/script/core/diagnostics/undefined-global.lua @@ -0,0 +1,40 @@ +local files = require 'files' +local vm = require 'vm' +local lang = require 'language' +local config = require 'config' +local guide = require 'parser.guide' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + -- 遍历全局变量,检查所有没有 set 模式的全局变量 + guide.eachSourceType(ast.ast, 'getglobal', function (src) + local key = guide.getName(src) + if not key then + return + end + if config.config.diagnostics.globals[key] then + return + end + if #vm.getGlobalSets(guide.getKeyName(src)) > 0 then + return + end + local message = lang.script('DIAG_UNDEF_GLOBAL', key) + -- TODO check other version + local otherVersion + local customVersion + if otherVersion then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_VERSION', table.concat(otherVersion, '/'), config.config.runtime.version)) + elseif customVersion then + message = ('%s(%s)'):format(message, lang.script('DIAG_DEFINED_CUSTOM', table.concat(customVersion, '/'))) + end + callback { + start = src.start, + finish = src.finish, + message = message, + } + end) +end diff --git a/script/core/diagnostics/unused-function.lua b/script/core/diagnostics/unused-function.lua new file mode 100644 index 00000000..f0bca613 --- /dev/null +++ b/script/core/diagnostics/unused-function.lua @@ -0,0 +1,40 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local define = require 'proto.define' +local lang = require 'language' +local await = require 'await' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + -- 只检查局部函数 + guide.eachSourceType(ast.ast, 'function', function (source) + local parent = source.parent + if not parent then + return + end + if parent.type ~= 'local' + and parent.type ~= 'setlocal' then + return + end + local hasGet + local refs = vm.getRefs(source) + for _, src in ipairs(refs) do + if vm.isGet(src) then + hasGet = true + break + end + end + if not hasGet then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_FUNCTION, + } + end + end) +end diff --git a/script/core/diagnostics/unused-label.lua b/script/core/diagnostics/unused-label.lua new file mode 100644 index 00000000..e6d998ba --- /dev/null +++ b/script/core/diagnostics/unused-label.lua @@ -0,0 +1,22 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'label', function (source) + if not source.ref then + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LABEL', source[1]), + } + end + end) +end diff --git a/script/core/diagnostics/unused-local.lua b/script/core/diagnostics/unused-local.lua new file mode 100644 index 00000000..873a70f2 --- /dev/null +++ b/script/core/diagnostics/unused-local.lua @@ -0,0 +1,93 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +local function hasGet(loc) + if not loc.ref then + return false + end + local weak + for _, ref in ipairs(loc.ref) do + if ref.type == 'getlocal' then + if not ref.next then + return 'strong' + end + local nextType = ref.next.type + if nextType ~= 'setmethod' + and nextType ~= 'setfield' + and nextType ~= 'setindex' then + return 'strong' + else + weak = true + end + end + end + if weak then + return 'weak' + else + return nil + end +end + +local function isMyTable(loc) + local value = loc.value + if value and value.type == 'table' then + return true + end + return false +end + +local function isClose(source) + if not source.attrs then + return false + end + for _, attr in ipairs(source.attrs) do + if attr[1] == 'close' then + return true + end + end + return false +end + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + guide.eachSourceType(ast.ast, 'local', function (source) + local name = source[1] + if name == '_' + or name == ast.ENVMode then + return + end + if isClose(source) then + return + end + local data = hasGet(source) + if data == 'strong' then + return + end + if data == 'weak' then + if not isMyTable(source) then + return + end + end + callback { + start = source.start, + finish = source.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LOCAL', name), + } + if source.ref then + for _, ref in ipairs(source.ref) do + callback { + start = ref.start, + finish = ref.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script('DIAG_UNUSED_LOCAL', name), + } + end + end + end) +end diff --git a/script/core/diagnostics/unused-vararg.lua b/script/core/diagnostics/unused-vararg.lua new file mode 100644 index 00000000..74cc08e7 --- /dev/null +++ b/script/core/diagnostics/unused-vararg.lua @@ -0,0 +1,31 @@ +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local lang = require 'language' + +return function (uri, callback) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSourceType(ast.ast, 'function', function (source) + local args = source.args + if not args then + return + end + + for _, arg in ipairs(args) do + if arg.type == '...' then + if not arg.ref then + callback { + start = arg.start, + finish = arg.finish, + tags = { define.DiagnosticTag.Unnecessary }, + message = lang.script.DIAG_UNUSED_VARARG, + } + end + end + end + end) +end diff --git a/script/core/document-symbol.lua b/script/core/document-symbol.lua new file mode 100644 index 00000000..7392b337 --- /dev/null +++ b/script/core/document-symbol.lua @@ -0,0 +1,307 @@ +local await = require 'await' +local files = require 'files' +local guide = require 'parser.guide' +local define = require 'proto.define' +local util = require 'utility' + +local function buildName(source, text) + if source.type == 'setmethod' + or source.type == 'getmethod' then + if source.method then + return text:sub(source.start, source.method.finish) + end + end + if source.type == 'setfield' + or source.type == 'tablefield' + or source.type == 'getfield' then + if source.field then + return text:sub(source.start, source.field.finish) + end + end + return text:sub(source.start, source.finish) +end + +local function buildFunctionParams(func) + if not func.args then + return '' + end + local params = {} + for i, arg in ipairs(func.args) do + if arg.type == '...' then + params[i] = '...' + else + params[i] = arg[1] or '' + end + end + return table.concat(params, ', ') +end + +local function buildFunction(source, text, symbols) + local name = buildName(source, text) + local func = source.value + if source.type == 'tablefield' + or source.type == 'setfield' then + source = source.field + if not source then + return + end + end + local range, kind + if func.start > source.finish then + -- a = function() + range = { source.start, func.finish } + else + -- function f() + range = { func.start, func.finish } + end + if source.type == 'setmethod' then + kind = define.SymbolKind.Method + else + kind = define.SymbolKind.Function + end + symbols[#symbols+1] = { + name = name, + detail = ('function (%s)'):format(buildFunctionParams(func)), + kind = kind, + range = range, + selectionRange = { source.start, source.finish }, + valueRange = { func.start, func.finish }, + } +end + +local function buildTable(tbl) + local buf = {} + for i = 1, 3 do + local field = tbl[i] + if not field then + break + end + if field.type == 'tablefield' then + buf[i] = ('%s'):format(field.field[1]) + end + end + return table.concat(buf, ', ') +end + +local function buildValue(source, text, symbols) + local name = buildName(source, text) + local range, sRange, valueRange, kind + local details = {} + if source.type == 'local' then + if source.parent.type == 'funcargs' then + details[1] = 'param' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Constant + else + details[1] = 'local' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Variable + end + elseif source.type == 'setlocal' then + details[1] = 'setlocal' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Variable + elseif source.type == 'setglobal' then + details[1] = 'global' + range = { source.start, source.finish } + sRange = { source.start, source.finish } + kind = define.SymbolKind.Class + elseif source.type == 'tablefield' then + if not source.field then + return + end + details[1] = 'field' + range = { source.field.start, source.field.finish } + sRange = { source.field.start, source.field.finish } + kind = define.SymbolKind.Property + elseif source.type == 'setfield' then + if not source.field then + return + end + details[1] = 'field' + range = { source.field.start, source.field.finish } + sRange = { source.field.start, source.field.finish } + kind = define.SymbolKind.Field + else + return + end + if source.value then + local literal = source.value[1] + if source.value.type == 'boolean' then + details[2] = ' boolean' + if literal ~= nil then + details[3] = ' = ' + details[4] = util.viewLiteral(source.value[1]) + end + elseif source.value.type == 'string' then + details[2] = ' string' + if literal ~= nil then + details[3] = ' = ' + details[4] = util.viewLiteral(source.value[1]) + end + elseif source.value.type == 'number' then + details[2] = ' number' + if literal ~= nil then + details[3] = ' = ' + details[4] = util.viewLiteral(source.value[1]) + end + elseif source.value.type == 'table' then + details[2] = ' {' + details[3] = buildTable(source.value) + details[4] = '}' + valueRange = { source.value.start, source.value.finish } + elseif source.value.type == 'select' then + if source.value.vararg and source.value.vararg.type == 'call' then + valueRange = { source.value.start, source.value.finish } + end + end + range = { range[1], source.value.finish } + end + symbols[#symbols+1] = { + name = name, + detail = table.concat(details), + kind = kind, + range = range, + selectionRange = sRange, + valueRange = valueRange, + } +end + +local function buildSet(source, text, used, symbols) + local value = source.value + if value and value.type == 'function' then + used[value] = true + buildFunction(source, text, symbols) + else + buildValue(source, text, symbols) + end +end + +local function buildAnonymousFunction(source, text, used, symbols) + if used[source] then + return + end + used[source] = true + local head = '' + local parent = source.parent + if parent.type == 'return' then + head = 'return ' + elseif parent.type == 'callargs' then + local call = parent.parent + local node = call.node + head = buildName(node, text) .. ' -> ' + end + symbols[#symbols+1] = { + name = '', + detail = ('%sfunction (%s)'):format(head, buildFunctionParams(source)), + kind = define.SymbolKind.Function, + range = { source.start, source.finish }, + selectionRange = { source.start, source.start }, + valueRange = { source.start, source.finish }, + } +end + +local function buildSource(source, text, used, symbols) + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'setfield' + or source.type == 'setmethod' + or source.type == 'tablefield' then + await.delay() + buildSet(source, text, used, symbols) + elseif source.type == 'function' then + await.delay() + buildAnonymousFunction(source, text, used, symbols) + end +end + +local function makeSymbol(uri) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local text = files.getText(uri) + local symbols = {} + local used = {} + guide.eachSource(ast.ast, function (source) + buildSource(source, text, used, symbols) + end) + + return symbols +end + +local function packChild(ranges, symbols) + await.delay() + table.sort(symbols, function (a, b) + return a.selectionRange[1] < b.selectionRange[1] + end) + await.delay() + local root = { + valueRange = { 0, math.maxinteger }, + children = {}, + } + local stacks = { root } + for _, symbol in ipairs(symbols) do + local parent = stacks[#stacks] + -- 移除已经超出生效范围的区间 + while symbol.selectionRange[1] > parent.valueRange[2] do + stacks[#stacks] = nil + parent = stacks[#stacks] + end + -- 向后看,找出当前可能生效的区间 + local nextRange + while #ranges > 0 + and symbol.selectionRange[1] >= ranges[#ranges].valueRange[1] do + if symbol.selectionRange[1] <= ranges[#ranges].valueRange[2] then + nextRange = ranges[#ranges] + end + ranges[#ranges] = nil + end + if nextRange then + stacks[#stacks+1] = nextRange + parent = nextRange + end + if parent == symbol then + -- function f() end 的情况,selectionRange 在 valueRange 内部, + -- 当前区间置为上一层 + parent = stacks[#stacks-1] + end + -- 把自己放到当前区间中 + if not parent.children then + parent.children = {} + end + parent.children[#parent.children+1] = symbol + end + return root.children +end + +local function packSymbols(symbols) + local ranges = {} + for _, symbol in ipairs(symbols) do + if symbol.valueRange then + ranges[#ranges+1] = symbol + end + end + await.delay() + table.sort(ranges, function (a, b) + return a.valueRange[1] > b.valueRange[1] + end) + -- 处理嵌套 + return packChild(ranges, symbols) +end + +return function (uri) + local symbols = makeSymbol(uri) + if not symbols then + return nil + end + + local packedSymbols = packSymbols(symbols) + + return packedSymbols +end diff --git a/script/core/find-source.lua b/script/core/find-source.lua new file mode 100644 index 00000000..32de102c --- /dev/null +++ b/script/core/find-source.lua @@ -0,0 +1,14 @@ +local guide = require 'parser.guide' + +return function (ast, offset, accept) + local len = math.huge + local result + guide.eachSourceContain(ast.ast, offset, function (source) + local start, finish = guide.getStartFinish(source) + if finish - start < len and accept[source.type] then + result = source + len = finish - start + end + end) + return result +end diff --git a/script/core/highlight.lua b/script/core/highlight.lua new file mode 100644 index 00000000..d7671df2 --- /dev/null +++ b/script/core/highlight.lua @@ -0,0 +1,252 @@ +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm' +local define = require 'proto.define' +local findSource = require 'core.find-source' + +local function eachRef(source, callback) + local results = guide.requestReference(source) + for i = 1, #results do + callback(results[i]) + end +end + +local function eachField(source, callback) + local isGlobal = guide.isGlobal(source) + local results = guide.requestReference(source) + for i = 1, #results do + local res = results[i] + if isGlobal == guide.isGlobal(res) then + callback(res) + end + end +end + +local function eachLocal(source, callback) + callback(source) + if source.ref then + for _, ref in ipairs(source.ref) do + callback(ref) + end + end +end + +local function find(source, uri, callback) + if source.type == 'local' then + eachLocal(source, callback) + elseif source.type == 'getlocal' + or source.type == 'setlocal' then + eachLocal(source.node, callback) + elseif source.type == 'field' + or source.type == 'method' then + eachField(source.parent, callback) + elseif source.type == 'getindex' + or source.type == 'setindex' + or source.type == 'tableindex' then + eachField(source, callback) + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + eachField(source, callback) + elseif source.type == 'goto' + or source.type == 'label' then + eachRef(source, callback) + elseif source.type == 'string' + and source.parent.index == source then + eachField(source.parent, callback) + elseif source.type == 'string' + or source.type == 'boolean' + or source.type == 'number' + or source.type == 'nil' then + callback(source) + end +end + +local function checkInIf(source, text, offset) + -- 检查 end + local endA = source.finish - #'end' + 1 + local endB = source.finish + if offset >= endA + and offset <= endB + and text:sub(endA, endB) == 'end' then + return true + end + -- 检查每个子模块 + for _, block in ipairs(source) do + for i = 1, #block.keyword, 2 do + local start = block.keyword[i] + local finish = block.keyword[i+1] + if offset >= start and offset <= finish then + return true + end + end + end + return false +end + +local function makeIf(source, text, callback) + -- end + local endA = source.finish - #'end' + 1 + local endB = source.finish + if text:sub(endA, endB) == 'end' then + callback(endA, endB) + end + -- 每个子模块 + for _, block in ipairs(source) do + for i = 1, #block.keyword, 2 do + local start = block.keyword[i] + local finish = block.keyword[i+1] + callback(start, finish) + end + end + return false +end + +local function findKeyWord(ast, text, offset, callback) + guide.eachSourceContain(ast.ast, offset, function (source) + if source.type == 'do' + or source.type == 'function' + or source.type == 'loop' + or source.type == 'in' + or source.type == 'while' + or source.type == 'repeat' then + local ok + for i = 1, #source.keyword, 2 do + local start = source.keyword[i] + local finish = source.keyword[i+1] + if offset >= start and offset <= finish then + ok = true + break + end + end + if ok then + for i = 1, #source.keyword, 2 do + local start = source.keyword[i] + local finish = source.keyword[i+1] + callback(start, finish) + end + end + elseif source.type == 'if' then + local ok = checkInIf(source, text, offset) + if ok then + makeIf(source, text, callback) + end + end + end) +end + +local accept = { + ['label'] = true, + ['goto'] = true, + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['field'] = true, + ['method'] = true, + ['tablefield'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['string'] = true, + ['boolean'] = true, + ['number'] = true, + ['nil'] = true, +} + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local text = files.getText(uri) + local results = {} + local mark = {} + + local source = findSource(ast, offset, accept) + if source then + find(source, uri, function (target) + local kind + if target.type == 'getfield' then + target = target.field + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setfield' + or target.type == 'tablefield' then + target = target.field + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getmethod' then + target = target.method + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setmethod' then + target = target.method + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getindex' then + target = target.index + kind = define.DocumentHighlightKind.Read + elseif target.type == 'field' then + if target.parent.type == 'getfield' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'method' then + if target.parent.type == 'getmethod' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'index' then + if target.parent.type == 'getindex' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'index' then + if target.parent.type == 'getindex' then + kind = define.DocumentHighlightKind.Read + else + kind = define.DocumentHighlightKind.Write + end + elseif target.type == 'setindex' + or target.type == 'tableindex' then + target = target.index + kind = define.DocumentHighlightKind.Write + elseif target.type == 'getlocal' + or target.type == 'getglobal' + or target.type == 'goto' then + kind = define.DocumentHighlightKind.Read + elseif target.type == 'setlocal' + or target.type == 'local' + or target.type == 'setglobal' + or target.type == 'label' then + kind = define.DocumentHighlightKind.Write + elseif target.type == 'string' + or target.type == 'boolean' + or target.type == 'number' + or target.type == 'nil' then + kind = define.DocumentHighlightKind.Text + else + return + end + if mark[target] then + return + end + mark[target] = true + results[#results+1] = { + start = target.start, + finish = target.finish, + kind = kind, + } + end) + end + + findKeyWord(ast, text, offset, function (start, finish) + results[#results+1] = { + start = start, + finish = finish, + kind = define.DocumentHighlightKind.Write + } + end) + + if #results == 0 then + return nil + end + return results +end diff --git a/script/core/hover/arg.lua b/script/core/hover/arg.lua new file mode 100644 index 00000000..9cd19f02 --- /dev/null +++ b/script/core/hover/arg.lua @@ -0,0 +1,71 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local function optionalArg(arg) + if not arg.bindDocs then + return false + end + local name = arg[1] + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.param' and doc.param[1] == name then + return doc.optional + end + end +end + +local function asFunction(source, oop) + if not source.args then + return '' + end + local args = {} + for i = 1, #source.args do + local arg = source.args[i] + local name = arg.name or guide.getName(arg) + if name then + args[i] = ('%s%s: %s'):format( + name, + optionalArg(arg) and '?' or '', + vm.getInferType(arg) + ) + else + args[i] = ('%s'):format(vm.getInferType(arg)) + end + end + local methodDef + local parent = source.parent + if parent and parent.type == 'setmethod' then + methodDef = true + end + if not methodDef and oop then + return table.concat(args, ', ', 2) + else + return table.concat(args, ', ') + end +end + +local function asDocFunction(source) + if not source.args then + return '' + end + local args = {} + for i = 1, #source.args do + local arg = source.args[i] + local name = arg.name[1] + args[i] = ('%s%s: %s'):format( + name, + arg.optional and '?' or '', + vm.getInferType(arg.extends) + ) + end + return table.concat(args, ', ') +end + +return function (source, oop) + if source.type == 'function' then + return asFunction(source, oop) + end + if source.type == 'doc.type.function' then + return asDocFunction(source) + end + return '' +end diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua new file mode 100644 index 00000000..7d89ee6c --- /dev/null +++ b/script/core/hover/description.lua @@ -0,0 +1,204 @@ +local vm = require 'vm' +local ws = require 'workspace' +local furi = require 'file-uri' +local files = require 'files' +local guide = require 'parser.guide' +local markdown = require 'provider.markdown' +local config = require 'config' +local lang = require 'language' + +local function asStringInRequire(source, literal) + local rootPath = ws.path or '' + local parent = source.parent + if parent and parent.type == 'callargs' then + local result, searchers + local call = parent.parent + local func = call.node + local libName = vm.getLibraryName(func) + if not libName then + return + end + if libName == 'require' then + result, searchers = ws.findUrisByRequirePath(literal) + elseif libName == 'dofile' + or libName == 'loadfile' then + result = ws.findUrisByFilePath(literal) + end + if result and #result > 0 then + for i, uri in ipairs(result) do + local searcher = searchers and furi.decode(searchers[uri]) + uri = files.getOriginUri(uri) + local path = furi.decode(uri) + if files.eq(path:sub(1, #rootPath), rootPath) then + path = path:sub(#rootPath + 1) + end + path = path:gsub('^[/\\]*', '') + if vm.isMetaFile(uri) then + result[i] = ('* [[meta]](%s)'):format(uri) + elseif searcher then + searcher = searcher:sub(#rootPath + 1) + searcher = ws.normalize(searcher) + result[i] = ('* [%s](%s) %s'):format(path, uri, lang.script('HOVER_USE_LUA_PATH', searcher)) + else + result[i] = ('* [%s](%s)'):format(path, uri) + end + end + table.sort(result) + local md = markdown() + md:add('md', table.concat(result, '\n')) + return md:string() + end + end +end + +local function asStringView(source, literal) + -- 内部包含转义符? + local rawLen = source.finish - source.start - 2 * #source[2] + 1 + if config.config.hover.viewString + and (source[2] == '"' or source[2] == "'") + and rawLen > #literal then + local view = literal + local max = config.config.hover.viewStringMax + if #view > max then + view = view:sub(1, max) .. '...' + end + local md = markdown() + md:add('txt', view) + return md:string() + end +end + +local function asString(source) + local literal = guide.getLiteral(source) + if type(literal) ~= 'string' then + return nil + end + return asStringInRequire(source, literal) + or asStringView(source, literal) +end + +local function getBindComment(docGroup, base) + local lines = {} + for _, doc in ipairs(docGroup) do + if doc.type == 'doc.comment' then + lines[#lines+1] = doc.comment.text:sub(2) + elseif #lines > 0 and not base then + break + elseif doc == base then + break + else + lines = {} + end + end + if #lines == 0 then + return nil + end + return table.concat(lines, '\n') +end + +local function buildEnumChunk(docType, name) + local enums = vm.getDocEnums(docType) + if #enums == 0 then + return + end + local types = {} + for _, tp in ipairs(docType.types) do + types[#types+1] = tp[1] + end + local lines = {} + lines[#lines+1] = ('%s: %s'):format(name, table.concat(types)) + for _, enum in ipairs(enums) do + lines[#lines+1] = (' %s %s%s'):format( + (enum.default and '->') + or (enum.additional and '+>') + or ' |', + enum[1], + enum.comment and (' -- %s'):format(enum.comment) or '' + ) + end + return table.concat(lines, '\n') +end + +local function getBindEnums(docGroup) + local mark = {} + local chunks = {} + local returnIndex = 0 + for _, doc in ipairs(docGroup) do + if doc.type == 'doc.param' then + local name = doc.param[1] + if mark[name] then + goto CONTINUE + end + mark[name] = true + chunks[#chunks+1] = buildEnumChunk(doc.extends, name) + elseif doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returnIndex = returnIndex + 1 + local name = rtn.name and rtn.name[1] or ('(return %d)'):format(returnIndex) + if mark[name] then + goto CONTINUE + end + mark[name] = true + chunks[#chunks+1] = buildEnumChunk(rtn, name) + end + end + ::CONTINUE:: + end + if #chunks == 0 then + return nil + end + return table.concat(chunks, '\n\n') +end + +local function tryDocFieldUpComment(source) + if source.type ~= 'doc.field' then + return + end + if not source.bindGroup then + return + end + local comment = getBindComment(source.bindGroup, source) + return comment +end + +local function tryDocComment(source) + if not source.bindDocs then + return + end + local comment = getBindComment(source.bindDocs) + local enums = getBindEnums(source.bindDocs) + local md = markdown() + if comment then + md:add('md', comment) + end + if enums then + md:add('lua', enums) + end + return md:string() +end + +local function tryDocOverloadToComment(source) + if source.type ~= 'doc.type.function' then + return + end + local doc = source.parent + if doc.type ~= 'doc.overload' + or not doc.bindSources then + return + end + for _, src in ipairs(doc.bindSources) do + local md = tryDocComment(src) + if md then + return md + end + end +end + +return function (source) + if source.type == 'string' then + return asString(source) + end + return tryDocOverloadToComment(source) + or tryDocFieldUpComment(source) + or tryDocComment(source) +end diff --git a/script/core/hover/init.lua b/script/core/hover/init.lua new file mode 100644 index 00000000..96e01ab5 --- /dev/null +++ b/script/core/hover/init.lua @@ -0,0 +1,164 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local getLabel = require 'core.hover.label' +local getDesc = require 'core.hover.description' +local util = require 'utility' +local findSource = require 'core.find-source' +local lang = require 'language' + +local function eachFunctionAndOverload(value, callback) + callback(value) + if not value.bindDocs then + return + end + for _, doc in ipairs(value.bindDocs) do + if doc.type == 'doc.overload' then + callback(doc.overload) + end + end +end + +local function getHoverAsFunction(source) + local values = vm.getDefs(source, 'deep') + local desc = getDesc(source) + local labels = {} + local defs = 0 + local protos = 0 + local other = 0 + local oop = source.type == 'method' + or source.type == 'getmethod' + or source.type == 'setmethod' + local mark = {} + for _, def in ipairs(values) do + def = guide.getObjectValue(def) or def + if def.type == 'function' + or def.type == 'doc.type.function' then + eachFunctionAndOverload(def, function (value) + if mark[value] then + return + end + mark[value] =true + local label = getLabel(value, oop) + if label then + defs = defs + 1 + labels[label] = (labels[label] or 0) + 1 + if labels[label] == 1 then + protos = protos + 1 + end + end + desc = desc or getDesc(value) + end) + elseif def.type == 'table' + or def.type == 'boolean' + or def.type == 'string' + or def.type == 'number' then + other = other + 1 + desc = desc or getDesc(def) + end + end + + if defs == 1 and other == 0 then + return { + label = next(labels), + source = source, + description = desc, + } + end + + local lines = {} + if defs > 1 then + lines[#lines+1] = lang.script('HOVER_MULTI_DEF_PROTO', defs, protos) + end + if other > 0 then + lines[#lines+1] = lang.script('HOVER_MULTI_PROTO_NOT_FUNC', other) + end + if defs > 1 then + for label, count in util.sortPairs(labels) do + lines[#lines+1] = ('(%d) %s'):format(count, label) + end + else + lines[#lines+1] = next(labels) + end + local label = table.concat(lines, '\n') + return { + label = label, + source = source, + description = desc, + } +end + +local function getHoverAsValue(source) + local oop = source.type == 'method' + or source.type == 'getmethod' + or source.type == 'setmethod' + local label = getLabel(source, oop) + local desc = getDesc(source) + if not desc then + local values = vm.getDefs(source, 'deep') + for _, def in ipairs(values) do + desc = getDesc(def) + if desc then + break + end + end + end + return { + label = label, + source = source, + description = desc, + } +end + +local function getHoverAsDocName(source) + local label = getLabel(source) + local desc = getDesc(source) + return { + label = label, + source = source, + description = desc, + } +end + +local function getHover(source) + if source.type == 'doc.type.name' then + return getHoverAsDocName(source) + end + local isFunction = vm.hasInferType(source, 'function', 'deep') + if isFunction then + return getHoverAsFunction(source) + else + return getHoverAsValue(source) + end +end + +local accept = { + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['field'] = true, + ['method'] = true, + ['string'] = true, + ['number'] = true, + ['doc.type.name'] = true, +} + +local function getHoverByUri(uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + local source = findSource(ast, offset, accept) + if not source then + return nil + end + local hover = getHover(source) + return hover +end + +return { + get = getHover, + byUri = getHoverByUri, +} diff --git a/script/core/hover/label.lua b/script/core/hover/label.lua new file mode 100644 index 00000000..d785bc27 --- /dev/null +++ b/script/core/hover/label.lua @@ -0,0 +1,211 @@ +local buildName = require 'core.hover.name' +local buildArg = require 'core.hover.arg' +local buildReturn = require 'core.hover.return' +local buildTable = require 'core.hover.table' +local vm = require 'vm' +local util = require 'utility' +local guide = require 'parser.guide' +local lang = require 'language' +local config = require 'config' +local files = require 'files' + +local function asFunction(source, oop) + local name = buildName(source, oop) + local arg = buildArg(source, oop) + local rtn = buildReturn(source) + local lines = {} + lines[1] = ('function %s(%s)'):format(name, arg) + lines[2] = rtn + return table.concat(lines, '\n') +end + +local function asDocFunction(source) + local name = buildName(source) + local arg = buildArg(source) + local rtn = buildReturn(source) + local lines = {} + lines[1] = ('function %s(%s)'):format(name, arg) + lines[2] = rtn + return table.concat(lines, '\n') +end + +local function asDocTypeName(source) + for _, doc in ipairs(vm.getDocTypes(source[1])) do + if doc.type == 'doc.class.name' then + return 'class ' .. source[1] + end + if doc.type == 'doc.alias.name' then + local extends = doc.parent.extends + return lang.script('HOVER_EXTENDS', vm.getInferType(extends)) + end + end +end + +local function asValue(source, title) + local name = buildName(source) + local infers = vm.getInfers(source, 'deep') + local type = vm.getInferType(source, 'deep') + local class = vm.getClass(source, 'deep') + local literal = vm.getInferLiteral(source, 'deep') + local cont + if type ~= 'string' and not type:find('%[%]$') then + if #vm.getFields(source, 'deep') > 0 + or vm.hasInferType(source, 'table', 'deep') then + cont = buildTable(source) + end + end + local pack = {} + pack[#pack+1] = title + pack[#pack+1] = name .. ':' + if cont and type == 'table' then + type = nil + end + if class then + pack[#pack+1] = class + else + pack[#pack+1] = type + end + if literal then + pack[#pack+1] = '=' + pack[#pack+1] = literal + end + if cont then + pack[#pack+1] = cont + end + return table.concat(pack, ' ') +end + +local function asLocal(source) + return asValue(source, 'local') +end + +local function asGlobal(source) + return asValue(source, 'global') +end + +local function isGlobalField(source) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + if source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' then + local node = source.node + if node.type == 'setglobal' + or node.type == 'getglobal' then + return true + end + return isGlobalField(node) + elseif source.type == 'tablefield' then + local parent = source.parent + if parent.type == 'setglobal' + or parent.type == 'getglobal' then + return true + end + return isGlobalField(parent) + else + return false + end +end + +local function asField(source) + if isGlobalField(source) then + return asGlobal(source) + end + return asValue(source, 'field') +end + +local function asDocField(source) + local name = source.field[1] + local class + for _, doc in ipairs(source.bindGroup) do + if doc.type == 'doc.class' then + class = doc + break + end + end + if not class then + return ('field ?.%s: %s'):format( + name, + vm.getInferType(source.extends) + ) + end + return ('field %s.%s: %s'):format( + class.class[1], + name, + vm.getInferType(source.extends) + ) +end + +local function asString(source) + local str = source[1] + if type(str) ~= 'string' then + return '' + end + local len = #str + local charLen = util.utf8Len(str, 1, -1) + if len == charLen then + return lang.script('HOVER_STRING_BYTES', len) + else + return lang.script('HOVER_STRING_CHARACTERS', len, charLen) + end +end + +local function formatNumber(n) + local str = ('%.10f'):format(n) + str = str:gsub('%.?0*$', '') + return str +end + +local function asNumber(source) + if not config.config.hover.viewNumber then + return nil + end + local num = source[1] + if type(num) ~= 'number' then + return nil + end + local uri = guide.getUri(source) + local text = files.getText(uri) + if not text then + return nil + end + local raw = text:sub(source.start, source.finish) + if not raw or not raw:find '[^%-%d%.]' then + return nil + end + return formatNumber(num) +end + +return function (source, oop) + if source.type == 'function' then + return asFunction(source, oop) + elseif source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' then + return asLocal(source) + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + return asGlobal(source) + elseif source.type == 'getfield' + or source.type == 'setfield' + or source.type == 'getmethod' + or source.type == 'setmethod' + or source.type == 'tablefield' + or source.type == 'field' + or source.type == 'method' then + return asField(source) + elseif source.type == 'string' then + return asString(source) + elseif source.type == 'number' then + return asNumber(source) + elseif source.type == 'doc.type.function' then + return asDocFunction(source) + elseif source.type == 'doc.type.name' then + return asDocTypeName(source) + elseif source.type == 'doc.field' then + return asDocField(source) + end +end diff --git a/script/core/hover/name.lua b/script/core/hover/name.lua new file mode 100644 index 00000000..9ad32e09 --- /dev/null +++ b/script/core/hover/name.lua @@ -0,0 +1,101 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local buildName + +local function asLocal(source) + local name = guide.getName(source) + if not source.attrs then + return name + end + local label = {} + label[#label+1] = name + for _, attr in ipairs(source.attrs) do + label[#label+1] = ('<%s>'):format(attr[1]) + end + return table.concat(label, ' ') +end + +local function asField(source, oop) + local class + if source.node.type ~= 'getglobal' then + class = vm.getClass(source.node, 'deep') + end + local node = class or guide.getName(source.node) or '?' + local method = guide.getName(source) + if oop then + return ('%s:%s'):format(node, method) + else + return ('%s.%s'):format(node, method) + end +end + +local function asTableField(source) + if not source.field then + return + end + return guide.getName(source.field) +end + +local function asGlobal(source) + return guide.getName(source) +end + +local function asDocFunction(source) + local doc = guide.getParentType(source, 'doc.type') + or guide.getParentType(source, 'doc.overload') + if not doc or not doc.bindSources then + return '' + end + for _, src in ipairs(doc.bindSources) do + local name = buildName(src) + if name ~= '' then + return name + end + end + return '' +end + +local function asDocField(source) + return source.field[1] +end + +function buildName(source, oop) + if oop == nil then + oop = source.type == 'setmethod' + or source.type == 'getmethod' + end + if source.type == 'local' + or source.type == 'getlocal' + or source.type == 'setlocal' then + return asLocal(source) or '' + end + if source.type == 'setglobal' + or source.type == 'getglobal' then + return asGlobal(source) or '' + end + if source.type == 'setmethod' + or source.type == 'getmethod' then + return asField(source, true) or '' + end + if source.type == 'setfield' + or source.type == 'getfield' then + return asField(source, oop) or '' + end + if source.type == 'tablefield' then + return asTableField(source) or '' + end + if source.type == 'doc.type.function' then + return asDocFunction(source) + end + if source.type == 'doc.field' then + return asDocField(source) + end + local parent = source.parent + if parent then + return buildName(parent, oop) + end + return '' +end + +return buildName diff --git a/script/core/hover/return.lua b/script/core/hover/return.lua new file mode 100644 index 00000000..3829dbed --- /dev/null +++ b/script/core/hover/return.lua @@ -0,0 +1,125 @@ +local guide = require 'parser.guide' +local vm = require 'vm' + +local function mergeTypes(returns) + if type(returns) == 'string' then + return returns + end + return guide.mergeTypes(returns) +end + +local function getReturnDualByDoc(source) + local docs = source.bindDocs + if not docs then + return + end + local dual + for _, doc in ipairs(docs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + if not dual then + dual = {} + end + dual[#dual+1] = { rtn } + end + end + end + return dual +end + +local function getReturnDualByGrammar(source) + if not source.returns then + return nil + end + local dual + for _, rtn in ipairs(source.returns) do + if not dual then + dual = {} + end + for n = 1, #rtn do + if not dual[n] then + dual[n] = {} + end + dual[n][#dual[n]+1] = rtn[n] + end + end + return dual +end + +local function asFunction(source) + local dual = getReturnDualByDoc(source) + or getReturnDualByGrammar(source) + if not dual then + return + end + local returns = {} + for i, rtn in ipairs(dual) do + local line = {} + local types = {} + if i == 1 then + line[#line+1] = ' -> ' + else + line[#line+1] = ('% 3d. '):format(i) + end + for n = 1, #rtn do + local values = vm.getInfers(rtn[n]) + for _, value in ipairs(values) do + if value.type then + for tp in value.type:gmatch '[^|]+' do + types[#types+1] = tp + end + end + end + end + if #types > 0 or rtn[1] then + local tp = mergeTypes(types) or 'any' + if rtn[1].name then + line[#line+1] = ('%s%s: %s'):format( + rtn[1].name[1], + rtn[1].optional and '?' or '', + tp + ) + else + line[#line+1] = ('%s%s'):format( + tp, + rtn[1].optional and '?' or '' + ) + end + else + break + end + returns[i] = table.concat(line) + end + if #returns == 0 then + return nil + end + return table.concat(returns, '\n') +end + +local function asDocFunction(source) + if not source.returns or #source.returns == 0 then + return nil + end + local returns = {} + for i, rtn in ipairs(source.returns) do + local rtnText = ('%s%s'):format( + vm.getInferType(rtn), + rtn.optional and '?' or '' + ) + if i == 1 then + returns[#returns+1] = (' -> %s'):format(rtnText) + else + returns[#returns+1] = ('% 3d. %s'):format(i, rtnText) + end + end + return table.concat(returns, '\n') +end + +return function (source) + if source.type == 'function' then + return asFunction(source) + end + if source.type == 'doc.type.function' then + return asDocFunction(source) + end +end diff --git a/script/core/hover/table.lua b/script/core/hover/table.lua new file mode 100644 index 00000000..02be5271 --- /dev/null +++ b/script/core/hover/table.lua @@ -0,0 +1,257 @@ +local vm = require 'vm' +local util = require 'utility' +local guide = require 'parser.guide' +local config = require 'config' +local lang = require 'language' + +local function getKey(src) + local key = vm.getKeyName(src) + if not key or #key <= 2 then + if not src.index then + return '[any]' + end + local class = vm.getClass(src.index) + if class then + return ('[%s]'):format(class) + end + local tp = vm.getInferType(src.index) + if tp then + return ('[%s]'):format(tp) + end + return '[any]' + end + local ktype = key:sub(1, 2) + key = key:sub(3) + if ktype == 's|' then + if key:match '^[%a_][%w_]*$' then + return key + else + return ('[%s]'):format(util.viewLiteral(key)) + end + end + return ('[%s]'):format(key) +end + +local function getFieldFast(src) + local value = guide.getObjectValue(src) or src + if not value then + return 'any' + end + if value.type == 'boolean' then + return value.type, util.viewLiteral(value[1]) + end + if value.type == 'number' + or value.type == 'integer' then + if math.tointeger(value[1]) then + if config.config.runtime.version == 'Lua 5.3' + or config.config.runtime.version == 'Lua 5.4' then + return 'integer', util.viewLiteral(value[1]) + end + end + return value.type, util.viewLiteral(value[1]) + end + if value.type == 'table' + or value.type == 'function' then + return value.type + end + if value.type == 'string' then + local literal = value[1] + if type(literal) == 'string' and #literal >= 50 then + literal = literal:sub(1, 47) .. '...' + end + return value.type, util.viewLiteral(literal) + end +end + +local function getFieldFull(src) + local tp = vm.getInferType(src) + --local class = vm.getClass(src) + local literal = vm.getInferLiteral(src) + if type(literal) == 'string' and #literal >= 50 then + literal = literal:sub(1, 47) .. '...' + end + return tp, literal +end + +local function getField(src, timeUp, mark, key) + if src.type == 'table' + or src.type == 'function' then + return nil + end + if src.parent then + if src.type == 'string' + or src.type == 'boolean' + or src.type == 'number' + or src.type == 'integer' then + if src.parent.type == 'tableindex' + or src.parent.type == 'setindex' + or src.parent.type == 'getindex' then + if src.parent.index == src then + src = src.parent + end + end + end + end + local tp, literal + tp, literal = getFieldFast(src) + if tp then + return tp, literal + end + if timeUp or mark[key] then + return nil + end + mark[key] = true + tp, literal = getFieldFull(src) + if tp then + return tp, literal + end + return nil +end + +local function buildAsHash(classes, literals) + local keys = {} + for k in pairs(classes) do + keys[#keys+1] = k + end + table.sort(keys) + local lines = {} + lines[#lines+1] = '{' + for _, key in ipairs(keys) do + local class = classes[key] + local literal = literals[key] + if literal then + lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + else + lines[#lines+1] = (' %s: %s,'):format(key, class) + end + end + lines[#lines+1] = '}' + return table.concat(lines, '\n') +end + +local function buildAsConst(classes, literals) + local keys = {} + for k in pairs(classes) do + keys[#keys+1] = k + end + table.sort(keys, function (a, b) + return tonumber(literals[a]) < tonumber(literals[b]) + end) + local lines = {} + lines[#lines+1] = '{' + for _, key in ipairs(keys) do + local class = classes[key] + local literal = literals[key] + if literal then + lines[#lines+1] = (' %s: %s = %s,'):format(key, class, literal) + else + lines[#lines+1] = (' %s: %s,'):format(key, class) + end + end + lines[#lines+1] = '}' + return table.concat(lines, '\n') +end + +local function mergeLiteral(literals) + local results = {} + local mark = {} + for _, value in ipairs(literals) do + if not mark[value] then + mark[value] = true + results[#results+1] = value + end + end + if #results == 0 then + return nil + end + table.sort(results) + return table.concat(results, '|') +end + +local function mergeTypes(types) + local results = {} + local mark = { + -- 讲道理table的keyvalue不会是nil + ['nil'] = true, + } + for _, tv in ipairs(types) do + for tp in tv:gmatch '[^|]+' do + if not mark[tp] then + mark[tp] = true + results[#results+1] = tp + end + end + end + return guide.mergeTypes(results) +end + +local function clearClasses(classes) + classes['[nil]'] = nil + classes['[any]'] = nil + classes['[string]'] = nil +end + +return function (source) + local literals = {} + local classes = {} + local clock = os.clock() + local timeUp + local mark = {} + local fields = vm.getFields(source, 'deep') + local keyCount = 0 + for _, src in ipairs(fields) do + local key = getKey(src) + if not key then + goto CONTINUE + end + if not classes[key] then + classes[key] = {} + keyCount = keyCount + 1 + end + if not literals[key] then + literals[key] = {} + end + if not TEST and os.clock() - clock > config.config.hover.fieldInfer / 1000.0 then + timeUp = true + end + local class, literal = getField(src, timeUp, mark, key) + if literal == 'nil' then + literal = nil + end + classes[key][#classes[key]+1] = class + literals[key][#literals[key]+1] = literal + if keyCount >= 1000 then + break + end + ::CONTINUE:: + end + + clearClasses(classes) + + for key, class in pairs(classes) do + literals[key] = mergeLiteral(literals[key]) + classes[key] = mergeTypes(class) + end + + if not next(classes) then + return '{}' + end + + local intValue = true + for key, class in pairs(classes) do + if class ~= 'integer' or not tonumber(literals[key]) then + intValue = false + break + end + end + local result + if intValue then + result = buildAsConst(classes, literals) + else + result = buildAsHash(classes, literals) + end + if timeUp then + result = ('\n--%s\n%s'):format(lang.script.HOVER_TABLE_TIME_UP, result) + end + return result +end diff --git a/script/core/keyword.lua b/script/core/keyword.lua new file mode 100644 index 00000000..1cbeb78d --- /dev/null +++ b/script/core/keyword.lua @@ -0,0 +1,264 @@ +local define = require 'proto.define' +local guide = require 'parser.guide' + +local keyWordMap = { + {'do', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'do .. end', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[$0 end]], + } + else + results[#results+1] = { + label = 'do .. end', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +do + $0 +end]], + } + end + return true + end, function (ast, start) + return guide.eachSourceContain(ast.ast, start, function (source) + if source.type == 'while' + or source.type == 'in' + or source.type == 'loop' then + for i = 1, #source.keyword do + if start == source.keyword[i] then + return true + end + end + end + end) + end}, + {'and'}, + {'break'}, + {'else'}, + {'elseif', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'elseif .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[$1 then]], + } + else + results[#results+1] = { + label = 'elseif .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[elseif $1 then]], + } + end + return true + end}, + {'end'}, + {'false'}, + {'for', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'for .. in', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:key, value} in ${2:pairs(${3:t})} do + $0 +end]] + } + results[#results+1] = { + label = 'for i = ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:i} = ${2:1}, ${3:10, 1} do + $0 +end]] + } + else + results[#results+1] = { + label = 'for .. in', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +for ${1:key, value} in ${2:pairs(${3:t})} do + $0 +end]] + } + results[#results+1] = { + label = 'for i = ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +for ${1:i} = ${2:1}, ${3:10, 1} do + $0 +end]] + } + end + return true + end}, + {'function', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'function ()', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +$1($2) + $0 +end]] + } + else + results[#results+1] = { + label = 'function ()', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +function $1($2) + $0 +end]] + } + end + return true + end}, + {'goto'}, + {'if', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'if .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +$1 then + $0 +end]] + } + else + results[#results+1] = { + label = 'if .. then', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +if $1 then + $0 +end]] + } + end + return true + end}, + {'in', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'in ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:pairs(${2:t})} do + $0 +end]] + } + else + results[#results+1] = { + label = 'in ..', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +in ${1:pairs(${2:t})} do + $0 +end]] + } + end + return true + end}, + {'local', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'local function', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +function $1($2) + $0 +end]] + } + else + results[#results+1] = { + label = 'local function', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +local function $1($2) + $0 +end]] + } + end + return false + end}, + {'nil'}, + {'not'}, + {'or'}, + {'repeat', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'repeat .. until', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[$0 until $1]] + } + else + results[#results+1] = { + label = 'repeat .. until', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +repeat + $0 +until $1]] + } + end + return true + end}, + {'return', function (hasSpace, results) + if not hasSpace then + results[#results+1] = { + label = 'do return end', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[do return $1end]] + } + end + return false + end}, + {'then'}, + {'true'}, + {'until'}, + {'while', function (hasSpace, results) + if hasSpace then + results[#results+1] = { + label = 'while .. do', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +${1:true} do + $0 +end]] + } + else + results[#results+1] = { + label = 'while .. do', + kind = define.CompletionItemKind.Snippet, + insertTextFormat = 2, + insertText = [[ +while ${1:true} do + $0 +end]] + } + end + return true + end}, +} + +return keyWordMap diff --git a/script/core/matchkey.lua b/script/core/matchkey.lua new file mode 100644 index 00000000..45c86eff --- /dev/null +++ b/script/core/matchkey.lua @@ -0,0 +1,33 @@ +return function (me, other, fast) + if me == other then + return true + end + if me == '' then + return true + end + if #me > #other then + return false + end + local lMe = me:lower() + local lOther = other:lower() + if lMe == lOther:sub(1, #lMe) then + return true + end + if fast and me:sub(1, 1) ~= other:sub(1, 1) then + return false + end + local chars = {} + for i = 1, #lOther do + local c = lOther:sub(i, i) + chars[c] = (chars[c] or 0) + 1 + end + for i = 1, #lMe do + local c = lMe:sub(i, i) + if chars[c] and chars[c] > 0 then + chars[c] = chars[c] - 1 + else + return false + end + end + return true +end diff --git a/script/core/reference.lua b/script/core/reference.lua new file mode 100644 index 00000000..d7d3df03 --- /dev/null +++ b/script/core/reference.lua @@ -0,0 +1,116 @@ +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm' +local findSource = require 'core.find-source' + +local function isValidFunction(source, offset) + -- 必须点在 `function` 这个单词上才能查找函数引用 + return offset >= source.start and offset < source.start + #'function' +end + +local function sortResults(results) + -- 先按照顺序排序 + table.sort(results, function (a, b) + local u1 = guide.getUri(a.target) + local u2 = guide.getUri(b.target) + if u1 == u2 then + return a.target.start < b.target.start + else + return u1 < u2 + end + end) + -- 如果2个结果处于嵌套状态,则取范围小的那个 + local lf, lu + for i = #results, 1, -1 do + local res = results[i].target + local f = res.finish + local uri = guide.getUri(res) + if lf and f > lf and uri == lu then + table.remove(results, i) + else + lu = uri + lf = f + end + end +end + +local accept = { + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['label'] = true, + ['goto'] = true, + ['field'] = true, + ['method'] = true, + ['setindex'] = true, + ['getindex'] = true, + ['tableindex'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['function'] = true, + + ['doc.type.name'] = true, + ['doc.class.name'] = true, + ['doc.extends.name'] = true, + ['doc.alias.name'] = true, +} + +return function (uri, offset) + local ast = files.getAst(uri) + if not ast then + return nil + end + + local source = findSource(ast, offset, accept) + if not source then + return nil + end + if source.type == 'function' and not isValidFunction(source, offset) and not TEST then + return nil + end + + local results = {} + for _, src in ipairs(vm.getRefs(source, 'deep')) do + local root = guide.getRoot(src) + if not root then + goto CONTINUE + end + if vm.isMetaFile(root.uri) then + goto CONTINUE + end + if ( src.type == 'doc.class.name' + or src.type == 'doc.type.name' + ) + and source.type ~= 'doc.type.name' + and source.type ~= 'doc.class.name' then + goto CONTINUE + end + if src.type == 'setfield' + or src.type == 'getfield' + or src.type == 'tablefield' then + src = src.field + elseif src.type == 'setindex' + or src.type == 'getindex' + or src.type == 'tableindex' then + src = src.index + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + src = src.method + elseif src.type == 'table' and src.parent.type ~= 'return' then + goto CONTINUE + end + results[#results+1] = { + target = src, + uri = files.getOriginUri(root.uri), + } + ::CONTINUE:: + end + + if #results == 0 then + return nil + end + + sortResults(results) + + return results +end diff --git a/script/core/rename.lua b/script/core/rename.lua new file mode 100644 index 00000000..89298bdd --- /dev/null +++ b/script/core/rename.lua @@ -0,0 +1,448 @@ +local files = require 'files' +local vm = require 'vm' +local guide = require 'parser.guide' +local proto = require 'proto' +local define = require 'proto.define' +local util = require 'utility' +local findSource = require 'core.find-source' +local ws = require 'workspace' + +local Forcing + +local function askForcing(str) + -- TODO 总是可以替换 + do return true end + if TEST then + return true + end + if Forcing ~= nil then + return Forcing + end + local version = files.globalVersion + -- TODO + local item = proto.awaitRequest('window/showMessageRequest', { + type = define.MessageType.Warning, + message = ('[%s]不是有效的标识符,是否强制替换?'):format(str), + actions = { + { + title = '强制替换', + }, + { + title = '取消', + }, + } + }) + if version ~= files.globalVersion then + Forcing = false + proto.notify('window/showMessage', { + type = define.MessageType.Warning, + message = '文件发生了变化,替换取消。' + }) + return false + end + if not item then + Forcing = false + return false + end + if item.title == '强制替换' then + Forcing = true + return true + else + Forcing = false + return false + end +end + +local function askForMultiChange(results, newname) + -- TODO 总是可以替换 + do return true end + if TEST then + return true + end + local uris = {} + for _, result in ipairs(results) do + local uri = result.uri + if not uris[uri] then + uris[uri] = 0 + uris[#uris+1] = uri + end + uris[uri] = uris[uri] + 1 + end + if #uris <= 1 then + return true + end + + local version = files.globalVersion + -- TODO + local item = proto.awaitRequest('window/showMessageRequest', { + type = define.MessageType.Warning, + message = ('将修改 %d 个文件,共 %d 处。'):format( + #uris, + #results + ), + actions = { + { + title = '继续', + }, + { + title = '放弃', + }, + } + }) + if version ~= files.globalVersion then + proto.notify('window/showMessage', { + type = define.MessageType.Warning, + message = '文件发生了变化,替换取消。' + }) + return false + end + if item and item.title == '继续' then + local fileList = {} + for _, uri in ipairs(uris) do + fileList[#fileList+1] = ('%s (%d)'):format(uri, uris[uri]) + end + + log.debug(('Renamed [%s]\r\n%s'):format(newname, table.concat(fileList, '\r\n'))) + return true + end + return false +end + +local function trim(str) + return str:match '^%s*(%S+)%s*$' +end + +local function isValidName(str) + return str:match '^[%a_][%w_]*$' +end + +local function isValidGlobal(str) + for s in str:gmatch '[^%.]*' do + if not isValidName(trim(s)) then + return false + end + end + return true +end + +local function isValidFunctionName(str) + if isValidGlobal(str) then + return true + end + local pos = str:find(':', 1, true) + if not pos then + return false + end + return isValidGlobal(trim(str:sub(1, pos-1))) + and isValidName(trim(str:sub(pos+1))) +end + +local function isFunctionGlobalName(source) + local parent = source.parent + if parent.type ~= 'setglobal' then + return false + end + local value = parent.value + if not value.type ~= 'function' then + return false + end + return value.start <= parent.start +end + +local function renameLocal(source, newname, callback) + if isValidName(newname) then + callback(source, source.start, source.finish, newname) + return + end + if askForcing(newname) then + callback(source, source.start, source.finish, newname) + end +end + +local function renameField(source, newname, callback) + if isValidName(newname) then + callback(source, source.start, source.finish, newname) + return true + end + local parent = source.parent + if parent.type == 'setfield' + or parent.type == 'getfield' then + local dot = parent.dot + local newstr = '[' .. util.viewString(newname) .. ']' + callback(source, dot.start, source.finish, newstr) + elseif parent.type == 'tablefield' then + local newstr = '[' .. util.viewString(newname) .. ']' + callback(source, source.start, source.finish, newstr) + elseif parent.type == 'getmethod' then + if not askForcing(newname) then + return false + end + callback(source, source.start, source.finish, newname) + elseif parent.type == 'setmethod' then + local uri = guide.getUri(source) + local text = files.getText(uri) + local func = parent.value + -- function mt:name () end --> mt['newname'] = function (self) end + local newstr = string.format('%s[%s] = function ' + , text:sub(parent.start, parent.node.finish) + , util.viewString(newname) + ) + callback(source, func.start, parent.finish, newstr) + local pl = text:find('(', parent.finish, true) + if pl then + if func.args then + callback(source, pl + 1, pl, 'self, ') + else + callback(source, pl + 1, pl, 'self') + end + end + end + return true +end + +local function renameGlobal(source, newname, callback) + if isValidGlobal(newname) then + callback(source, source.start, source.finish, newname) + return true + end + if isValidFunctionName(newname) then + if not isFunctionGlobalName(source) then + askForcing(newname) + end + callback(source, source.start, source.finish, newname) + return true + end + local newstr = '_ENV[' .. util.viewString(newname) .. ']' + -- function name () end --> _ENV['newname'] = function () end + if source.value and source.value.type == 'function' + and source.value.start < source.start then + callback(source, source.value.start, source.finish, newstr .. ' = function ') + return true + end + callback(source, source.start, source.finish, newstr) + return true +end + +local function ofLocal(source, newname, callback) + renameLocal(source, newname, callback) + if source.ref then + for _, ref in ipairs(source.ref) do + renameLocal(ref, newname, callback) + end + end +end + +local function ofFieldThen(key, src, newname, callback) + if vm.getKeyName(src) ~= key then + return + end + if src.type == 'tablefield' + or src.type == 'getfield' + or src.type == 'setfield' then + src = src.field + elseif src.type == 'tableindex' + or src.type == 'getindex' + or src.type == 'setindex' then + src = src.index + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + src = src.method + end + if src.type == 'string' then + local quo = src[2] + local text = util.viewString(newname, quo) + callback(src, src.start, src.finish, text) + return + elseif src.type == 'field' + or src.type == 'method' then + local suc = renameField(src, newname, callback) + if not suc then + return + end + elseif src.type == 'setglobal' + or src.type == 'getglobal' then + local suc = renameGlobal(src, newname, callback) + if not suc then + return + end + end +end + +local function ofField(source, newname, callback) + local key = guide.getKeyName(source) + local node + if source.type == 'tablefield' + or source.type == 'tableindex' then + node = source.parent + else + node = source.node + end + for _, src in ipairs(vm.getFields(node, 'deep')) do + ofFieldThen(key, src, newname, callback) + end +end + +local function ofGlobal(source, newname, callback) + local key = guide.getKeyName(source) + for _, src in ipairs(vm.getRefs(source, 'deep')) do + ofFieldThen(key, src, newname, callback) + end +end + +local function ofLabel(source, newname, callback) + if not isValidName(newname) and not askForcing(newname)then + return false + end + for _, src in ipairs(vm.getRefs(source, 'deep')) do + callback(src, src.start, src.finish, newname) + end +end + +local function rename(source, newname, callback) + if source.type == 'label' + or source.type == 'goto' then + return ofLabel(source, newname, callback) + elseif source.type == 'local' then + return ofLocal(source, newname, callback) + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + return ofLocal(source.node, newname, callback) + elseif source.type == 'field' + or source.type == 'method' + or source.type == 'index' then + return ofField(source.parent, newname, callback) + elseif source.type == 'setglobal' + or source.type == 'getglobal' then + return ofGlobal(source, newname, callback) + elseif source.type == 'string' + or source.type == 'number' + or source.type == 'boolean' then + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + return ofField(parent, newname, callback) + end + end + return +end + +local function prepareRename(source) + if source.type == 'label' + or source.type == 'goto' + or source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'field' + or source.type == 'method' + or source.type == 'tablefield' + or source.type == 'setglobal' + or source.type == 'getglobal' then + return source, source[1] + elseif source.type == 'string' + or source.type == 'number' + or source.type == 'boolean' then + local parent = source.parent + if not parent then + return nil + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + return source, source[1] + end + return nil + end + return nil +end + +local accept = { + ['label'] = true, + ['goto'] = true, + ['local'] = true, + ['setlocal'] = true, + ['getlocal'] = true, + ['field'] = true, + ['method'] = true, + ['tablefield'] = true, + ['setglobal'] = true, + ['getglobal'] = true, + ['string'] = true, + ['boolean'] = true, + ['number'] = true, +} + +local m = {} + +function m.rename(uri, pos, newname) + local ast = files.getAst(uri) + if not ast then + return nil + end + local source = findSource(ast, pos, accept) + if not source then + return nil + end + local results = {} + local mark = {} + + rename(source, newname, function (target, start, finish, text) + local turi = files.getOriginUri(guide.getUri(target)) + local uid = turi .. start + if mark[uid] then + return + end + mark[uid] = true + if files.isLibrary(turi) then + return + end + results[#results+1] = { + start = start, + finish = finish, + text = text, + uri = turi, + } + end) + + if Forcing == false then + Forcing = nil + return nil + end + + if #results == 0 then + return nil + end + + if not askForMultiChange(results, newname) then + return nil + end + + return results +end + +function m.prepareRename(uri, pos) + local ast = files.getAst(uri) + if not ast then + return nil + end + local source = findSource(ast, pos, accept) + if not source then + return + end + + local res, text = prepareRename(source) + if not res then + return nil + end + + return { + start = source.start, + finish = source.finish, + text = text, + } +end + +return m diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua new file mode 100644 index 00000000..e6b35cdd --- /dev/null +++ b/script/core/semantic-tokens.lua @@ -0,0 +1,161 @@ +local files = require 'files' +local guide = require 'parser.guide' +local await = require 'await' +local define = require 'proto.define' +local vm = require 'vm' +local util = require 'utility' + +local Care = {} +Care['setglobal'] = function (source, results) + local isLib = vm.isGlobalLibraryName(source[1]) + if not isLib then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.namespace, + modifieres = define.TokenModifiers.deprecated, + } + end +end +Care['getglobal'] = function (source, results) + local isLib = vm.isGlobalLibraryName(source[1]) + if not isLib then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.namespace, + modifieres = define.TokenModifiers.deprecated, + } + end +end +Care['tablefield'] = function (source, results) + local field = source.field + if not field then + return + end + results[#results+1] = { + start = field.start, + finish = field.finish, + type = define.TokenTypes.property, + modifieres = define.TokenModifiers.declaration, + } +end +Care['getlocal'] = function (source, results) + local loc = source.node + -- 1. 值为函数的局部变量 + local hasFunc + local node = loc.node + if node then + for _, ref in ipairs(node.ref) do + local def = ref.value + if def.type == 'function' then + hasFunc = true + break + end + end + end + if hasFunc then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.interface, + modifieres = define.TokenModifiers.declaration, + } + return + end + -- 2. 对象 + if source.parent.type == 'getmethod' + and source.parent.node == source then + return + end + -- 3. 函数的参数 + if loc.parent and loc.parent.type == 'funcargs' then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.parameter, + modifieres = define.TokenModifiers.declaration, + } + return + end + -- 4. 特殊变量 + if source[1] == '_ENV' + or source[1] == 'self' then + return + end + -- 5. 其他 + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.variable, + } +end +Care['setlocal'] = Care['getlocal'] +Care['doc.return.name'] = function (source, results) + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.parameter, + } +end + +local function buildTokens(results, text, lines) + local tokens = {} + local lastLine = 0 + local lastStartChar = 0 + for i, source in ipairs(results) do + local row, col = guide.positionOf(lines, source.start) + local start = guide.lineRange(lines, row) + local ucol = util.utf8Len(text, start, start + col - 1) + local line = row - 1 + local startChar = ucol - 1 + local deltaLine = line - lastLine + local deltaStartChar + if deltaLine == 0 then + deltaStartChar = startChar - lastStartChar + else + deltaStartChar = startChar + end + lastLine = line + lastStartChar = startChar + -- see https://microsoft.github.io/language-server-protocol/specifications/specification-3-16/#textDocument_semanticTokens + local len = i * 5 - 5 + tokens[len + 1] = deltaLine + tokens[len + 2] = deltaStartChar + tokens[len + 3] = source.finish - source.start + 1 -- length + tokens[len + 4] = source.type + tokens[len + 5] = source.modifieres or 0 + end + return tokens +end + +return function (uri, start, finish) + local ast = files.getAst(uri) + local lines = files.getLines(uri) + local text = files.getText(uri) + if not ast then + return nil + end + + local results = {} + local count = 0 + guide.eachSourceBetween(ast.ast, start, finish, function (source) + local method = Care[source.type] + if not method then + return + end + method(source, results) + count = count + 1 + if count % 100 == 0 then + await.delay() + end + end) + + table.sort(results, function (a, b) + return a.start < b.start + end) + + local tokens = buildTokens(results, text, lines) + + return tokens +end diff --git a/script/core/signature.lua b/script/core/signature.lua new file mode 100644 index 00000000..dad38924 --- /dev/null +++ b/script/core/signature.lua @@ -0,0 +1,106 @@ +local files = require 'files' +local guide = require 'parser.guide' +local vm = require 'vm' +local hoverLabel = require 'core.hover.label' +local hoverDesc = require 'core.hover.description' + +local function findNearCall(uri, ast, pos) + local text = files.getText(uri) + -- 检查 `f()$` 的情况,注意要区别于 `f($` + if text:sub(pos, pos) == ')' then + return nil + end + + local nearCall + guide.eachSourceContain(ast.ast, pos, function (src) + if src.type == 'call' + or src.type == 'table' + or src.type == 'function' then + if not nearCall or nearCall.start < src.start then + nearCall = src + end + end + end) + if not nearCall then + return nil + end + if nearCall.type ~= 'call' then + return nil + end + return nearCall +end + +local function makeOneSignature(source, oop, index) + local label = hoverLabel(source, oop) + -- 去掉返回值 + label = label:gsub('%s*->.+', '') + local params = {} + local i = 0 + for start, finish in label:gmatch '[%(%)%,]%s*().-()%s*%f[%(%)%,%[%]]' do + i = i + 1 + params[i] = { + label = {start, finish-1}, + } + end + -- 不定参数 + if index > i and i > 0 then + local lastLabel = params[i].label + local text = label:sub(lastLabel[1], lastLabel[2]) + if text == '...' then + index = i + end + end + return { + label = label, + params = params, + index = index, + description = hoverDesc(source), + } +end + +local function makeSignatures(call, pos) + local node = call.node + local oop = node.type == 'method' + or node.type == 'getmethod' + or node.type == 'setmethod' + local index + local args = call.args + if args then + for i, arg in ipairs(args) do + if arg.start <= pos and arg.finish >= pos then + index = i + break + end + end + if not index then + index = #args + 1 + end + else + index = 1 + end + local signs = {} + local defs = vm.getDefs(node, 'deep') + for _, src in ipairs(defs) do + if src.type == 'function' + or src.type == 'doc.type.function' then + signs[#signs+1] = makeOneSignature(src, oop, index) + end + end + return signs +end + +return function (uri, pos) + local ast = files.getAst(uri) + if not ast then + return nil + end + local call = findNearCall(uri, ast, pos) + if not call then + return nil + end + local signs = makeSignatures(call, pos) + if not signs or #signs == 0 then + return nil + end + return signs +end diff --git a/script/core/workspace-symbol.lua b/script/core/workspace-symbol.lua new file mode 100644 index 00000000..4fc6a854 --- /dev/null +++ b/script/core/workspace-symbol.lua @@ -0,0 +1,69 @@ +local files = require 'files' +local guide = require 'parser.guide' +local matchKey = require 'core.matchkey' +local define = require 'proto.define' +local await = require 'await' + +local function buildSource(uri, source, key, results) + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' then + local name = source[1] + if matchKey(key, name) then + results[#results+1] = { + name = name, + kind = define.SymbolKind.Variable, + uri = uri, + range = { source.start, source.finish }, + } + end + elseif source.type == 'setfield' + or source.type == 'tablefield' then + local field = source.field + local name = field[1] + if matchKey(key, name) then + results[#results+1] = { + name = name, + kind = define.SymbolKind.Field, + uri = uri, + range = { field.start, field.finish }, + } + end + elseif source.type == 'setmethod' then + local method = source.method + local name = method[1] + if matchKey(key, name) then + results[#results+1] = { + name = name, + kind = define.SymbolKind.Method, + uri = uri, + range = { method.start, method.finish }, + } + end + end +end + +local function searchFile(uri, key, results) + local ast = files.getAst(uri) + if not ast then + return + end + + guide.eachSource(ast.ast, function (source) + buildSource(uri, source, key, results) + end) +end + +return function (key) + local results = {} + + for uri in files.eachFile() do + searchFile(files.getOriginUri(uri), key, results) + if #results > 1000 then + break + end + await.delay() + end + + return results +end diff --git a/script/doctor.lua b/script/doctor.lua new file mode 100644 index 00000000..08ec69cf --- /dev/null +++ b/script/doctor.lua @@ -0,0 +1,380 @@ +local type = type +local next = next +local ipairs = ipairs +local rawget = rawget +local pcall = pcall +local getregistry = debug.getregistry +local getmetatable = debug.getmetatable +local getupvalue = debug.getupvalue +local getuservalue = debug.getuservalue +local getlocal = debug.getlocal +local getinfo = debug.getinfo +local maxinterger = math.maxinteger +local mathType = math.type +local tableConcat = table.concat +local _G = _G +local registry = getregistry() +local tableSort = table.sort + +_ENV = nil + +local m = {} + +local function getTostring(obj) + local mt = getmetatable(obj) + if not mt then + return nil + end + local toString = rawget(mt, '__tostring') + if not toString then + return nil + end + local suc, str = pcall(toString, obj) + if not suc then + return nil + end + if type(str) ~= 'string' then + return nil + end + return str +end + +local function formatName(obj) + local tp = type(obj) + if tp == 'nil' then + return 'nil:nil' + elseif tp == 'boolean' then + if obj == true then + return 'boolean:true' + else + return 'boolean:false' + end + elseif tp == 'number' then + if mathType(obj) == 'integer' then + return ('number:%d'):format(obj) + else + -- 如果浮点数可以完全表示为整数,那么就转换为整数 + local str = ('%.10f'):format(obj):gsub('%.?[0]+$', '') + if str:find('.', 1, true) then + -- 如果浮点数不能表示为整数,那么再加上它的精确表示法 + str = ('%s(%q)'):format(str, obj) + end + return 'number:' .. str + end + elseif tp == 'string' then + local str = ('%q'):format(obj) + if #str > 100 then + local new = ('%s...(len=%d)'):format(str:sub(1, 100), #str) + if #new < #str then + str = new + end + end + return 'string:' .. str + elseif tp == 'function' then + local info = getinfo(obj, 'S') + if info.what == 'c' then + return ('function:%p(C)'):format(obj) + elseif info.what == 'main' then + return ('function:%p(main)'):format(obj) + else + return ('function:%p(%s:%d-%d)'):format(obj, info.source, info.linedefined, info.lastlinedefined) + end + elseif tp == 'table' then + local id = getTostring(obj) + if not id then + if obj == _G then + id = '_G' + elseif obj == registry then + id = 'registry' + end + end + if id then + return ('table:%p(%s)'):format(obj, id) + else + return ('table:%p'):format(obj) + end + elseif tp == 'userdata' then + local id = getTostring(obj) + if id then + return ('userdata:%p(%s)'):format(obj, id) + else + return ('userdata:%p'):format(obj) + end + else + return ('%s:%p'):format(tp, obj) + end +end + +--- 内存快照 +---@return table +function m.snapshot() + local mark = {} + local find + + local function findTable(t, result) + result = result or {} + local mt = getmetatable(t) + local wk, wv + if mt then + local mode = rawget(mt, '__mode') + if type(mode) == 'string' then + if mode:find('k', 1, true) then + wk = true + end + if mode:find('v', 1, true) then + wv = true + end + end + end + for k, v in next, t do + if not wk then + local keyInfo = find(k) + if keyInfo then + result[#result+1] = { + type = 'key', + name = formatName(k), + info = keyInfo, + } + end + end + if not wv then + local valueInfo = find(v) + if valueInfo then + result[#result+1] = { + type = 'field', + name = formatName(k) .. '|' .. formatName(v), + info = valueInfo, + } + end + end + end + local MTInfo = find(getmetatable(t)) + if MTInfo then + result[#result+1] = { + type = 'metatable', + name = '', + info = MTInfo, + } + end + if #result == 0 then + return nil + end + return result + end + + local function findFunction(f, result, trd, stack) + result = result or {} + for i = 1, maxinterger do + local n, v = getupvalue(f, i) + if not n then + break + end + local valueInfo = find(v) + if valueInfo then + result[#result+1] = { + type = 'upvalue', + name = n, + info = valueInfo, + } + end + end + if trd then + for i = 1, maxinterger do + local n, l = getlocal(trd, stack, i) + if not n then + break + end + local valueInfo = find(l) + if valueInfo then + result[#result+1] = { + type = 'local', + name = n, + info = valueInfo, + } + end + end + end + if #result == 0 then + return nil + end + return result + end + + local function findUserData(u, result) + result = result or {} + for i = 1, maxinterger do + local v, b = getuservalue(u, i) + if not b then + break + end + local valueInfo = find(v) + if valueInfo then + result[#result+1] = { + type = 'uservalue', + name = formatName(i), + info = valueInfo, + } + end + end + local MTInfo = find(getmetatable(u)) + if MTInfo then + result[#result+1] = { + type = 'metatable', + name = '', + info = MTInfo, + } + end + if #result == 0 then + return nil + end + return result + end + + local function findThread(trd, result) + -- 不查找主线程,主线程一定是临时的(视为弱引用) + if trd == registry[1] then + return nil + end + result = result or {} + + for i = 1, maxinterger do + local info = getinfo(trd, i, 'Sf') + if not info then + break + end + local funcInfo = find(info.func, trd, i) + if funcInfo then + result[#result+1] = { + type = 'stack', + name = i .. '@' .. formatName(info.func), + info = funcInfo, + } + end + end + + if #result == 0 then + return nil + end + return result + end + + function find(obj, trd, stack) + if mark[obj] then + return mark[obj] + end + local tp = type(obj) + if tp == 'table' then + mark[obj] = {} + mark[obj] = findTable(obj, mark[obj]) + elseif tp == 'function' then + mark[obj] = {} + mark[obj] = findFunction(obj, mark[obj], trd, stack) + elseif tp == 'userdata' then + mark[obj] = {} + mark[obj] = findUserData(obj, mark[obj]) + elseif tp == 'thread' then + mark[obj] = {} + mark[obj] = findThread(obj, mark[obj]) + else + return nil + end + if mark[obj] then + mark[obj].object = obj + end + return mark[obj] + end + + return { + name = formatName(registry), + type = 'root', + info = find(registry), + } +end + +--- 寻找对象的引用 +---@return string +function m.catch(...) + local targets = {} + for _, target in ipairs {...} do + targets[target] = true + end + local report = m.snapshot() + local path = {} + local result = {} + local mark = {} + + local function push() + result[#result+1] = tableConcat(path, ' => ') + end + + local function search(t) + path[#path+1] = ('(%s)%s'):format(t.type, t.name) + local addTarget + if targets[t.info.object] then + targets[t.info.object] = nil + addTarget = t.info.object + push(t) + end + if not mark[t.info] then + mark[t.info] = true + for _, obj in ipairs(t.info) do + search(obj) + end + end + path[#path] = nil + if addTarget then + targets[addTarget] = true + end + end + + search(report) + + return result +end + +--- 生成一个报告 +---@return string +function m.report() + local snapshot = m.snapshot() + local cache = {} + local mark = {} + + local function scan(t) + local obj = t.info.object + local tp = type(obj) + if tp == 'table' + or tp == 'userdata' + or tp == 'function' + or tp == 'string' + or tp == 'thread' then + local point = ('%p'):format(obj) + if not cache[point] then + cache[point] = { + point = point, + count = 0, + name = formatName(obj), + } + end + cache[point].count = cache[point].count + 1 + end + if not mark[t.info] then + mark[t.info] = true + for _, child in ipairs(t.info) do + scan(child) + end + end + end + + scan(snapshot) + + local list = {} + for _, info in next, cache do + list[#list+1] = info + end + tableSort(list, function (a, b) + return a.name < b.name + end) + return list +end + +return m diff --git a/script/file-uri.lua b/script/file-uri.lua new file mode 100644 index 00000000..ba44f2e7 --- /dev/null +++ b/script/file-uri.lua @@ -0,0 +1,89 @@ +local platform = require 'bee.platform' + +local escPatt = '[^%w%-%.%_%~%/]' + +local function esc(c) + return ('%%%02X'):format(c:byte()) +end + +local function normalize(str) + return str:gsub('%%(%x%x)', function (n) + return string.char(tonumber(n, 16)) + end) +end + +local m = {} + +-- c:\my\files --> file:///c%3A/my/files +-- /usr/home --> file:///usr/home +-- \\server\share\some\path --> file://server/share/some/path + +--- path -> uri +---@param path string +---@return string uri +function m.encode(path) + local authority = '' + if platform.OS == 'Windows' then + path = path:gsub('\\', '/') + end + + if path:sub(1, 2) == '//' then + local idx = path:find('/', 3) + if idx then + authority = path:sub(3, idx) + path = path:sub(idx + 1) + if path == '' then + path = '/' + end + else + authority = path:sub(3) + path = '/' + end + end + + if path:sub(1, 1) ~= '/' then + path = '/' .. path + end + + -- lower-case windows drive letters in /C:/fff or C:/fff + local start, finish, drive = path:find '/(%u):' + if drive then + path = path:sub(1, start) .. drive:lower() .. path:sub(finish, -1) + end + + local uri = 'file://' + .. authority:gsub(escPatt, esc) + .. path:gsub(escPatt, esc) + return uri +end + +-- file:///c%3A/my/files --> c:\my\files +-- file:///usr/home --> /usr/home +-- file://server/share/some/path --> \\server\share\some\path + +--- uri -> path +---@param uri string +---@return string path +function m.decode(uri) + local scheme, authority, path = uri:match('([^:]*):?/?/?([^/]*)(.*)') + if not scheme then + return '' + end + scheme = normalize(scheme) + authority = normalize(authority) + path = normalize(path) + local value + if scheme == 'file' and #authority > 0 and #path > 1 then + value = '//' .. authority .. path + elseif path:match '/%a:' then + value = path:sub(2, 2):lower() .. path:sub(3) + else + value = path + end + if platform.OS == 'Windows' then + value = value:gsub('/', '\\') + end + return value +end + +return m diff --git a/script/files.lua b/script/files.lua new file mode 100644 index 00000000..4d34568d --- /dev/null +++ b/script/files.lua @@ -0,0 +1,438 @@ +local platform = require 'bee.platform' +local config = require 'config' +local glob = require 'glob' +local furi = require 'file-uri' +local parser = require 'parser' +local proto = require 'proto' +local lang = require 'language' +local await = require 'await' +local timer = require 'timer' + +local m = {} + +m.openMap = {} +m.libraryMap = {} +m.fileMap = {} +m.watchList = {} +m.notifyCache = {} +m.assocVersion = -1 +m.assocMatcher = nil +m.globalVersion = 0 +m.linesMap = setmetatable({}, { __mode = 'v' }) +m.astMap = setmetatable({}, { __mode = 'v' }) + +--- 打开文件 +---@param uri string +function m.open(uri) + local originUri = uri + if platform.OS == 'Windows' then + uri = uri:lower() + end + m.openMap[uri] = true + m.onWatch('open', originUri) +end + +--- 关闭文件 +---@param uri string +function m.close(uri) + local originUri = uri + if platform.OS == 'Windows' then + uri = uri:lower() + end + m.openMap[uri] = nil + m.onWatch('close', originUri) +end + +--- 是否打开 +---@param uri string +---@return boolean +function m.isOpen(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + return m.openMap[uri] == true +end + +--- 标记为库文件 +function m.setLibraryPath(uri, libraryPath) + if platform.OS == 'Windows' then + uri = uri:lower() + end + m.libraryMap[uri] = libraryPath +end + +--- 是否是库文件 +function m.isLibrary(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + return m.libraryMap[uri] ~= nil +end + +--- 获取库文件的根目录 +function m.getLibraryPath(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + return m.libraryMap[uri] +end + +function m.flushAllLibrary() + m.libraryMap = {} +end + +--- 是否存在 +---@return boolean +function m.exists(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + return m.fileMap[uri] ~= nil +end + +function m.asKey(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + return uri +end + +--- 设置文件文本 +---@param uri string +---@param text string +function m.setText(uri, text) + if not text then + return + end + local originUri = uri + if platform.OS == 'Windows' then + uri = uri:lower() + end + local create + if not m.fileMap[uri] then + m.fileMap[uri] = { + uri = originUri, + version = 0, + } + create = true + end + local file = m.fileMap[uri] + if file.text == text then + return + end + file.text = text + m.linesMap[uri] = nil + m.astMap[uri] = nil + file.cache = {} + file.cacheActiveTime = math.huge + file.version = file.version + 1 + m.globalVersion = m.globalVersion + 1 + await.close('files.version') + if create then + m.onWatch('create', originUri) + end + m.onWatch('update', originUri) +end + +--- 获取文件版本 +function m.getVersion(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + local file = m.fileMap[uri] + if not file then + return nil + end + return file.version +end + +--- 获取文件文本 +---@param uri string +---@return string text +function m.getText(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + local file = m.fileMap[uri] + if not file then + return nil + end + return file.text +end + +--- 移除文件 +---@param uri string +function m.remove(uri) + local originUri = uri + if platform.OS == 'Windows' then + uri = uri:lower() + end + local file = m.fileMap[uri] + if not file then + return + end + m.fileMap[uri] = nil + + m.globalVersion = m.globalVersion + 1 + await.close('files.version') + m.onWatch('remove', originUri) +end + +--- 移除所有文件 +function m.removeAll() + m.globalVersion = m.globalVersion + 1 + await.close('files.version') + for uri in pairs(m.fileMap) do + if not m.libraryMap[uri] then + m.fileMap[uri] = nil + m.astMap[uri] = nil + m.linesMap[uri] = nil + m.onWatch('remove', uri) + end + end + --m.notifyCache = {} +end + +--- 移除所有关闭的文件 +function m.removeAllClosed() + m.globalVersion = m.globalVersion + 1 + await.close('files.version') + for uri in pairs(m.fileMap) do + if not m.openMap[uri] + and not m.libraryMap[uri] then + m.fileMap[uri] = nil + m.astMap[uri] = nil + m.linesMap[uri] = nil + m.onWatch('remove', uri) + end + end + --m.notifyCache = {} +end + +--- 遍历文件 +function m.eachFile() + return pairs(m.fileMap) +end + +function m.compileAst(uri, text) + if not m.isOpen(uri) and #text >= config.config.workspace.preloadFileSize * 1000 then + if not m.notifyCache['preloadFileSize'] then + m.notifyCache['preloadFileSize'] = {} + m.notifyCache['skipLargeFileCount'] = 0 + end + if not m.notifyCache['preloadFileSize'][uri] then + m.notifyCache['preloadFileSize'][uri] = true + m.notifyCache['skipLargeFileCount'] = m.notifyCache['skipLargeFileCount'] + 1 + if m.notifyCache['skipLargeFileCount'] <= 3 then + local ws = require 'workspace' + proto.notify('window/showMessage', { + type = 3, + message = lang.script('WORKSPACE_SKIP_LARGE_FILE' + , ws.getRelativePath(uri) + , config.config.workspace.preloadFileSize + , #text / 1000 + ), + }) + end + end + return nil + end + local clock = os.clock() + local state, err = parser:compile(text + , 'lua' + , config.config.runtime.version + , { + special = config.config.runtime.special, + } + ) + local passed = os.clock() - clock + if passed > 0.1 then + log.warn(('Compile [%s] takes [%.3f] sec, size [%.3f] kb.'):format(uri, passed, #text / 1000)) + end + if state then + state.uri = uri + state.ast.uri = uri + if config.config.luadoc.enable then + parser:luadoc(state) + end + return state + else + log.error(err) + return nil + end +end + +--- 获取文件语法树 +---@param uri string +---@return table ast +function m.getAst(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + if uri ~= '' and not m.isLua(uri) then + return nil + end + local file = m.fileMap[uri] + if not file then + return nil + end + local ast = m.astMap[uri] + if not ast then + ast = m.compileAst(uri, file.text) + m.astMap[uri] = ast + end + file.cacheActiveTime = timer.clock() + return ast +end + +--- 获取文件行信息 +---@param uri string +---@return table lines +function m.getLines(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + local file = m.fileMap[uri] + if not file then + return nil + end + local lines = m.linesMap[uri] + if not lines then + lines = parser:lines(file.text) + m.linesMap[uri] = lines + end + return lines +end + +--- 获取原始uri +function m.getOriginUri(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + local file = m.fileMap[uri] + if not file then + return nil + end + return file.uri +end + +function m.getUri(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + return uri +end + +--- 获取文件的自定义缓存信息(在文件内容更新后自动失效) +function m.getCache(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + local file = m.fileMap[uri] + if not file then + return nil + end + file.cacheActiveTime = timer.clock() + return file.cache +end + +--- 判断文件名相等 +function m.eq(a, b) + if platform.OS == 'Windows' then + return a:lower() == b:lower() + else + return a == b + end +end + +--- 获取文件关联 +function m.getAssoc() + if m.assocVersion == config.version then + return m.assocMatcher + end + m.assocVersion = config.version + local patt = {} + for k, v in pairs(config.other.associations) do + if m.eq(v, 'lua') then + patt[#patt+1] = k + end + end + m.assocMatcher = glob.glob(patt) + if platform.OS == 'Windows' then + m.assocMatcher:setOption 'ignoreCase' + end + return m.assocMatcher +end + +--- 判断是否是Lua文件 +---@param uri string +---@return boolean +function m.isLua(uri) + local ext = uri:match '%.([^%.%/%\\]-)$' + if not ext then + return false + end + if m.eq(ext, 'lua') then + return true + end + local matcher = m.getAssoc() + local path = furi.decode(uri) + return matcher(path) +end + +--- 注册事件 +function m.watch(callback) + m.watchList[#m.watchList+1] = callback +end + +function m.onWatch(ev, ...) + for _, callback in ipairs(m.watchList) do + callback(ev, ...) + end +end + +function m.flushCache() + for uri, file in pairs(m.fileMap) do + file.cacheActiveTime = math.huge + m.linesMap[uri] = nil + m.astMap[uri] = nil + file.cache = {} + end +end + +function m.flushFileCache(uri) + if platform.OS == 'Windows' then + uri = uri:lower() + end + local file = m.fileMap[uri] + if not file then + return + end + file.cacheActiveTime = math.huge + m.linesMap[uri] = nil + m.astMap[uri] = nil + file.cache = {} +end + +local function init() + --TODO 可以清空文件缓存,之后看要不要启用吧 + --timer.loop(10, function () + -- local list = {} + -- for _, file in pairs(m.fileMap) do + -- if timer.clock() - file.cacheActiveTime > 10.0 then + -- file.cacheActiveTime = math.huge + -- file.ast = nil + -- file.cache = {} + -- list[#list+1] = file.uri + -- end + -- end + -- if #list > 0 then + -- log.info('Flush file caches:', #list, '\n', table.concat(list, '\n')) + -- collectgarbage() + -- end + --end) +end + +xpcall(init, log.error) + +return m diff --git a/script/fs-utility.lua b/script/fs-utility.lua new file mode 100644 index 00000000..42041734 --- /dev/null +++ b/script/fs-utility.lua @@ -0,0 +1,559 @@ +local fs = require 'bee.filesystem' +local platform = require 'bee.platform' + +local type = type +local ioOpen = io.open +local pcall = pcall +local pairs = pairs +local setmetatable = setmetatable +local next = next +local ipairs = ipairs +local tostring = tostring +local tableSort = table.sort + +_ENV = nil + +local m = {} +--- 读取文件 +---@param path string +function m.loadFile(path) + if type(path) ~= 'string' then + path = path:string() + end + local f, e = ioOpen(path, 'rb') + if not f then + return nil, e + end + if f:read(3) ~= '\xEF\xBB\xBF' then + f:seek("set") + end + local buf = f:read 'a' + f:close() + return buf +end + +--- 写入文件 +---@param path string +---@param content string +function m.saveFile(path, content) + if type(path) ~= 'string' then + path = path:string() + end + local f, e = ioOpen(path, "wb") + + if f then + f:write(content) + f:close() + return true + else + return false, e + end +end + +local function buildOptional(optional) + optional = optional or {} + optional.add = optional.add or {} + optional.del = optional.del or {} + optional.mod = optional.mod or {} + optional.err = optional.err or {} + return optional +end + +local function split(str, sep) + local t = {} + local current = 1 + while current <= #str do + local s, e = str:find(sep, current) + if not s then + t[#t+1] = str:sub(current) + break + end + if s > 1 then + t[#t+1] = str:sub(current, s - 1) + end + current = e + 1 + end + return t +end + +local dfs = {} +dfs.__index = dfs +dfs.type = 'dummy' +dfs.path = '' + +function m.dummyFS(t) + return setmetatable({ + files = t or {}, + }, dfs) +end + +function dfs:__tostring() + return 'dummy:' .. tostring(self.path) +end + +function dfs:__div(filename) + if type(filename) ~= 'string' then + filename = filename:string() + end + local new = m.dummyFS(self.files) + if self.path:sub(-1):match '[^/\\]' then + new.path = self.path .. '\\' .. filename + else + new.path = self.path .. filename + end + return new +end + +function dfs:_open(index) + local paths = split(self.path, '[/\\]') + local current = self.files + if not index then + index = #paths + elseif index < 0 then + index = #paths + index + 1 + end + for i = 1, index do + local path = paths[i] + if current[path] then + current = current[path] + else + return nil + end + end + return current +end + +function dfs:_filename() + return self.path:match '[^/\\]+$' +end + +function dfs:parent_path() + local new = m.dummyFS(self.files) + if self.path:find('[/\\]') then + new.path = self.path:gsub('[/\\]+[^/\\]*$', '') + else + new.path = '' + end + return new +end + +function dfs:filename() + local new = m.dummyFS(self.files) + new.path = self:_filename() + return new +end + +function dfs:string() + return self.path +end + +function dfs:list_directory() + local dir = self:_open() + if type(dir) ~= 'table' then + return function () end + end + local keys = {} + for k in pairs(dir) do + keys[#keys+1] = k + end + tableSort(keys) + local i = 0 + return function () + i = i + 1 + local k = keys[i] + if not k then + return nil + end + return self / k + end +end + +function dfs:isDirectory() + local target = self:_open() + if type(target) == 'table' then + return true + end + return false +end + +function dfs:remove() + local dir = self:_open(-2) + local filename = self:_filename() + if not filename then + return + end + dir[filename] = nil +end + +function dfs:exists() + local target = self:_open() + return target ~= nil +end + +function dfs:createDirectories(path) + if type(path) ~= 'string' then + path = path:string() + end + local paths = split(path, '[/\\]') + local current = self.files + for i = 1, #paths do + local sub = paths[i] + if current[sub] then + if type(current[sub]) ~= 'table' then + return false + end + else + current[sub] = {} + end + current = current[sub] + end + return true +end + +function dfs:saveFile(path, text) + if type(path) ~= 'string' then + path = path:string() + end + local temp = m.dummyFS(self.files) + temp.path = path + local dir = temp:_open(-2) + if not dir then + return false, '无法打开:' .. path + end + local filename = temp:_filename() + if not filename then + return false, '无法打开:' .. path + end + if type(dir[filename]) == 'table' then + return false, '无法打开:' .. path + end + dir[filename] = text +end + +local function fsAbsolute(path, optional) + if type(path) == 'string' then + local suc, res = pcall(fs.path, path) + if not suc then + optional.err[#optional.err+1] = res + return nil + end + path = res + elseif type(path) == 'table' then + return path + end + local suc, res = pcall(fs.absolute, path) + if not suc then + optional.err[#optional.err+1] = res + return nil + end + return res +end + +local function fsIsDirectory(path, optional) + if path.type == 'dummy' then + return path:isDirectory() + end + local suc, res = pcall(fs.is_directory, path) + if not suc then + optional.err[#optional.err+1] = res + return false + end + return res +end + +local function fsRemove(path, optional) + if path.type == 'dummy' then + return path:remove() + end + local suc, res = pcall(fs.remove, path) + if not suc then + optional.err[#optional.err+1] = res + end + optional.del[#optional.del+1] = path:string() +end + +local function fsExists(path, optional) + if path.type == 'dummy' then + return path:exists() + end + local suc, res = pcall(fs.exists, path) + if not suc then + optional.err[#optional.err+1] = res + return false + end + return res +end + +local function fsSave(path, text, optional) + if path.type == 'dummy' then + local dir = path:_open(-2) + if not dir then + optional.err[#optional.err+1] = '无法打开:' .. path:string() + return false + end + local filename = path:_filename() + if not filename then + optional.err[#optional.err+1] = '无法打开:' .. path:string() + return false + end + if type(dir[filename]) == 'table' then + optional.err[#optional.err+1] = '无法打开:' .. path:string() + return false + end + dir[filename] = text + else + local suc, err = m.saveFile(path, text) + if suc then + return true + end + optional.err[#optional.err+1] = err + return false + end +end + +local function fsLoad(path, optional) + if path.type == 'dummy' then + local text = path:_open() + if type(text) == 'string' then + return text + else + optional.err[#optional.err+1] = '无法打开:' .. path:string() + return nil + end + else + local text, err = m.loadFile(path) + if text then + return text + else + optional.err[#optional.err+1] = err + return nil + end + end +end + +local function fsCopy(source, target, optional) + if source.type == 'dummy' then + local sourceText = source:_open() + if not sourceText then + optional.err[#optional.err+1] = '无法打开:' .. source:string() + return false + end + return fsSave(target, sourceText, optional) + else + if target.type == 'dummy' then + local sourceText, err = m.loadFile(source) + if not sourceText then + optional.err[#optional.err+1] = err + return false + end + return fsSave(target, sourceText, optional) + else + local suc, res = pcall(fs.copy_file, source, target, true) + if not suc then + optional.err[#optional.err+1] = res + return false + end + end + end + return true +end + +local function fsCreateDirectories(path, optional) + if path.type == 'dummy' then + return path:createDirectories() + end + local suc, res = pcall(fs.create_directories, path) + if not suc then + optional.err[#optional.err+1] = res + return false + end + return true +end + +local function fileRemove(path, optional) + if optional.onRemove and optional.onRemove(path) == false then + return + end + if fsIsDirectory(path, optional) then + for child in path:list_directory() do + fileRemove(child, optional) + end + end + if fsRemove(path, optional) then + optional.del[#optional.del+1] = path:string() + end +end + +local function fileCopy(source, target, optional) + if optional.onCopy and optional.onCopy(source, target) == false then + return + end + local isDir1 = fsIsDirectory(source, optional) + local isDir2 = fsIsDirectory(target, optional) + local isExists = fsExists(target, optional) + if isDir1 then + if isDir2 or fsCreateDirectories(target, optional) then + for filePath in source:list_directory() do + local name = filePath:filename():string() + fileCopy(filePath, target / name, optional) + end + end + else + if isExists and not isDir2 then + local buf1 = fsLoad(source, optional) + local buf2 = fsLoad(target, optional) + if buf1 and buf2 then + if buf1 ~= buf2 then + if fsCopy(source, target, optional) then + optional.mod[#optional.mod+1] = target:string() + end + end + end + else + if fsCopy(source, target, optional) then + optional.add[#optional.add+1] = target:string() + end + end + end +end + +local function fileSync(source, target, optional) + if optional.onSync and optional.onSync(source, target) == false then + return + end + local isDir1 = fsIsDirectory(source, optional) + local isDir2 = fsIsDirectory(target, optional) + local isExists = fsExists(target, optional) + if isDir1 then + if isDir2 then + local fileList = m.fileList() + for filePath in target:list_directory() do + fileList[filePath] = true + end + for filePath in source:list_directory() do + local name = filePath:filename():string() + local targetPath = target / name + fileSync(filePath, targetPath, optional) + fileList[targetPath] = nil + end + for path in pairs(fileList) do + fileRemove(path, optional) + end + else + if isExists then + fileRemove(target, optional) + end + if fsCreateDirectories(target) then + for filePath in source:list_directory() do + local name = filePath:filename():string() + fileCopy(filePath, target / name, optional) + end + end + end + else + if isDir2 then + fileRemove(target, optional) + end + if isExists then + local buf1 = fsLoad(source, optional) + local buf2 = fsLoad(target, optional) + if buf1 and buf2 then + if buf1 ~= buf2 then + if fsCopy(source, target, optional) then + optional.mod[#optional.mod+1] = target:string() + end + end + end + else + if fsCopy(source, target, optional) then + optional.add[#optional.add+1] = target:string() + end + end + end +end + +--- 文件列表 +function m.fileList(optional) + optional = optional or buildOptional(optional) + local os = platform.OS + local keyMap = {} + local fileList = {} + local function computeKey(path) + path = fsAbsolute(path, optional) + if not path then + return nil + end + local key + if os == 'Windows' then + key = path:string():lower() + else + key = path:string() + end + return key + end + return setmetatable({}, { + __index = function (_, path) + local key = computeKey(path) + return fileList[key] + end, + __newindex = function (_, path, value) + local key = computeKey(path) + if not key then + return + end + if value == nil then + keyMap[key] = nil + else + keyMap[key] = path + fileList[key] = value + end + end, + __pairs = function () + local key, path + return function () + key, path = next(keyMap, key) + return path, fileList[key] + end + end, + }) +end + +--- 删除文件(夹) +function m.fileRemove(path, optional) + optional = buildOptional(optional) + path = fsAbsolute(path, optional) + + fileRemove(path, optional) + + return optional +end + +--- 复制文件(夹) +---@param source string +---@param target string +---@return table +function m.fileCopy(source, target, optional) + optional = buildOptional(optional) + source = fsAbsolute(source, optional) + target = fsAbsolute(target, optional) + + fileCopy(source, target, optional) + + return optional +end + +--- 同步文件(夹) +---@param source string +---@param target string +---@return table +function m.fileSync(source, target, optional) + optional = buildOptional(optional) + source = fsAbsolute(source, optional) + target = fsAbsolute(target, optional) + + fileSync(source, target, optional) + + return optional +end + +return m diff --git a/script/glob/gitignore.lua b/script/glob/gitignore.lua new file mode 100644 index 00000000..f98a2f31 --- /dev/null +++ b/script/glob/gitignore.lua @@ -0,0 +1,221 @@ +local m = require 'lpeglabel' +local matcher = require 'glob.matcher' + +local function prop(name, pat) + return m.Cg(m.Cc(true), name) * pat +end + +local function object(type, pat) + return m.Ct( + m.Cg(m.Cc(type), 'type') * + m.Cg(pat, 'value') + ) +end + +local function expect(p, err) + return p + m.T(err) +end + +local parser = m.P { + 'Main', + ['Sp'] = m.S(' \t')^0, + ['Slash'] = m.S('/\\')^1, + ['Main'] = m.Ct(m.V'Sp' * m.P'{' * m.V'Pattern' * (',' * expect(m.V'Pattern', 'Miss exp after ","'))^0 * m.P'}') + + m.Ct(m.V'Pattern') + + m.T'Main Failed' + , + ['Pattern'] = m.Ct(m.V'Sp' * prop('neg', m.P'!') * expect(m.V'Unit', 'Miss exp after "!"')) + + m.Ct(m.V'Unit') + , + ['NeedRoot'] = prop('root', (m.P'.' * m.V'Slash' + m.V'Slash')), + ['Unit'] = m.V'Sp' * m.V'NeedRoot'^-1 * expect(m.V'Exp', 'Miss exp') * m.V'Sp', + ['Exp'] = m.V'Sp' * (m.V'FSymbol' + object('/', m.V'Slash') + m.V'Word')^0 * m.V'Sp', + ['Word'] = object('word', m.Ct((m.V'CSymbol' + m.V'Char' - m.V'FSymbol')^1)), + ['CSymbol'] = object('*', m.P'*') + + object('?', m.P'?') + + object('[]', m.V'Range') + , + ['Char'] = object('char', (1 - m.S',{}[]*?/\\')^1), + ['FSymbol'] = object('**', m.P'**'), + ['Range'] = m.P'[' * m.Ct(m.V'RangeUnit'^0) * m.P']'^-1, + ['RangeUnit'] = m.Ct(- m.P']' * m.C(m.P(1)) * (m.P'-' * - m.P']' * m.C(m.P(1)))^-1), +} + +local mt = {} +mt.__index = mt +mt.__name = 'gitignore' + +function mt:addPattern(pat) + if type(pat) ~= 'string' then + return + end + self.pattern[#self.pattern+1] = pat + if self.options.ignoreCase then + pat = pat:lower() + end + local states, err = parser:match(pat) + if not states then + self.errors[#self.errors+1] = { + pattern = pat, + message = err + } + return + end + for _, state in ipairs(states) do + self.matcher[#self.matcher+1] = matcher(state) + end +end + +function mt:setOption(op, val) + if val == nil then + val = true + end + self.options[op] = val +end + +---@param key string | "'type'" | "'list'" +---@param func function | "function (path) end" +function mt:setInterface(key, func) + if type(func) ~= 'function' then + return + end + self.interface[key] = func +end + +function mt:callInterface(name, ...) + local func = self.interface[name] + return func(...) +end + +function mt:hasInterface(name) + return self.interface[name] ~= nil +end + +function mt:checkDirectory(catch, path, matcher) + if not self:hasInterface 'type' then + return true + end + if not matcher:isNeedDirectory() then + return true + end + if #catch < #path then + -- if path is 'a/b/c' and catch is 'a/b' + -- then the catch must be a directory + return true + else + return self:callInterface('type', path) == 'directory' + end +end + +function mt:simpleMatch(path) + for i = #self.matcher, 1, -1 do + local matcher = self.matcher[i] + local catch = matcher(path) + if catch and self:checkDirectory(catch, path, matcher) then + if matcher:isNegative() then + return false + else + return true + end + end + end + return nil +end + +function mt:finishMatch(path) + local paths = {} + for filename in path:gmatch '[^/\\]+' do + paths[#paths+1] = filename + end + for i = 1, #paths do + local newPath = table.concat(paths, '/', 1, i) + local passed = self:simpleMatch(newPath) + if passed == true then + return true + elseif passed == false then + return false + end + end + return false +end + +function mt:scan(callback) + local files = {} + if type(callback) ~= 'function' then + callback = nil + end + local list = {} + local result = self:callInterface('list', '') + if type(result) ~= 'table' then + return files + end + for _, path in ipairs(result) do + list[#list+1] = path:match '([^/\\]+)[/\\]*$' + end + while #list > 0 do + local current = list[#list] + if not current then + break + end + list[#list] = nil + if not self:simpleMatch(current) then + local fileType = self:callInterface('type', current) + if fileType == 'file' then + if callback then + callback(current) + end + files[#files+1] = current + elseif fileType == 'directory' then + local result = self:callInterface('list', current) + if type(result) == 'table' then + for _, path in ipairs(result) do + local filename = path:match '([^/\\]+)[/\\]*$' + if filename then + list[#list+1] = current .. '/' .. filename + end + end + end + end + end + end + return files +end + +function mt:__call(path) + if self.options.ignoreCase then + path = path:lower() + end + return self:finishMatch(path) +end + +return function (pattern, options, interface) + local self = setmetatable({ + pattern = {}, + options = {}, + matcher = {}, + errors = {}, + interface = {}, + }, mt) + + if type(pattern) == 'table' then + for _, pat in ipairs(pattern) do + self:addPattern(pat) + end + else + self:addPattern(pattern) + end + + if type(options) == 'table' then + for op, val in pairs(options) do + self:setOption(op, val) + end + end + + if type(interface) == 'table' then + for key, func in pairs(interface) do + self:setInterface(key, func) + end + end + + return self +end diff --git a/script/glob/glob.lua b/script/glob/glob.lua new file mode 100644 index 00000000..aa8923f3 --- /dev/null +++ b/script/glob/glob.lua @@ -0,0 +1,122 @@ +local m = require 'lpeglabel' +local matcher = require 'glob.matcher' + +local function prop(name, pat) + return m.Cg(m.Cc(true), name) * pat +end + +local function object(type, pat) + return m.Ct( + m.Cg(m.Cc(type), 'type') * + m.Cg(pat, 'value') + ) +end + +local function expect(p, err) + return p + m.T(err) +end + +local parser = m.P { + 'Main', + ['Sp'] = m.S(' \t')^0, + ['Slash'] = m.S('/\\')^1, + ['Main'] = m.Ct(m.V'Sp' * m.P'{' * m.V'Pattern' * (',' * expect(m.V'Pattern', 'Miss exp after ","'))^0 * m.P'}') + + m.Ct(m.V'Pattern') + + m.T'Main Failed' + , + ['Pattern'] = m.Ct(m.V'Sp' * prop('neg', m.P'!') * expect(m.V'Unit', 'Miss exp after "!"')) + + m.Ct(m.V'Unit') + , + ['NeedRoot'] = prop('root', (m.P'.' * m.V'Slash' + m.V'Slash')), + ['Unit'] = m.V'Sp' * m.V'NeedRoot'^-1 * expect(m.V'Exp', 'Miss exp') * m.V'Sp', + ['Exp'] = m.V'Sp' * (m.V'FSymbol' + object('/', m.V'Slash') + m.V'Word')^0 * m.V'Sp', + ['Word'] = object('word', m.Ct((m.V'CSymbol' + m.V'Char' - m.V'FSymbol')^1)), + ['CSymbol'] = object('*', m.P'*') + + object('?', m.P'?') + + object('[]', m.V'Range') + , + ['Char'] = object('char', (1 - m.S',{}[]*?/\\')^1), + ['FSymbol'] = object('**', m.P'**'), + ['RangeWord'] = 1 - m.P']', + ['Range'] = m.P'[' * m.Ct(m.V'RangeUnit'^0) * m.P']'^-1, + ['RangeUnit'] = m.Ct(m.C(m.V'RangeWord') * m.P'-' * m.C(m.V'RangeWord')) + + m.V'RangeWord', +} + +local mt = {} +mt.__index = mt +mt.__name = 'glob' + +function mt:addPattern(pat) + if type(pat) ~= 'string' then + return + end + self.pattern[#self.pattern+1] = pat + if self.options.ignoreCase then + pat = pat:lower() + end + local states, err = parser:match(pat) + if not states then + self.errors[#self.errors+1] = { + pattern = pat, + message = err + } + return + end + for _, state in ipairs(states) do + if state.neg then + self.refused[#self.refused+1] = matcher(state) + else + self.passed[#self.passed+1] = matcher(state) + end + end +end + +function mt:setOption(op, val) + if val == nil then + val = true + end + self.options[op] = val +end + +function mt:__call(path) + if self.options.ignoreCase then + path = path:lower() + end + for _, refused in ipairs(self.refused) do + if refused(path) then + return false + end + end + for _, passed in ipairs(self.passed) do + if passed(path) then + return true + end + end + return false +end + +return function (pattern, options) + local self = setmetatable({ + pattern = {}, + options = {}, + passed = {}, + refused = {}, + errors = {}, + }, mt) + + if type(pattern) == 'table' then + for _, pat in ipairs(pattern) do + self:addPattern(pat) + end + else + self:addPattern(pattern) + end + + if type(options) == 'table' then + for op, val in pairs(options) do + self:setOption(op, val) + end + end + return self +end diff --git a/script/glob/init.lua b/script/glob/init.lua new file mode 100644 index 00000000..6578a0d4 --- /dev/null +++ b/script/glob/init.lua @@ -0,0 +1,4 @@ +return { + glob = require 'glob.glob', + gitignore = require 'glob.gitignore', +} diff --git a/script/glob/matcher.lua b/script/glob/matcher.lua new file mode 100644 index 00000000..f4c2b12c --- /dev/null +++ b/script/glob/matcher.lua @@ -0,0 +1,151 @@ +local m = require 'lpeglabel' + +local Slash = m.S('/\\')^1 +local Symbol = m.S',{}[]*?/\\' +local Char = 1 - Symbol +local Path = Char^1 * Slash +local NoWord = #(m.P(-1) + Symbol) +local function whatHappened() + return m.Cmt(m.P(1)^1, function (...) + print(...) + end) +end + +local mt = {} +mt.__index = mt +mt.__name = 'matcher' + +function mt:exp(state, index) + local exp = state[index] + if not exp then + return + end + if exp.type == 'word' then + return self:word(exp, state, index + 1) + elseif exp.type == 'char' then + return self:char(exp, state, index + 1) + elseif exp.type == '**' then + return self:anyPath(exp, state, index + 1) + elseif exp.type == '*' then + return self:anyChar(exp, state, index + 1) + elseif exp.type == '?' then + return self:oneChar(exp, state, index + 1) + elseif exp.type == '[]' then + return self:range(exp, state, index + 1) + elseif exp.type == '/' then + return self:slash(exp, state, index + 1) + end +end + +function mt:word(exp, state, index) + local current = self:exp(exp.value, 1) + local after = self:exp(state, index) + if after then + return current * Slash * after + else + return current + end +end + +function mt:char(exp, state, index) + local current = m.P(exp.value) + local after = self:exp(state, index) + if after then + return current * after * NoWord + else + return current * NoWord + end +end + +function mt:anyPath(_, state, index) + local after = self:exp(state, index) + if after then + return m.P { + 'Main', + Main = after + + Path * m.V'Main' + } + else + return Path^0 + end +end + +function mt:anyChar(_, state, index) + local after = self:exp(state, index) + if after then + return m.P { + 'Main', + Main = after + + Char * m.V'Main' + } + else + return Char^0 + end +end + +function mt:oneChar(_, state, index) + local after = self:exp(state, index) + if after then + return Char * after + else + return Char + end +end + +function mt:range(exp, state, index) + local after = self:exp(state, index) + local ranges = {} + local selects = {} + for _, range in ipairs(exp.value) do + if #range == 1 then + selects[#selects+1] = range[1] + elseif #range == 2 then + ranges[#ranges+1] = range[1] .. range[2] + end + end + local current = m.S(table.concat(selects)) + m.R(table.unpack(ranges)) + if after then + return current * after + else + return current + end +end + +function mt:slash(_, state, index) + local after = self:exp(state, index) + if after then + return after + else + self.needDirectory = true + return nil + end +end + +function mt:pattern(state) + if state.root then + return m.C(self:exp(state, 1)) + else + return m.C(self:anyPath(nil, state, 1)) + end +end + +function mt:isNeedDirectory() + return self.needDirectory == true +end + +function mt:isNegative() + return self.state.neg == true +end + +function mt:__call(path) + return self.matcher:match(path) +end + +return function (state, options) + local self = setmetatable({ + options = options, + state = state, + }, mt) + self.matcher = self:pattern(state) + return self +end diff --git a/script/json-beautify.lua b/script/json-beautify.lua new file mode 100644 index 00000000..1d2a6cc0 --- /dev/null +++ b/script/json-beautify.lua @@ -0,0 +1,120 @@ +local json = require "json" +local type = type +local next = next +local error = error +local table_concat = table.concat +local table_sort = table.sort +local string_rep = string.rep +local math_type = math.type +local setmetatable = setmetatable +local getmetatable = getmetatable + +local statusMark +local statusQue +local statusDep +local statusOpt + +local defaultOpt = { + newline = "\n", + indent = " ", +} +defaultOpt.__index = defaultOpt + +local function encode_newline() + statusQue[#statusQue+1] = statusOpt.newline..string_rep(statusOpt.indent, statusDep) +end + +local encode_map = {} +for k ,v in next, json.encode_map do + encode_map[k] = v +end + +local encode_string = json.encode_map.string + +local function encode(v) + local res = encode_map[type(v)](v) + statusQue[#statusQue+1] = res +end + +function encode_map.table(t) + local first_val = next(t) + if first_val == nil then + if getmetatable(t) == json.object then + return "{}" + else + return "[]" + end + end + if statusMark[t] then + error("circular reference") + end + statusMark[t] = true + if type(first_val) == 'string' then + local key = {} + for k in next, t do + if type(k) ~= "string" then + error("invalid table: mixed or invalid key types") + end + key[#key+1] = k + end + table_sort(key) + statusQue[#statusQue+1] = "{" + statusDep = statusDep + 1 + encode_newline() + local k = key[1] + statusQue[#statusQue+1] = encode_string(k) + statusQue[#statusQue+1] = ": " + encode(t[k]) + for i = 2, #key do + local k = key[i] + statusQue[#statusQue+1] = "," + encode_newline() + statusQue[#statusQue+1] = encode_string(k) + statusQue[#statusQue+1] = ": " + encode(t[k]) + end + statusDep = statusDep - 1 + encode_newline() + statusMark[t] = nil + return "}" + else + local max = 0 + for k in next, t do + if math_type(k) ~= "integer" or k <= 0 then + error("invalid table: mixed or invalid key types") + end + if max < k then + max = k + end + end + statusQue[#statusQue+1] = "[" + statusDep = statusDep + 1 + encode_newline() + encode(t[1]) + for i = 2, max do + statusQue[#statusQue+1] = "," + encode_newline() + encode(t[i]) + end + statusDep = statusDep - 1 + encode_newline() + statusMark[t] = nil + return "]" + end +end + +local function beautify(v, option) + if type(v) == "string" then + v = json.decode(v) + end + statusMark = {} + statusQue = {} + statusDep = 0 + statusOpt = option and setmetatable(option, defaultOpt) or defaultOpt + encode(v) + return table_concat(statusQue) +end + +json.beautify = beautify + +return json diff --git a/script/json.lua b/script/json.lua new file mode 100644 index 00000000..46261d7d --- /dev/null +++ b/script/json.lua @@ -0,0 +1,450 @@ +local type = type +local next = next +local error = error +local tonumber = tonumber +local tostring = tostring +local utf8_char = utf8.char +local table_concat = table.concat +local table_sort = table.sort +local string_char = string.char +local string_byte = string.byte +local string_find = string.find +local string_match = string.match +local string_gsub = string.gsub +local string_sub = string.sub +local string_format = string.format +local math_type = math.type +local setmetatable = setmetatable +local getmetatable = getmetatable +local Inf = math.huge + +local json = {} +json.object = {} + +-- json.encode -- +local statusMark +local statusQue + +local encode_map = {} + +local encode_escape_map = { + [ "\"" ] = "\\\"", + [ "\\" ] = "\\\\", + [ "/" ] = "\\/", + [ "\b" ] = "\\b", + [ "\f" ] = "\\f", + [ "\n" ] = "\\n", + [ "\r" ] = "\\r", + [ "\t" ] = "\\t", +} + +local decode_escape_set = {} +local decode_escape_map = {} +for k, v in next, encode_escape_map do + decode_escape_map[v] = k + decode_escape_set[string_byte(v, 2)] = true +end + +for i = 0, 31 do + local c = string_char(i) + if not encode_escape_map[c] then + encode_escape_map[c] = string_format("\\u%04x", i) + end +end + +local function encode(v) + local res = encode_map[type(v)](v) + statusQue[#statusQue+1] = res +end + +encode_map["nil"] = function () + return "null" +end + +function encode_map.string(v) + return '"' .. string_gsub(v, '[\0-\31\\"]', encode_escape_map) .. '"' +end +local encode_string = encode_map.string + +local function convertreal(v) + local g = string_format('%.16g', v) + if tonumber(g) == v then + return g + end + return string_format('%.17g', v) +end + +function encode_map.number(v) + if v ~= v or v <= -Inf or v >= Inf then + error("unexpected number value '" .. tostring(v) .. "'") + end + return string_gsub(convertreal(v), ',', '.') +end + +function encode_map.boolean(v) + if v then + return "true" + else + return "false" + end +end + +function encode_map.table(t) + local first_val = next(t) + if first_val == nil then + if getmetatable(t) == json.object then + return "{}" + else + return "[]" + end + end + if statusMark[t] then + error("circular reference") + end + statusMark[t] = true + if type(first_val) == 'string' then + local key = {} + for k in next, t do + if type(k) ~= "string" then + error("invalid table: mixed or invalid key types") + end + key[#key+1] = k + end + table_sort(key) + statusQue[#statusQue+1] = "{" + local k = key[1] + statusQue[#statusQue+1] = encode_string(k) + statusQue[#statusQue+1] = ":" + encode(t[k]) + for i = 2, #key do + local k = key[i] + statusQue[#statusQue+1] = "," + statusQue[#statusQue+1] = encode_string(k) + statusQue[#statusQue+1] = ":" + encode(t[k]) + end + statusMark[t] = nil + return "}" + else + local max = 0 + for k in next, t do + if math_type(k) ~= "integer" or k <= 0 then + error("invalid table: mixed or invalid key types") + end + if max < k then + max = k + end + end + statusQue[#statusQue+1] = "[" + encode(t[1]) + for i = 2, max do + statusQue[#statusQue+1] = "," + encode(t[i]) + end + statusMark[t] = nil + return "]" + end +end + +local function encode_unexpected(v) + if v == json.null then + return "null" + else + error("unexpected type '"..type(v).."'") + end +end +encode_map[ "function" ] = encode_unexpected +encode_map[ "userdata" ] = encode_unexpected +encode_map[ "thread" ] = encode_unexpected + +function json.encode(v) + statusMark = {} + statusQue = {} + encode(v) + return table_concat(statusQue) +end + +json.encode_map = encode_map + +-- json.decode -- + +local statusBuf +local statusPos +local statusTop +local statusAry = {} +local statusRef = {} + +local function find_line() + local line = 1 + local pos = 1 + while true do + local f, _, nl1, nl2 = string_find(statusBuf, '([\n\r])([\n\r]?)', pos) + if not f then + return line, statusPos - pos + 1 + end + local newpos = f + ((nl1 == nl2 or nl2 == '') and 1 or 2) + if newpos > statusPos then + return line, statusPos - pos + 1 + end + pos = newpos + line = line + 1 + end +end + +local function decode_error(msg) + error(string_format("ERROR: %s at line %d col %d", msg, find_line())) +end + +local function get_word() + return string_match(statusBuf, "^[^ \t\r\n%]},]*", statusPos) +end + +local function next_byte() + statusPos = string_find(statusBuf, "[^ \t\r\n]", statusPos) + if statusPos then + return string_byte(statusBuf, statusPos) + end + statusPos = #statusBuf + 1 + decode_error("unexpected character '<eol>'") +end + +local function expect_byte(c) + local _, pos = string_find(statusBuf, c, statusPos) + if not pos then + decode_error(string_format("expected '%s'", string_sub(c, #c))) + end + statusPos = pos +end + +local function decode_unicode_surrogate(s1, s2) + return utf8_char(0x10000 + (tonumber(s1, 16) - 0xd800) * 0x400 + (tonumber(s2, 16) - 0xdc00)) +end + +local function decode_unicode_escape(s) + return utf8_char(tonumber(s, 16)) +end + +local function decode_string() + local has_unicode_escape = false + local has_escape = false + local i = statusPos + 1 + while true do + i = string_find(statusBuf, '["\\\0-\31]', i) + if not i then + decode_error "expected closing quote for string" + end + local x = string_byte(statusBuf, i) + if x < 32 then + statusPos = i + decode_error "control character in string" + end + if x == 34 --[[ '"' ]] then + local s = string_sub(statusBuf, statusPos + 1, i - 1) + if has_unicode_escape then + s = string_gsub(string_gsub(s + , "\\u([dD][89aAbB]%x%x)\\u([dD][c-fC-F]%x%x)", decode_unicode_surrogate) + , "\\u(%x%x%x%x)", decode_unicode_escape) + end + if has_escape then + s = string_gsub(s, "\\.", decode_escape_map) + end + statusPos = i + 1 + return s + end + --assert(x == 92 --[[ "\\" ]]) + local nx = string_byte(statusBuf, i+1) + if nx == 117 --[[ "u" ]] then + if not string_match(statusBuf, "^%x%x%x%x", i+2) then + statusPos = i + decode_error "invalid unicode escape in string" + end + has_unicode_escape = true + i = i + 6 + else + if not decode_escape_set[nx] then + statusPos = i + decode_error("invalid escape char '" .. (nx and string_char(nx) or "<eol>") .. "' in string") + end + has_escape = true + i = i + 2 + end + end +end + +local function decode_number() + local word = get_word() + if not ( + string_match(word, '^.[0-9]*$') + or string_match(word, '^.[0-9]*%.[0-9]+$') + or string_match(word, '^.[0-9]*[Ee][+-]?[0-9]+$') + or string_match(word, '^.[0-9]*%.[0-9]+[Ee][+-]?[0-9]+$') + ) then + decode_error("invalid number '" .. word .. "'") + end + statusPos = statusPos + #word + return tonumber(word) +end + +local function decode_number_negative() + local word = get_word() + if not ( + string_match(word, '^.[1-9][0-9]*$') + or string_match(word, '^.[1-9][0-9]*%.[0-9]+$') + or string_match(word, '^.[1-9][0-9]*[Ee][+-]?[0-9]+$') + or string_match(word, '^.[1-9][0-9]*%.[0-9]+[Ee][+-]?[0-9]+$') + or word == "-0" + or string_match(word, '^.0%.[0-9]+$') + or string_match(word, '^.0[Ee][+-]?[0-9]+$') + or string_match(word, '^.0%.[0-9]+[Ee][+-]?[0-9]+$') + ) then + decode_error("invalid number '" .. word .. "'") + end + statusPos = statusPos + #word + return tonumber(word) +end + +local function decode_number_zero() + local word = get_word() + if not ( + #word == 1 + or string_match(word, '^.%.[0-9]+$') + or string_match(word, '^.[Ee][+-]?[0-9]+$') + or string_match(word, '^.%.[0-9]+[Ee][+-]?[0-9]+$') + ) then + decode_error("invalid number '" .. word .. "'") + end + statusPos = statusPos + #word + return tonumber(word) +end + +local function decode_true() + if string_sub(statusBuf, statusPos, statusPos+3) ~= "true" then + decode_error("invalid literal '" .. get_word() .. "'") + end + statusPos = statusPos + 4 + return true +end + +local function decode_false() + if string_sub(statusBuf, statusPos, statusPos+4) ~= "false" then + decode_error("invalid literal '" .. get_word() .. "'") + end + statusPos = statusPos + 5 + return false +end + +local function decode_null() + if string_sub(statusBuf, statusPos, statusPos+3) ~= "null" then + decode_error("invalid literal '" .. get_word() .. "'") + end + statusPos = statusPos + 4 + return json.null +end + +local function decode_array() + statusPos = statusPos + 1 + local res = {} + if next_byte() == 93 --[[ "]" ]] then + statusPos = statusPos + 1 + return res + end + statusTop = statusTop + 1 + statusAry[statusTop] = true + statusRef[statusTop] = res + return res +end + +local function decode_object() + statusPos = statusPos + 1 + local res = {} + if next_byte() == 125 --[[ "}" ]] then + statusPos = statusPos + 1 + return setmetatable(res, json.object) + end + statusTop = statusTop + 1 + statusAry[statusTop] = false + statusRef[statusTop] = res + return res +end + +local decode_uncompleted_map = { + [ string_byte '"' ] = decode_string, + [ string_byte "0" ] = decode_number_zero, + [ string_byte "1" ] = decode_number, + [ string_byte "2" ] = decode_number, + [ string_byte "3" ] = decode_number, + [ string_byte "4" ] = decode_number, + [ string_byte "5" ] = decode_number, + [ string_byte "6" ] = decode_number, + [ string_byte "7" ] = decode_number, + [ string_byte "8" ] = decode_number, + [ string_byte "9" ] = decode_number, + [ string_byte "-" ] = decode_number_negative, + [ string_byte "t" ] = decode_true, + [ string_byte "f" ] = decode_false, + [ string_byte "n" ] = decode_null, + [ string_byte "[" ] = decode_array, + [ string_byte "{" ] = decode_object, +} +local function unexpected_character() + decode_error("unexpected character '" .. string_sub(statusBuf, statusPos, statusPos) .. "'") +end + +local decode_map = {} +for i = 0, 255 do + decode_map[i] = decode_uncompleted_map[i] or unexpected_character +end + +local function decode() + return decode_map[next_byte()]() +end + +local function decode_item() + local top = statusTop + local ref = statusRef[top] + if statusAry[top] then + ref[#ref+1] = decode() + else + expect_byte '^[ \t\r\n]*"' + local key = decode_string() + expect_byte '^[ \t\r\n]*:' + statusPos = statusPos + 1 + ref[key] = decode() + end + if top == statusTop then + repeat + local chr = next_byte(); statusPos = statusPos + 1 + if chr == 44 --[[ "," ]] then + return + end + if statusAry[statusTop] then + if chr ~= 93 --[[ "]" ]] then decode_error "expected ']' or ','" end + else + if chr ~= 125 --[[ "}" ]] then decode_error "expected '}' or ','" end + end + statusTop = statusTop - 1 + until statusTop == 0 + end +end + +function json.decode(str) + if type(str) ~= "string" then + error("expected argument of type string, got " .. type(str)) + end + statusBuf = str + statusPos = 1 + statusTop = 0 + local res = decode() + while statusTop > 0 do + decode_item() + end + if string_find(statusBuf, "[^ \t\r\n]", statusPos) then + decode_error "trailing garbage" + end + return res +end + +-- Generate a lightuserdata +json.null = debug.upvalueid(decode, 1) + +return json diff --git a/script/jsonrpc.lua b/script/jsonrpc.lua new file mode 100644 index 00000000..17e6e73d --- /dev/null +++ b/script/jsonrpc.lua @@ -0,0 +1,64 @@ +local json = require 'json' +local pcall = pcall +local tonumber = tonumber +local util = require 'utility' +local log = require 'brave.log' + +---@class jsonrpc +local m = {} +m.type = 'jsonrpc' + +function m.encode(pack) + pack.jsonrpc = '2.0' + local content = json.encode(pack) + local buf = ('Content-Length: %d\r\n\r\n%s'):format(#content, content) + return buf +end + +local function readProtoHead(reader) + local head = {} + while true do + local line = reader 'L' + if line == nil then + -- 说明管道已经关闭了 + return nil, 'Disconnected!' + end + if line == '\r\n' then + break + end + local k, v = line:match '^([^:]+)%s*%:%s*(.+)\r\n$' + if not k then + return nil, 'Proto header error: ' .. line + end + if k == 'Content-Length' then + v = tonumber(v) + end + head[k] = v + end + return head +end + +function m.decode(reader, errHandle) + local head, err = readProtoHead(reader, errHandle) + if not head then + return nil, err + end + local len = head['Content-Length'] + if not len then + return nil, 'Proto header error: ' .. util.dump(head) + end + local content = reader(len) + if not content then + return nil, 'Proto read error' + end + local null = json.null + json.null = nil + local suc, res = pcall(json.decode, content) + json.null = null + if not suc then + return nil, 'Proto parse error: ' .. res + end + return res +end + +return m diff --git a/script/language.lua b/script/language.lua new file mode 100644 index 00000000..6077e4c6 --- /dev/null +++ b/script/language.lua @@ -0,0 +1,140 @@ +local fs = require 'bee.filesystem' +local lni = require 'lni' +local util = require 'utility' + +local function supportLanguage() + local list = {} + for path in (ROOT / 'locale'):list_directory() do + if fs.is_directory(path) then + list[#list+1] = path:filename():string():lower() + end + end + return list +end + +local function osLanguage() + return LANG:lower() +end + +local function getLanguage(id) + local support = supportLanguage() + -- 检查是否支持语言 + if support[id] then + return id + end + -- 根据语言的前2个字母来找近似语言 + for _, lang in ipairs(support) do + if lang:sub(1, 2) == id:sub(1, 2) then + return lang + end + end + -- 使用英文 + return 'enUS' +end + +local function loadFileByLanguage(name, language) + local path = ROOT / 'locale' / language / (name .. '.lni') + local buf = util.loadFile(path:string()) + if not buf then + return {} + end + local suc, tbl = xpcall(lni, log.error, buf, path:string()) + if not suc then + return {} + end + return tbl +end + +local function formatAsArray(str, ...) + local index = 0 + local args = {...} + return str:gsub('%{(.-)%}', function (pat) + local id, fmt + local pos = pat:find(':', 1, true) + if pos then + id = pat:sub(1, pos-1) + fmt = pat:sub(pos+1) + else + id = pat + fmt = 's' + end + id = tonumber(id) + if not id then + index = index + 1 + id = index + end + return ('%'..fmt):format(args[id]) + end) +end + +local function formatAsTable(str, ...) + local args = ... + return str:gsub('%{(.-)%}', function (pat) + local id, fmt + local pos = pat:find(':', 1, true) + if pos then + id = pat:sub(1, pos-1) + fmt = pat:sub(pos+1) + else + id = pat + fmt = 's' + end + if not id then + return + end + return ('%'..fmt):format(args[id]) + end) +end + +local function loadLang(name, language) + local tbl = loadFileByLanguage(name, 'en-US') + if language ~= 'en-US' then + local other = loadFileByLanguage(name, language) + for k, v in pairs(other) do + tbl[k] = v + end + end + return setmetatable(tbl, { + __index = function (self, key) + self[key] = key + return key + end, + __call = function (self, key, ...) + local str = self[key] + if not ... then + return str + end + local suc, res + if type(...) == 'table' then + suc, res = pcall(formatAsTable, str, ...) + else + suc, res = pcall(formatAsArray, str, ...) + end + if suc then + return res + else + -- 这里不能使用翻译,以免死循环 + log.warn(('[%s][%s-%s] formated error: %s'):format( + language, name, key, str + )) + return str + end + end, + }) +end + +local function init() + local id = osLanguage() + local language = getLanguage(id) + log.info(('VSC language: %s'):format(id)) + log.info(('LS language: %s'):format(language)) + return setmetatable({ id = language }, { + __index = function (self, name) + local tbl = loadLang(name, language) + self[name] = tbl + return tbl + end, + }) +end + +return init() diff --git a/script/library.lua b/script/library.lua new file mode 100644 index 00000000..5a48499b --- /dev/null +++ b/script/library.lua @@ -0,0 +1,205 @@ +local lni = require 'lni' +local fs = require 'bee.filesystem' +local config = require 'config' +local util = require 'utility' +local lang = require 'language' +local client = require 'provider.client' + +local m = {} + +local function getDocFormater() + local version = config.config.runtime.version + if client.client() == 'vscode' then + if version == 'Lua 5.1' then + return 'HOVER_NATIVE_DOCUMENT_LUA51' + elseif version == 'Lua 5.2' then + return 'HOVER_NATIVE_DOCUMENT_LUA52' + elseif version == 'Lua 5.3' then + return 'HOVER_NATIVE_DOCUMENT_LUA53' + elseif version == 'Lua 5.4' then + return 'HOVER_NATIVE_DOCUMENT_LUA54' + elseif version == 'LuaJIT' then + return 'HOVER_NATIVE_DOCUMENT_LUAJIT' + end + else + if version == 'Lua 5.1' then + return 'HOVER_DOCUMENT_LUA51' + elseif version == 'Lua 5.2' then + return 'HOVER_DOCUMENT_LUA52' + elseif version == 'Lua 5.3' then + return 'HOVER_DOCUMENT_LUA53' + elseif version == 'Lua 5.4' then + return 'HOVER_DOCUMENT_LUA54' + elseif version == 'LuaJIT' then + return 'HOVER_DOCUMENT_LUAJIT' + end + end +end + +local function convertLink(text) + local fmt = getDocFormater() + return text:gsub('%$([%.%w]+)', function (name) + if fmt then + return ('[%s](%s)'):format(name, lang.script(fmt, 'pdf-' .. name)) + else + return ('`%s`'):format(name) + end + end):gsub('§([%.%w]+)', function (name) + if fmt then + return ('[§%s](%s)'):format(name, lang.script(fmt, name)) + else + return ('`%s`'):format(name) + end + end) +end + +local function createViewDocument(name) + local fmt = getDocFormater() + if not fmt then + return nil + end + return ('[%s](%s)'):format(lang.script.HOVER_VIEW_DOCUMENTS, lang.script(fmt, 'pdf-' .. name)) +end + +local function compileSingleMetaDoc(script, metaLang) + local middleBuf = {} + local compileBuf = {} + + local last = 1 + for start, lua, finish in script:gmatch '()%-%-%-%#([^\n\r]*)()' do + middleBuf[#middleBuf+1] = ('PUSH [===[%s]===]'):format(script:sub(last, start - 1)) + middleBuf[#middleBuf+1] = lua + last = finish + end + middleBuf[#middleBuf+1] = ('PUSH [===[%s]===]'):format(script:sub(last)) + local middleScript = table.concat(middleBuf, '\n') + local version, jit + if config.config.runtime.version == 'LuaJIT' then + version = 5.1 + jit = true + else + version = tonumber(config.config.runtime.version:sub(-3)) + jit = false + end + + local env = setmetatable({ + VERSION = version, + JIT = jit, + PUSH = function (text) + compileBuf[#compileBuf+1] = text + end, + DES = function (name) + local des = metaLang[name] + if not des then + des = ('Miss locale <%s>'):format(name) + end + compileBuf[#compileBuf+1] = '---\n' + for line in util.eachLine(des) do + compileBuf[#compileBuf+1] = '---' + compileBuf[#compileBuf+1] = convertLink(line) + compileBuf[#compileBuf+1] = '\n' + end + local viewDocument = createViewDocument(name) + if viewDocument then + compileBuf[#compileBuf+1] = '---\n---' + compileBuf[#compileBuf+1] = viewDocument + compileBuf[#compileBuf+1] = '\n' + end + compileBuf[#compileBuf+1] = '---\n' + end, + DESENUM = function (name) + local des = metaLang[name] + if not des then + des = ('Miss locale <%s>'):format(name) + end + compileBuf[#compileBuf+1] = convertLink(des) + compileBuf[#compileBuf+1] = '\n' + end, + ALIVE = function (str) + local isAlive + for piece in str:gmatch '[^%,]+' do + if piece:sub(1, 1) == '>' then + local alive = tonumber(piece:sub(2)) + if not alive or version >= alive then + isAlive = true + break + end + elseif piece:sub(1, 1) == '<' then + local alive = tonumber(piece:sub(2)) + if not alive or version <= alive then + isAlive = true + break + end + else + local alive = tonumber(piece) + if not alive or version == alive then + isAlive = true + break + end + end + end + if not isAlive then + compileBuf[#compileBuf+1] = '---@deprecated\n' + end + end, + }, { __index = _ENV }) + + util.saveFile((ROOT / 'log' / 'middleScript.lua'):string(), middleScript) + + assert(load(middleScript, middleScript, 't', env))() + return table.concat(compileBuf) +end + +local function loadMetaLocale(langID, result) + result = result or {} + local path = (ROOT / 'locale' / langID / 'meta.lni'):string() + local lniContent = util.loadFile(path) + if lniContent then + xpcall(lni, log.error, lniContent, path, {result}) + end + return result +end + +local function compileMetaDoc() + local langID = lang.id + local version = config.config.runtime.version + local metapath = ROOT / 'meta' / config.config.runtime.meta:gsub('%$%{(.-)%}', { + version = version, + language = langID, + }) + if fs.exists(metapath) then + --return + end + + local metaLang = loadMetaLocale('en-US') + if langID ~= 'en-US' then + loadMetaLocale(langID, metaLang) + end + --log.debug('metaLang:', util.dump(metaLang)) + + m.metaPath = metapath:string() + m.metaPaths = {} + fs.create_directory(metapath) + local templateDir = ROOT / 'meta' / 'template' + for fullpath in templateDir:list_directory() do + local filename = fullpath:filename() + local metaDoc = compileSingleMetaDoc(util.loadFile(fullpath:string()), metaLang) + local filepath = metapath / filename + util.saveFile(filepath:string(), metaDoc) + m.metaPaths[#m.metaPaths+1] = filepath:string() + end +end + +local function initFromMetaDoc() + compileMetaDoc() +end + +local function init() + initFromMetaDoc() +end + +function m.init() + init() +end + +return m diff --git a/script/log.lua b/script/log.lua new file mode 100644 index 00000000..169fed7f --- /dev/null +++ b/script/log.lua @@ -0,0 +1,142 @@ +local fs = require 'bee.filesystem' + +local osTime = os.time +local osClock = os.clock +local osDate = os.date +local ioOpen = io.open +local tablePack = table.pack +local tableConcat = table.concat +local tostring = tostring +local debugTraceBack = debug.traceback +local mathModf = math.modf +local debugGetInfo = debug.getinfo +local ioStdErr = io.stderr + +_ENV = nil + +local m = {} + +m.file = nil +m.startTime = osTime() - osClock() +m.size = 0 +m.maxSize = 100 * 1024 * 1024 + +local function trimSrc(src) + src = src:sub(m.prefixLen + 3, -5) + src = src:gsub('^[/\\]+', '') + src = src:gsub('[\\/]+', '.') + return src +end + +local function init_log_file() + if not m.file then + m.file = ioOpen(m.path, 'w') + if not m.file then + return + end + m.file:write('') + m.file:close() + m.file = ioOpen(m.path, 'ab') + if not m.file then + return + end + m.file:setvbuf 'no' + end +end + +local function pushLog(level, ...) + if not m.path then + return + end + local t = tablePack(...) + for i = 1, t.n do + t[i] = tostring(t[i]) + end + local str = tableConcat(t, '\t', 1, t.n) + if level == 'error' then + str = str .. '\n' .. debugTraceBack(nil, 3) + end + local info = debugGetInfo(3, 'Sl') + return m.raw(0, level, str, info.source, info.currentline, osClock()) +end + +function m.info(...) + pushLog('info', ...) +end + +function m.debug(...) + pushLog('debug', ...) +end + +function m.trace(...) + pushLog('trace', ...) +end + +function m.warn(...) + pushLog('warn', ...) +end + +function m.error(...) + return pushLog('error', ...) +end + +function m.raw(thd, level, msg, source, currentline, clock) + if level == 'error' then + ioStdErr:write(msg .. '\n') + end + if m.size > m.maxSize then + return + end + init_log_file() + if not m.file then + return '' + end + local sec, ms = mathModf(m.startTime + clock) + local timestr = osDate('%H:%M:%S', sec) + local agl = '' + if #level < 5 then + agl = (' '):rep(5 - #level) + end + local buf + if currentline == -1 then + buf = ('[%s.%03.f][%s]%s[#%d]: %s\n'):format(timestr, ms * 1000, level, agl, thd, msg) + else + buf = ('[%s.%03.f][%s]%s[#%d:%s:%s]: %s\n'):format(timestr, ms * 1000, level, agl, thd, trimSrc(source), currentline, msg) + end + m.size = m.size + #buf + if m.size > m.maxSize then + m.file:write(buf:sub(1, m.size - m.maxSize)) + m.file:write('[REACH MAX SIZE]') + else + m.file:write(buf) + end + return buf +end + +function m.init(root, path) + local lastBuf + if m.file then + m.file:close() + m.file = nil + local file = ioOpen(m.path, 'rb') + if file then + lastBuf = file:read(m.maxSize) + file:close() + end + end + m.path = path:string() + m.prefixLen = #root:string() + m.size = 0 + if not fs.exists(path:parent_path()) then + fs.create_directories(path:parent_path()) + end + if lastBuf then + init_log_file() + if m.file then + m.file:write(lastBuf) + m.size = m.size + #lastBuf + end + end +end + +return m diff --git a/script/parser/ast.lua b/script/parser/ast.lua new file mode 100644 index 00000000..d8614eae --- /dev/null +++ b/script/parser/ast.lua @@ -0,0 +1,1751 @@ +local tonumber = tonumber +local stringChar = string.char +local utf8Char = utf8.char +local tableUnpack = table.unpack +local mathType = math.type +local tableRemove = table.remove +local pairs = pairs +local tableSort = table.sort + +_ENV = nil + +local State +local PushError +local PushDiag +local PushComment + +-- goto 单独处理 +local RESERVED = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +local VersionOp = { + ['&'] = {'Lua 5.3', 'Lua 5.4'}, + ['~'] = {'Lua 5.3', 'Lua 5.4'}, + ['|'] = {'Lua 5.3', 'Lua 5.4'}, + ['<<'] = {'Lua 5.3', 'Lua 5.4'}, + ['>>'] = {'Lua 5.3', 'Lua 5.4'}, + ['//'] = {'Lua 5.3', 'Lua 5.4'}, +} + +local function checkOpVersion(op) + local versions = VersionOp[op.type] + if not versions then + return + end + for i = 1, #versions do + if versions[i] == State.version then + return + end + end + PushError { + type = 'UNSUPPORT_SYMBOL', + start = op.start, + finish = op.finish, + version = versions, + info = { + version = State.version, + } + } +end + +local function checkMissEnd(start) + if not State.MissEndErr then + return + end + local err = State.MissEndErr + State.MissEndErr = nil + local _, finish = State.lua:find('[%w_]+', start) + if not finish then + return + end + err.info.related = { + { + start = start, + finish = finish, + } + } + PushError { + type = 'MISS_END', + start = start, + finish = finish, + } +end + +local function getSelect(vararg, index) + return { + type = 'select', + start = vararg.start, + finish = vararg.finish, + vararg = vararg, + index = index, + } +end + +local function getValue(values, i) + if not values then + return nil, nil + end + local value = values[i] + if not value then + local last = values[#values] + if not last then + return nil, nil + end + if last.type == 'call' or last.type == 'varargs' then + return getSelect(last, i - #values + 1) + end + return nil, nil + end + if value.type == 'call' or value.type == 'varargs' then + value = getSelect(value, 1) + end + return value +end + +local function createLocal(key, effect, value, attrs) + if not key then + return nil + end + key.type = 'local' + key.effect = effect + key.value = value + key.attrs = attrs + if value then + key.range = value.finish + end + return key +end + +local function createCall(args, start, finish) + if args then + args.type = 'callargs' + args.start = start + args.finish = finish + end + return { + type = 'call', + start = start, + finish = finish, + args = args, + } +end + +local function packList(start, list, finish) + local lastFinish = start + local wantName = true + local count = 0 + for i = 1, #list do + local ast = list[i] + if ast.type == ',' then + if wantName or i == #list then + PushError { + type = 'UNEXPECT_SYMBOL', + start = ast.start, + finish = ast.finish, + info = { + symbol = ',', + } + } + end + wantName = true + else + if not wantName then + PushError { + type = 'MISS_SYMBOL', + start = lastFinish, + finish = ast.start - 1, + info = { + symbol = ',', + } + } + end + wantName = false + count = count + 1 + list[count] = list[i] + end + lastFinish = ast.finish + 1 + end + for i = count + 1, #list do + list[i] = nil + end + list.type = 'list' + list.start = start + list.finish = finish - 1 + return list +end + +local BinaryLevel = { + ['or'] = 1, + ['and'] = 2, + ['<='] = 3, + ['>='] = 3, + ['<'] = 3, + ['>'] = 3, + ['~='] = 3, + ['=='] = 3, + ['|'] = 4, + ['~'] = 5, + ['&'] = 6, + ['<<'] = 7, + ['>>'] = 7, + ['..'] = 8, + ['+'] = 9, + ['-'] = 9, + ['*'] = 10, + ['//'] = 10, + ['/'] = 10, + ['%'] = 10, + ['^'] = 11, +} + +local BinaryForward = { + [01] = true, + [02] = true, + [03] = true, + [04] = true, + [05] = true, + [06] = true, + [07] = true, + [08] = false, + [09] = true, + [10] = true, + [11] = false, +} + +local Defs = { + Nil = function (pos) + return { + type = 'nil', + start = pos, + finish = pos + 2, + } + end, + True = function (pos) + return { + type = 'boolean', + start = pos, + finish = pos + 3, + [1] = true, + } + end, + False = function (pos) + return { + type = 'boolean', + start = pos, + finish = pos + 4, + [1] = false, + } + end, + ShortComment = function (start, text, finish) + PushComment { + start = start, + finish = finish - 1, + text = text, + } + end, + LongComment = function (beforeEq, afterEq, str, missPos) + if missPos then + local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']' + local s, _, w = str:find('(%][%=]*%])[%c%s]*$') + if s then + PushError { + type = 'ERR_LCOMMENT_END', + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + info = { + symbol = endSymbol, + }, + fix = { + title = 'FIX_LCOMMENT_END', + { + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + text = endSymbol, + } + }, + } + end + PushError { + type = 'MISS_SYMBOL', + start = missPos, + finish = missPos, + info = { + symbol = endSymbol, + }, + fix = { + title = 'ADD_LCOMMENT_END', + { + start = missPos, + finish = missPos, + text = endSymbol, + } + }, + } + end + end, + CLongComment = function (start1, finish1, start2, finish2) + PushError { + type = 'ERR_C_LONG_COMMENT', + start = start1, + finish = finish2 - 1, + fix = { + title = 'FIX_C_LONG_COMMENT', + { + start = start1, + finish = finish1 - 1, + text = '--[[', + }, + { + start = start2, + finish = finish2 - 1, + text = '--]]' + }, + } + } + end, + CCommentPrefix = function (start, finish) + PushError { + type = 'ERR_COMMENT_PREFIX', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_COMMENT_PREFIX', + { + start = start, + finish = finish - 1, + text = '--', + }, + } + } + end, + String = function (start, quote, str, finish) + return { + type = 'string', + start = start, + finish = finish - 1, + [1] = str, + [2] = quote, + } + end, + LongString = function (beforeEq, afterEq, str, missPos) + if missPos then + local endSymbol = ']' .. ('='):rep(afterEq-beforeEq) .. ']' + local s, _, w = str:find('(%][%=]*%])[%c%s]*$') + if s then + PushError { + type = 'ERR_LSTRING_END', + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + info = { + symbol = endSymbol, + }, + fix = { + title = 'FIX_LSTRING_END', + { + start = missPos - #str + s - 1, + finish = missPos - #str + s + #w - 2, + text = endSymbol, + } + }, + } + end + PushError { + type = 'MISS_SYMBOL', + start = missPos, + finish = missPos, + info = { + symbol = endSymbol, + }, + fix = { + title = 'ADD_LSTRING_END', + { + start = missPos, + finish = missPos, + text = endSymbol, + } + }, + } + end + return '[' .. ('='):rep(afterEq-beforeEq) .. '[', str + end, + Char10 = function (char) + char = tonumber(char) + if not char or char < 0 or char > 255 then + return '' + end + return stringChar(char) + end, + Char16 = function (pos, char) + if State.version == 'Lua 5.1' then + PushError { + type = 'ERR_ESC', + start = pos-1, + finish = pos, + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return char + end + return stringChar(tonumber(char, 16)) + end, + CharUtf8 = function (pos, char) + if State.version ~= 'Lua 5.3' + and State.version ~= 'Lua 5.4' + and State.version ~= 'LuaJIT' + then + PushError { + type = 'ERR_ESC', + start = pos-3, + finish = pos-2, + version = {'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return char + end + if #char == 0 then + PushError { + type = 'UTF8_SMALL', + start = pos-3, + finish = pos, + } + return '' + end + local v = tonumber(char, 16) + if not v then + for i = 1, #char do + if not tonumber(char:sub(i, i), 16) then + PushError { + type = 'MUST_X16', + start = pos + i - 1, + finish = pos + i - 1, + } + end + end + return '' + end + if State.version == 'Lua 5.4' then + if v < 0 or v > 0x7FFFFFFF then + PushError { + type = 'UTF8_MAX', + start = pos-3, + finish = pos+#char, + info = { + min = '00000000', + max = '7FFFFFFF', + } + } + end + else + if v < 0 or v > 0x10FFFF then + PushError { + type = 'UTF8_MAX', + start = pos-3, + finish = pos+#char, + version = v <= 0x7FFFFFFF and 'Lua 5.4' or nil, + info = { + min = '000000', + max = '10FFFF', + } + } + end + end + if v >= 0 and v <= 0x10FFFF then + return utf8Char(v) + end + return '' + end, + Number = function (start, number, finish) + local n = tonumber(number) + if n then + State.LastNumber = { + type = 'number', + start = start, + finish = finish - 1, + [1] = n, + } + return State.LastNumber + else + PushError { + type = 'MALFORMED_NUMBER', + start = start, + finish = finish - 1, + } + State.LastNumber = { + type = 'number', + start = start, + finish = finish - 1, + [1] = 0, + } + return State.LastNumber + end + end, + FFINumber = function (start, symbol) + local lastNumber = State.LastNumber + if mathType(lastNumber[1]) == 'float' then + PushError { + type = 'UNKNOWN_SYMBOL', + start = start, + finish = start + #symbol - 1, + info = { + symbol = symbol, + } + } + lastNumber[1] = 0 + return + end + if State.version ~= 'LuaJIT' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = start + #symbol - 1, + version = 'LuaJIT', + info = { + version = State.version, + } + } + lastNumber[1] = 0 + end + end, + ImaginaryNumber = function (start, symbol) + local lastNumber = State.LastNumber + if State.version ~= 'LuaJIT' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = start + #symbol - 1, + version = 'LuaJIT', + info = { + version = State.version, + } + } + end + lastNumber[1] = 0 + end, + Name = function (start, str, finish) + local isKeyWord + if RESERVED[str] then + isKeyWord = true + elseif str == 'goto' then + if State.version ~= 'Lua 5.1' and State.version ~= 'LuaJIT' then + isKeyWord = true + end + end + if isKeyWord then + PushError { + type = 'KEYWORD', + start = start, + finish = finish - 1, + } + end + return { + type = 'name', + start = start, + finish = finish - 1, + [1] = str, + } + end, + GetField = function (dot, field) + local obj = { + type = 'getfield', + field = field, + dot = dot, + start = dot.start, + finish = (field or dot).finish, + } + if field then + field.type = 'field' + field.parent = obj + end + return obj + end, + GetIndex = function (start, index, finish) + local obj = { + type = 'getindex', + start = start, + finish = finish - 1, + index = index, + } + if index then + index.parent = obj + end + return obj + end, + GetMethod = function (colon, method) + local obj = { + type = 'getmethod', + method = method, + colon = colon, + start = colon.start, + finish = (method or colon).finish, + } + if method then + method.type = 'method' + method.parent = obj + end + return obj + end, + Single = function (unit) + unit.type = 'getname' + return unit + end, + Simple = function (units) + local last = units[1] + for i = 2, #units do + local current = units[i] + current.node = last + current.start = last.start + last.next = current + last = units[i] + end + return last + end, + SimpleCall = function (call) + if call.type ~= 'call' and call.type ~= 'getmethod' then + PushError { + type = 'EXP_IN_ACTION', + start = call.start, + finish = call.finish, + } + end + return call + end, + BinaryOp = function (start, op) + return { + type = op, + start = start, + finish = start + #op - 1, + } + end, + UnaryOp = function (start, op) + return { + type = op, + start = start, + finish = start + #op - 1, + } + end, + Unary = function (first, ...) + if not ... then + return nil + end + local list = {first, ...} + local e = list[#list] + for i = #list - 1, 1, -1 do + local op = list[i] + checkOpVersion(op) + e = { + type = 'unary', + op = op, + start = op.start, + finish = e.finish, + [1] = e, + } + end + return e + end, + SubBinary = function (op, symb) + if symb then + return op, symb + end + PushError { + type = 'MISS_EXP', + start = op.start, + finish = op.finish, + } + end, + Binary = function (first, op, second, ...) + if not first then + return second + end + if not op then + return first + end + if not ... then + checkOpVersion(op) + return { + type = 'binary', + op = op, + start = first.start, + finish = second.finish, + [1] = first, + [2] = second, + } + end + local list = {first, op, second, ...} + local ops = {} + for i = 2, #list, 2 do + ops[#ops+1] = i + end + tableSort(ops, function (a, b) + local op1 = list[a] + local op2 = list[b] + local lv1 = BinaryLevel[op1.type] + local lv2 = BinaryLevel[op2.type] + if lv1 == lv2 then + local forward = BinaryForward[lv1] + if forward then + return op1.start > op2.start + else + return op1.start < op2.start + end + else + return lv1 < lv2 + end + end) + local final + for i = #ops, 1, -1 do + local n = ops[i] + local op = list[n] + local left = list[n-1] + local right = list[n+1] + local exp = { + type = 'binary', + op = op, + start = left.start, + finish = right and right.finish or op.finish, + [1] = left, + [2] = right, + } + local leftIndex, rightIndex + if list[left] then + leftIndex = list[left[1]] + else + leftIndex = n - 1 + end + if list[right] then + rightIndex = list[right[2]] + else + rightIndex = n + 1 + end + + list[leftIndex] = exp + list[rightIndex] = exp + list[left] = leftIndex + list[right] = rightIndex + list[exp] = n + final = exp + + checkOpVersion(op) + end + return final + end, + Paren = function (start, exp, finish) + if exp and exp.type == 'paren' then + exp.start = start + exp.finish = finish - 1 + return exp + end + return { + type = 'paren', + start = start, + finish = finish - 1, + exp = exp + } + end, + VarArgs = function (dots) + dots.type = 'varargs' + return dots + end, + PackLoopArgs = function (start, list, finish) + local list = packList(start, list, finish) + if #list == 0 then + PushError { + type = 'MISS_LOOP_MIN', + start = finish, + finish = finish, + } + elseif #list == 1 then + PushError { + type = 'MISS_LOOP_MAX', + start = finish, + finish = finish, + } + end + return list + end, + PackInNameList = function (start, list, finish) + local list = packList(start, list, finish) + if #list == 0 then + PushError { + type = 'MISS_NAME', + start = start, + finish = finish, + } + end + return list + end, + PackInExpList = function (start, list, finish) + local list = packList(start, list, finish) + if #list == 0 then + PushError { + type = 'MISS_EXP', + start = start, + finish = finish, + } + end + return list + end, + PackExpList = function (start, list, finish) + local list = packList(start, list, finish) + return list + end, + PackNameList = function (start, list, finish) + local list = packList(start, list, finish) + return list + end, + Call = function (start, args, finish) + return createCall(args, start, finish-1) + end, + COMMA = function (start) + return { + type = ',', + start = start, + finish = start, + } + end, + SEMICOLON = function (start) + return { + type = ';', + start = start, + finish = start, + } + end, + DOTS = function (start) + return { + type = '...', + start = start, + finish = start + 2, + } + end, + COLON = function (start) + return { + type = ':', + start = start, + finish = start, + } + end, + DOT = function (start) + return { + type = '.', + start = start, + finish = start, + } + end, + Function = function (functionStart, functionFinish, args, actions, endStart, endFinish) + actions.type = 'function' + actions.start = functionStart + actions.finish = endFinish - 1 + actions.args = args + actions.keyword= { + functionStart, functionFinish - 1, + endStart, endFinish - 1, + } + checkMissEnd(functionStart) + return actions + end, + NamedFunction = function (functionStart, functionFinish, name, args, actions, endStart, endFinish) + actions.type = 'function' + actions.start = functionStart + actions.finish = endFinish - 1 + actions.args = args + actions.keyword= { + functionStart, functionFinish - 1, + endStart, endFinish - 1, + } + checkMissEnd(functionStart) + if not name then + return + end + if name.type == 'getname' then + name.type = 'setname' + name.value = actions + elseif name.type == 'getfield' then + name.type = 'setfield' + name.value = actions + elseif name.type == 'getmethod' then + name.type = 'setmethod' + name.value = actions + end + name.range = actions.finish + name.vstart = functionStart + return name + end, + LocalFunction = function (start, functionStart, functionFinish, name, args, actions, endStart, endFinish) + actions.type = 'function' + actions.start = start + actions.finish = endFinish - 1 + actions.args = args + actions.keyword= { + functionStart, functionFinish - 1, + endStart, endFinish - 1, + } + checkMissEnd(start) + + if not name then + return + end + + if name.type ~= 'getname' then + PushError { + type = 'UNEXPECT_LFUNC_NAME', + start = name.start, + finish = name.finish, + } + return + end + + local loc = createLocal(name, name.start, actions) + loc.localfunction = true + loc.vstart = functionStart + + return loc + end, + Table = function (start, tbl, finish) + tbl.type = 'table' + tbl.start = start + tbl.finish = finish - 1 + local wantField = true + local lastStart = start + 1 + local fieldCount = 0 + for i = 1, #tbl do + local field = tbl[i] + if field.type == ',' or field.type == ';' then + if wantField then + PushError { + type = 'MISS_EXP', + start = lastStart, + finish = field.start - 1, + } + end + wantField = true + lastStart = field.finish + 1 + else + if not wantField then + PushError { + type = 'MISS_SEP_IN_TABLE', + start = lastStart, + finish = field.start - 1, + } + end + wantField = false + lastStart = field.finish + 1 + fieldCount = fieldCount + 1 + tbl[fieldCount] = field + end + end + for i = fieldCount + 1, #tbl do + tbl[i] = nil + end + return tbl + end, + NewField = function (start, field, value, finish) + local obj = { + type = 'tablefield', + start = start, + finish = finish-1, + field = field, + value = value, + } + if field then + field.type = 'field' + field.parent = obj + end + return obj + end, + NewIndex = function (start, index, value, finish) + local obj = { + type = 'tableindex', + start = start, + finish = finish-1, + index = index, + value = value, + } + if index then + index.parent = obj + end + return obj + end, + FuncArgs = function (start, args, finish) + args.type = 'funcargs' + args.start = start + args.finish = finish - 1 + local lastStart = start + 1 + local wantName = true + local argCount = 0 + for i = 1, #args do + local arg = args[i] + local argAst = arg + if argAst.type == ',' then + if wantName then + PushError { + type = 'MISS_NAME', + start = lastStart, + finish = argAst.start-1, + } + end + wantName = true + else + if not wantName then + PushError { + type = 'MISS_SYMBOL', + start = lastStart-1, + finish = argAst.start-1, + info = { + symbol = ',', + } + } + end + wantName = false + argCount = argCount + 1 + + if argAst.type == '...' then + args[argCount] = arg + if i < #args then + local a = args[i+1] + local b = args[#args] + PushError { + type = 'ARGS_AFTER_DOTS', + start = a.start, + finish = b.finish, + } + end + break + else + args[argCount] = createLocal(arg, arg.start) + end + end + lastStart = argAst.finish + 1 + end + for i = argCount + 1, #args do + args[i] = nil + end + if wantName and argCount > 0 then + PushError { + type = 'MISS_NAME', + start = lastStart, + finish = finish - 1, + } + end + return args + end, + Set = function (start, keys, values, finish) + for i = 1, #keys do + local key = keys[i] + if key.type == 'getname' then + key.type = 'setname' + key.value = getValue(values, i) + elseif key.type == 'getfield' then + key.type = 'setfield' + key.value = getValue(values, i) + elseif key.type == 'getindex' then + key.type = 'setindex' + key.value = getValue(values, i) + end + if key.value then + key.range = key.value.finish + end + end + if values then + for i = #keys+1, #values do + local value = values[i] + PushDiag('redundant-value', { + start = value.start, + finish = value.finish, + max = #keys, + passed = #values, + }) + end + end + return tableUnpack(keys) + end, + LocalAttr = function (attrs) + if #attrs == 0 then + return nil + end + for i = 1, #attrs do + local attr = attrs[i] + local attrAst = attr + attrAst.type = 'localattr' + if State.version ~= 'Lua 5.4' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = attrAst.start, + finish = attrAst.finish, + version = 'Lua 5.4', + info = { + version = State.version, + } + } + elseif attrAst[1] ~= 'const' and attrAst[1] ~= 'close' then + PushError { + type = 'UNKNOWN_TAG', + start = attrAst.start, + finish = attrAst.finish, + info = { + tag = attrAst[1], + } + } + elseif i > 1 then + PushError { + type = 'MULTI_TAG', + start = attrAst.start, + finish = attrAst.finish, + info = { + tag = attrAst[1], + } + } + end + end + attrs.start = attrs[1].start + attrs.finish = attrs[#attrs].finish + return attrs + end, + LocalName = function (name, attrs) + if not name then + return name + end + name.attrs = attrs + return name + end, + Local = function (start, keys, values, finish) + for i = 1, #keys do + local key = keys[i] + local attrs = key.attrs + key.attrs = nil + local value = getValue(values, i) + createLocal(key, finish, value, attrs) + end + if values then + for i = #keys+1, #values do + local value = values[i] + PushDiag('redundant-value', { + start = value.start, + finish = value.finish, + max = #keys, + passed = #values, + }) + end + end + return tableUnpack(keys) + end, + Do = function (start, actions, endA, endB) + actions.type = 'do' + actions.start = start + actions.finish = endB - 1 + actions.keyword= { + start, start + #'do' - 1, + endA , endB - 1, + } + checkMissEnd(start) + return actions + end, + Break = function (start, finish) + return { + type = 'break', + start = start, + finish = finish - 1, + } + end, + Return = function (start, exps, finish) + exps.type = 'return' + exps.start = start + exps.finish = finish - 1 + return exps + end, + Label = function (start, name, finish) + if State.version == 'Lua 5.1' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = finish - 1, + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return + end + if not name then + return nil + end + name.type = 'label' + return name + end, + GoTo = function (start, name, finish) + if State.version == 'Lua 5.1' then + PushError { + type = 'UNSUPPORT_SYMBOL', + start = start, + finish = finish - 1, + version = {'Lua 5.2', 'Lua 5.3', 'Lua 5.4', 'LuaJIT'}, + info = { + version = State.version, + } + } + return + end + if not name then + return nil + end + name.type = 'goto' + return name + end, + IfBlock = function (ifStart, ifFinish, exp, thenStart, thenFinish, actions, finish) + actions.type = 'ifblock' + actions.start = ifStart + actions.finish = finish - 1 + actions.filter = exp + actions.keyword= { + ifStart, ifFinish - 1, + thenStart, thenFinish - 1, + } + return actions + end, + ElseIfBlock = function (elseifStart, elseifFinish, exp, thenStart, thenFinish, actions, finish) + actions.type = 'elseifblock' + actions.start = elseifStart + actions.finish = finish - 1 + actions.filter = exp + actions.keyword= { + elseifStart, elseifFinish - 1, + thenStart, thenFinish - 1, + } + return actions + end, + ElseBlock = function (elseStart, elseFinish, actions, finish) + actions.type = 'elseblock' + actions.start = elseStart + actions.finish = finish - 1 + actions.keyword= { + elseStart, elseFinish - 1, + } + return actions + end, + If = function (start, blocks, endStart, endFinish) + blocks.type = 'if' + blocks.start = start + blocks.finish = endFinish - 1 + local hasElse + for i = 1, #blocks do + local block = blocks[i] + if i == 1 and block.type ~= 'ifblock' then + PushError { + type = 'MISS_SYMBOL', + start = block.start, + finish = block.start, + info = { + symbol = 'if', + } + } + end + if hasElse then + PushError { + type = 'BLOCK_AFTER_ELSE', + start = block.start, + finish = block.finish, + } + end + if block.type == 'elseblock' then + hasElse = true + end + end + checkMissEnd(start) + return blocks + end, + Loop = function (forA, forB, arg, steps, doA, doB, blockStart, block, endA, endB) + local loc = createLocal(arg, blockStart, steps[1]) + block.type = 'loop' + block.start = forA + block.finish = endB - 1 + block.loc = loc + block.max = steps[2] + block.step = steps[3] + block.keyword= { + forA, forB - 1, + doA , doB - 1, + endA, endB - 1, + } + checkMissEnd(forA) + return block + end, + In = function (forA, forB, keys, inA, inB, exp, doA, doB, blockStart, block, endA, endB) + local func = tableRemove(exp, 1) + block.type = 'in' + block.start = forA + block.finish = endB - 1 + block.keys = keys + block.keyword= { + forA, forB - 1, + inA , inB - 1, + doA , doB - 1, + endA, endB - 1, + } + + local values + if func then + local call = createCall(exp, func.finish + 1, exp.finish) + call.node = func + call.start = func.start + func.next = call + values = { call } + keys.range = call.finish + end + for i = 1, #keys do + local loc = keys[i] + if values then + createLocal(loc, blockStart, getValue(values, i)) + else + createLocal(loc, blockStart) + end + end + checkMissEnd(forA) + return block + end, + While = function (whileA, whileB, filter, doA, doB, block, endA, endB) + block.type = 'while' + block.start = whileA + block.finish = endB - 1 + block.filter = filter + block.keyword= { + whileA, whileB - 1, + doA , doB - 1, + endA , endB - 1, + } + checkMissEnd(whileA) + return block + end, + Repeat = function (repeatA, repeatB, block, untilA, untilB, filter, finish) + block.type = 'repeat' + block.start = repeatA + block.finish = finish + block.filter = filter + block.keyword= { + repeatA, repeatB - 1, + untilA , untilB - 1, + } + return block + end, + Lua = function (start, actions, finish) + actions.type = 'main' + actions.start = start + actions.finish = finish - 1 + return actions + end, + + -- 捕获错误 + UnknownSymbol = function (start, symbol) + PushError { + type = 'UNKNOWN_SYMBOL', + start = start, + finish = start + #symbol - 1, + info = { + symbol = symbol, + } + } + return + end, + UnknownAction = function (start, symbol) + PushError { + type = 'UNKNOWN_SYMBOL', + start = start, + finish = start + #symbol - 1, + info = { + symbol = symbol, + } + } + end, + DirtyName = function (pos) + PushError { + type = 'MISS_NAME', + start = pos, + finish = pos, + } + return nil + end, + DirtyExp = function (pos) + PushError { + type = 'MISS_EXP', + start = pos, + finish = pos, + } + return nil + end, + MissExp = function (pos) + PushError { + type = 'MISS_EXP', + start = pos, + finish = pos, + } + end, + MissExponent = function (start, finish) + PushError { + type = 'MISS_EXPONENT', + start = start, + finish = finish - 1, + } + end, + MissQuote1 = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '"' + } + } + end, + MissQuote2 = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = "'" + } + } + end, + MissEscX = function (pos) + PushError { + type = 'MISS_ESC_X', + start = pos-2, + finish = pos+1, + } + end, + MissTL = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '{', + } + } + end, + MissTR = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '}', + } + } + end, + MissBR = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ']', + } + } + end, + MissPL = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '(', + } + } + end, + MissPR = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ')', + } + } + end, + ErrEsc = function (pos) + PushError { + type = 'ERR_ESC', + start = pos-1, + finish = pos, + } + end, + MustX16 = function (pos, str) + PushError { + type = 'MUST_X16', + start = pos, + finish = pos + #str - 1, + } + end, + MissAssign = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '=', + } + } + end, + MissTableSep = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ',' + } + } + end, + MissField = function (pos) + PushError { + type = 'MISS_FIELD', + start = pos, + finish = pos, + } + end, + MissMethod = function (pos) + PushError { + type = 'MISS_METHOD', + start = pos, + finish = pos, + } + end, + MissLabel = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = '::', + } + } + end, + MissEnd = function (pos) + State.MissEndErr = PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'end', + } + } + return pos, pos + end, + MissDo = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'do', + } + } + return pos, pos + end, + MissComma = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = ',', + } + } + end, + MissIn = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'in', + } + } + return pos, pos + end, + MissUntil = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'until', + } + } + return pos, pos + end, + MissThen = function (pos) + PushError { + type = 'MISS_SYMBOL', + start = pos, + finish = pos, + info = { + symbol = 'then', + } + } + return pos, pos + end, + MissName = function (pos) + PushError { + type = 'MISS_NAME', + start = pos, + finish = pos, + } + end, + ExpInAction = function (start, exp, finish) + PushError { + type = 'EXP_IN_ACTION', + start = start, + finish = finish - 1, + } + -- 当exp为nil时,不能返回任何值,否则会产生带洞的actionlist + if exp then + return exp + else + return + end + end, + MissIf = function (start, block) + PushError { + type = 'MISS_SYMBOL', + start = start, + finish = start, + info = { + symbol = 'if', + } + } + return block + end, + MissGT = function (start) + PushError { + type = 'MISS_SYMBOL', + start = start, + finish = start, + info = { + symbol = '>' + } + } + end, + ErrAssign = function (start, finish) + PushError { + type = 'ERR_ASSIGN_AS_EQ', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_ASSIGN_AS_EQ', + { + start = start, + finish = finish - 1, + text = '=', + } + } + } + end, + ErrEQ = function (start, finish) + PushError { + type = 'ERR_EQ_AS_ASSIGN', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_EQ_AS_ASSIGN', + { + start = start, + finish = finish - 1, + text = '==', + } + } + } + return '==' + end, + ErrUEQ = function (start, finish) + PushError { + type = 'ERR_UEQ', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_UEQ', + { + start = start, + finish = finish - 1, + text = '~=', + } + } + } + return '==' + end, + ErrThen = function (start, finish) + PushError { + type = 'ERR_THEN_AS_DO', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_THEN_AS_DO', + { + start = start, + finish = finish - 1, + text = 'then', + } + } + } + return start, finish + end, + ErrDo = function (start, finish) + PushError { + type = 'ERR_DO_AS_THEN', + start = start, + finish = finish - 1, + fix = { + title = 'FIX_DO_AS_THEN', + { + start = start, + finish = finish - 1, + text = 'do', + } + } + } + return start, finish + end, +} + +local function init(state) + State = state + PushError = state.pushError + PushDiag = state.pushDiag + PushComment = state.pushComment +end + +local function close() + State = nil + PushError = function (...) end + PushDiag = function (...) end + PushComment = function (...) end +end + +return { + defs = Defs, + init = init, + close = close, +} diff --git a/script/parser/calcline.lua b/script/parser/calcline.lua new file mode 100644 index 00000000..2e944167 --- /dev/null +++ b/script/parser/calcline.lua @@ -0,0 +1,94 @@ +local m = require 'lpeglabel' +local util = require 'utility' + +local row +local fl +local NL = (m.P'\r\n' + m.S'\r\n') * m.Cp() / function (pos) + row = row + 1 + fl = pos +end +local ROWCOL = (NL + m.P(1))^0 +local function rowcol(str, n) + row = 1 + fl = 1 + ROWCOL:match(str:sub(1, n)) + local col = n - fl + 1 + return row, col +end + +local function rowcol_utf8(str, n) + row = 1 + fl = 1 + ROWCOL:match(str:sub(1, n)) + return row, util.utf8Len(str, fl, n) +end + +local function position(str, _row, _col) + local cur = 1 + local row = 1 + while true do + if row == _row then + return cur + _col - 1 + elseif row > _row then + return cur - 1 + end + local pos = str:find('[\r\n]', cur) + if not pos then + return #str + end + row = row + 1 + if str:sub(pos, pos+1) == '\r\n' then + cur = pos + 2 + else + cur = pos + 1 + end + end +end + +local function position_utf8(str, _row, _col) + local cur = 1 + local row = 1 + while true do + if row == _row then + return utf8.offset(str, _col, cur) + elseif row > _row then + return cur - 1 + end + local pos = str:find('[\r\n]', cur) + if not pos then + return #str + end + row = row + 1 + if str:sub(pos, pos+1) == '\r\n' then + cur = pos + 2 + else + cur = pos + 1 + end + end +end + +local NL = m.P'\r\n' + m.S'\r\n' + +local function line(str, row) + local count = 0 + local res + local LINE = m.Cmt((1 - NL)^0, function (_, _, c) + count = count + 1 + if count == row then + res = c + return false + end + return true + end) + local MATCH = (LINE * NL)^0 * LINE + MATCH:match(str) + return res +end + +return { + rowcol = rowcol, + rowcol_utf8 = rowcol_utf8, + position = position, + position_utf8 = position_utf8, + line = line, +} diff --git a/script/parser/compile.lua b/script/parser/compile.lua new file mode 100644 index 00000000..2c7172e8 --- /dev/null +++ b/script/parser/compile.lua @@ -0,0 +1,561 @@ +local guide = require 'parser.guide' +local type = type + +local specials = { + ['_G'] = true, + ['rawset'] = true, + ['rawget'] = true, + ['setmetatable'] = true, + ['require'] = true, + ['dofile'] = true, + ['loadfile'] = true, + ['pcall'] = true, + ['xpcall'] = true, +} + +_ENV = nil + +local LocalLimit = 200 +local pushError, Compile, CompileBlock, Block, GoToTag, ENVMode, Compiled, LocalCount, Version, Root, Options + +local function addRef(node, obj) + if not node.ref then + node.ref = {} + end + node.ref[#node.ref+1] = obj + obj.node = node +end + +local function addSpecial(name, obj) + if not Root.specials then + Root.specials = {} + end + if not Root.specials[name] then + Root.specials[name] = {} + end + Root.specials[name][#Root.specials[name]+1] = obj + obj.special = name +end + +local vmMap = { + ['getname'] = function (obj) + local loc = guide.getLocal(obj, obj[1], obj.start) + if loc then + obj.type = 'getlocal' + obj.loc = loc + addRef(loc, obj) + if loc.special then + addSpecial(loc.special, obj) + end + else + obj.type = 'getglobal' + local node = guide.getLocal(obj, ENVMode, obj.start) + if node then + addRef(node, obj) + end + local name = obj[1] + if specials[name] then + addSpecial(name, obj) + elseif Options and Options.special then + local asName = Options.special[name] + if specials[asName] then + addSpecial(asName, obj) + end + end + end + return obj + end, + ['getfield'] = function (obj) + Compile(obj.node, obj) + end, + ['call'] = function (obj) + Compile(obj.node, obj) + Compile(obj.args, obj) + end, + ['callargs'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['binary'] = function (obj) + Compile(obj[1], obj) + Compile(obj[2], obj) + end, + ['unary'] = function (obj) + Compile(obj[1], obj) + end, + ['varargs'] = function (obj) + local func = guide.getParentFunction(obj) + if func then + local index, vararg = guide.getFunctionVarArgs(func) + if not index then + pushError { + type = 'UNEXPECT_DOTS', + start = obj.start, + finish = obj.finish, + } + end + if vararg then + if not vararg.ref then + vararg.ref = {} + end + vararg.ref[#vararg.ref+1] = obj + end + end + end, + ['paren'] = function (obj) + Compile(obj.exp, obj) + end, + ['getindex'] = function (obj) + Compile(obj.node, obj) + Compile(obj.index, obj) + end, + ['setindex'] = function (obj) + Compile(obj.node, obj) + Compile(obj.index, obj) + Compile(obj.value, obj) + end, + ['getmethod'] = function (obj) + Compile(obj.node, obj) + Compile(obj.method, obj) + end, + ['setmethod'] = function (obj) + Compile(obj.node, obj) + Compile(obj.method, obj) + local value = obj.value + value.localself = { + type = 'local', + start = 0, + finish = 0, + method = obj, + effect = obj.finish, + tag = 'self', + [1] = 'self', + } + Compile(value, obj) + end, + ['function'] = function (obj) + local lastBlock = Block + local LastLocalCount = LocalCount + Block = obj + LocalCount = 0 + if obj.localself then + Compile(obj.localself, obj) + obj.localself = nil + end + Compile(obj.args, obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + Block = lastBlock + LocalCount = LastLocalCount + end, + ['funcargs'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['table'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['tablefield'] = function (obj) + Compile(obj.value, obj) + end, + ['tableindex'] = function (obj) + Compile(obj.index, obj) + Compile(obj.value, obj) + end, + ['index'] = function (obj) + Compile(obj.index, obj) + end, + ['select'] = function (obj) + local vararg = obj.vararg + if vararg.parent then + if not vararg.extParent then + vararg.extParent = {} + end + vararg.extParent[#vararg.extParent+1] = obj + else + Compile(vararg, obj) + end + end, + ['setname'] = function (obj) + Compile(obj.value, obj) + local loc = guide.getLocal(obj, obj[1], obj.start) + if loc then + obj.type = 'setlocal' + obj.loc = loc + addRef(loc, obj) + if loc.attrs then + local const + for i = 1, #loc.attrs do + local attr = loc.attrs[i][1] + if attr == 'const' + or attr == 'close' then + const = true + break + end + end + if const then + pushError { + type = 'SET_CONST', + start = obj.start, + finish = obj.finish, + } + end + end + else + obj.type = 'setglobal' + local node = guide.getLocal(obj, ENVMode, obj.start) + if node then + addRef(node, obj) + end + local name = obj[1] + if specials[name] then + addSpecial(name, obj) + elseif Options and Options.special then + local asName = Options.special[name] + if specials[asName] then + addSpecial(asName, obj) + end + end + end + end, + ['local'] = function (obj) + local attrs = obj.attrs + if attrs then + for i = 1, #attrs do + Compile(attrs[i], obj) + end + end + if Block then + if not Block.locals then + Block.locals = {} + end + Block.locals[#Block.locals+1] = obj + LocalCount = LocalCount + 1 + if LocalCount > LocalLimit then + pushError { + type = 'LOCAL_LIMIT', + start = obj.start, + finish = obj.finish, + } + end + end + if obj.localfunction then + obj.localfunction = nil + end + Compile(obj.value, obj) + if obj.value and obj.value.special then + addSpecial(obj.value.special, obj) + end + end, + ['setfield'] = function (obj) + Compile(obj.node, obj) + Compile(obj.value, obj) + end, + ['do'] = function (obj) + local lastBlock = Block + Block = obj + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['return'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + if Block and Block[#Block] ~= obj then + pushError { + type = 'ACTION_AFTER_RETURN', + start = obj.start, + finish = obj.finish, + } + end + local func = guide.getParentFunction(obj) + if func then + if not func.returns then + func.returns = {} + end + func.returns[#func.returns+1] = obj + end + end, + ['label'] = function (obj) + local block = guide.getBlock(obj) + if block then + if not block.labels then + block.labels = {} + end + local name = obj[1] + local label = guide.getLabel(block, name) + if label then + if Version == 'Lua 5.4' + or block == guide.getBlock(label) then + pushError { + type = 'REDEFINED_LABEL', + start = obj.start, + finish = obj.finish, + relative = { + { + label.start, + label.finish, + } + } + } + end + end + block.labels[name] = obj + end + end, + ['goto'] = function (obj) + GoToTag[#GoToTag+1] = obj + end, + ['if'] = function (obj) + for i = 1, #obj do + Compile(obj[i], obj) + end + end, + ['ifblock'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.filter, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['elseifblock'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.filter, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['elseblock'] = function (obj) + local lastBlock = Block + Block = obj + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['loop'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.loc, obj) + Compile(obj.max, obj) + Compile(obj.step, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['in'] = function (obj) + local lastBlock = Block + Block = obj + local keys = obj.keys + for i = 1, #keys do + Compile(keys[i], obj) + end + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['while'] = function (obj) + local lastBlock = Block + Block = obj + Compile(obj.filter, obj) + CompileBlock(obj, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['repeat'] = function (obj) + local lastBlock = Block + Block = obj + CompileBlock(obj, obj) + Compile(obj.filter, obj) + if Block.locals then + LocalCount = LocalCount - #Block.locals + end + Block = lastBlock + end, + ['break'] = function (obj) + local block = guide.getBreakBlock(obj) + if block then + if not block.breaks then + block.breaks = {} + end + block.breaks[#block.breaks+1] = obj + else + pushError { + type = 'BREAK_OUTSIDE', + start = obj.start, + finish = obj.finish, + } + end + end, + ['main'] = function (obj) + Block = obj + Compile({ + type = 'local', + start = 0, + finish = 0, + effect = 0, + tag = '_ENV', + special= '_G', + [1] = ENVMode, + }, obj) + --- _ENV 是上值,不计入局部变量计数 + LocalCount = 0 + CompileBlock(obj, obj) + Block = nil + end, +} + +function CompileBlock(obj, parent) + for i = 1, #obj do + local act = obj[i] + local f = vmMap[act.type] + if f then + act.parent = parent + f(act) + end + end +end + +function Compile(obj, parent) + if not obj then + return nil + end + if Compiled[obj] then + return + end + Compiled[obj] = true + obj.parent = parent + local f = vmMap[obj.type] + if not f then + return + end + f(obj) +end + +local function compileGoTo(obj) + local name = obj[1] + local label = guide.getLabel(obj, name) + if not label then + pushError { + type = 'NO_VISIBLE_LABEL', + start = obj.start, + finish = obj.finish, + info = { + label = name, + } + } + return + end + if not label.ref then + label.ref = {} + end + label.ref[#label.ref+1] = obj + obj.node = label + + -- 如果有局部变量在 goto 与 label 之间声明, + -- 并在 label 之后使用,则算作语法错误 + + -- 如果 label 在 goto 之前声明,那么不会有中间声明的局部变量 + if obj.start > label.start then + return + end + + local block = guide.getBlock(obj) + local locals = block and block.locals + if not locals then + return + end + + for i = 1, #locals do + local loc = locals[i] + -- 检查局部变量声明位置为 goto 与 label 之间 + if loc.start < obj.start or loc.finish > label.finish then + goto CONTINUE + end + -- 检查局部变量的使用位置在 label 之后 + local refs = loc.ref + if not refs then + goto CONTINUE + end + for j = 1, #refs do + local ref = refs[j] + if ref.finish > label.finish then + pushError { + type = 'JUMP_LOCAL_SCOPE', + start = obj.start, + finish = obj.finish, + info = { + loc = loc[1], + }, + relative = { + { + start = label.start, + finish = label.finish, + }, + { + start = loc.start, + finish = loc.finish, + } + }, + } + return + end + end + ::CONTINUE:: + end +end + +local function PostCompile() + for i = 1, #GoToTag do + compileGoTo(GoToTag[i]) + end +end + +return function (self, lua, mode, version, options) + local state, err = self:parse(lua, mode, version) + if not state then + return nil, err + end + pushError = state.pushError + if version == 'Lua 5.1' or version == 'LuaJIT' then + ENVMode = '@fenv' + else + ENVMode = '_ENV' + end + Compiled = {} + GoToTag = {} + LocalCount = 0 + Version = version + Root = state.ast + Root.state = state + Options = options + state.ENVMode = ENVMode + if type(state.ast) == 'table' then + Compile(state.ast) + end + PostCompile() + Compiled = nil + GoToTag = nil + return state +end diff --git a/script/parser/grammar.lua b/script/parser/grammar.lua new file mode 100644 index 00000000..06dae246 --- /dev/null +++ b/script/parser/grammar.lua @@ -0,0 +1,538 @@ +local re = require 'parser.relabel' +local m = require 'lpeglabel' +local ast = require 'parser.ast' + +local scriptBuf = '' +local compiled = {} +local defs = ast.defs + +-- goto 可以作为名字,合法性之后处理 +local RESERVED = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +defs.nl = (m.P'\r\n' + m.S'\r\n') +defs.s = m.S' \t' +defs.S = - defs.s +defs.ea = '\a' +defs.eb = '\b' +defs.ef = '\f' +defs.en = '\n' +defs.er = '\r' +defs.et = '\t' +defs.ev = '\v' +defs['nil'] = m.Cp() / function () return nil end +defs['false'] = m.Cp() / function () return false end +defs.NotReserved = function (_, _, str) + if RESERVED[str] then + return false + end + return true +end +defs.Reserved = function (_, _, str) + if RESERVED[str] then + return true + end + return false +end +defs.None = function () end +defs.np = m.Cp() / function (n) return n+1 end + +m.setmaxstack(1000) + +local eof = re.compile '!. / %{SYNTAX_ERROR}' + +local function grammar(tag) + return function (script) + scriptBuf = script .. '\r\n' .. scriptBuf + compiled[tag] = re.compile(scriptBuf, defs) * eof + end +end + +local function errorpos(pos, err) + return { + type = 'UNKNOWN', + start = pos or 0, + finish = pos or 0, + err = err, + } +end + +grammar 'Comment' [[ +Comment <- LongComment + / '--' ShortComment +LongComment <- ('--[' {} {:eq: '='* :} {} '[' + {(!CommentClose .)*} + (CommentClose / {})) + -> LongComment + / ( + {} '/*' {} + (!'*/' .)* + {} '*/' {} + ) + -> CLongComment +CommentClose <- ']' =eq ']' +ShortComment <- ({} {(!%nl .)*} {}) + -> ShortComment +]] + +grammar 'Sp' [[ +Sp <- (Comment / %nl / %s)* +Sps <- (Comment / %nl / %s)+ +]] + +grammar 'Common' [[ +Word <- [a-zA-Z0-9_] +Cut <- !Word +X16 <- [a-fA-F0-9] +Rest <- (!%nl .)* + +AND <- Sp {'and'} Cut +BREAK <- Sp 'break' Cut +FALSE <- Sp 'false' Cut +GOTO <- Sp 'goto' Cut +LOCAL <- Sp 'local' Cut +NIL <- Sp 'nil' Cut +NOT <- Sp 'not' Cut +OR <- Sp {'or'} Cut +RETURN <- Sp 'return' Cut +TRUE <- Sp 'true' Cut + +DO <- Sp {} 'do' {} Cut + / Sp({} 'then' {} Cut) -> ErrDo +IF <- Sp {} 'if' {} Cut +ELSE <- Sp {} 'else' {} Cut +ELSEIF <- Sp {} 'elseif' {} Cut +END <- Sp {} 'end' {} Cut +FOR <- Sp {} 'for' {} Cut +FUNCTION <- Sp {} 'function' {} Cut +IN <- Sp {} 'in' {} Cut +REPEAT <- Sp {} 'repeat' {} Cut +THEN <- Sp {} 'then' {} Cut + / Sp({} 'do' {} Cut) -> ErrThen +UNTIL <- Sp {} 'until' {} Cut +WHILE <- Sp {} 'while' {} Cut + + +Esc <- '\' -> '' + EChar +EChar <- 'a' -> ea + / 'b' -> eb + / 'f' -> ef + / 'n' -> en + / 'r' -> er + / 't' -> et + / 'v' -> ev + / '\' + / '"' + / "'" + / %nl + / ('z' (%nl / %s)*) -> '' + / ({} 'x' {X16 X16}) -> Char16 + / ([0-9] [0-9]? [0-9]?) -> Char10 + / ('u{' {} {Word*} '}') -> CharUtf8 + -- 错误处理 + / 'x' {} -> MissEscX + / 'u' !'{' {} -> MissTL + / 'u{' Word* !'}' {} -> MissTR + / {} -> ErrEsc + +BOR <- Sp {'|'} +BXOR <- Sp {'~'} !'=' +BAND <- Sp {'&'} +Bshift <- Sp {BshiftList} +BshiftList <- '<<' + / '>>' +Concat <- Sp {'..'} +Adds <- Sp {AddsList} +AddsList <- '+' + / '-' +Muls <- Sp {MulsList} +MulsList <- '*' + / '//' + / '/' + / '%' +Unary <- Sp {} {UnaryList} +UnaryList <- NOT + / '#' + / '-' + / '~' !'=' +POWER <- Sp {'^'} + +BinaryOp <-( Sp {} {'or'} Cut + / Sp {} {'and'} Cut + / Sp {} {'<=' / '>=' / '<'!'<' / '>'!'>' / '~=' / '=='} + / Sp {} ({} '=' {}) -> ErrEQ + / Sp {} ({} '!=' {}) -> ErrUEQ + / Sp {} {'|'} + / Sp {} {'~'} + / Sp {} {'&'} + / Sp {} {'<<' / '>>'} + / Sp {} {'..'} !'.' + / Sp {} {'+' / '-'} + / Sp {} {'*' / '//' / '/' / '%'} + / Sp {} {'^'} + )-> BinaryOp +UnaryOp <-( Sp {} {'not' Cut / '#' / '~' !'=' / '-' !'-'} + )-> UnaryOp + +PL <- Sp '(' +PR <- Sp ')' +BL <- Sp '[' !'[' !'=' +BR <- Sp ']' +TL <- Sp '{' +TR <- Sp '}' +COMMA <- Sp ({} ',') + -> COMMA +SEMICOLON <- Sp ({} ';') + -> SEMICOLON +DOTS <- Sp ({} '...') + -> DOTS +DOT <- Sp ({} '.' !'.') + -> DOT +COLON <- Sp ({} ':' !':') + -> COLON +LABEL <- Sp '::' +ASSIGN <- Sp '=' !'=' +AssignOrEQ <- Sp ({} '==' {}) + -> ErrAssign + / Sp '=' + +DirtyBR <- BR / {} -> MissBR +DirtyTR <- TR / {} -> MissTR +DirtyPR <- PR / {} -> MissPR +DirtyLabel <- LABEL / {} -> MissLabel +NeedEnd <- END / {} -> MissEnd +NeedDo <- DO / {} -> MissDo +NeedAssign <- ASSIGN / {} -> MissAssign +NeedComma <- COMMA / {} -> MissComma +NeedIn <- IN / {} -> MissIn +NeedUntil <- UNTIL / {} -> MissUntil +NeedThen <- THEN / {} -> MissThen +]] + +grammar 'Nil' [[ +Nil <- Sp ({} -> Nil) NIL +]] + +grammar 'Boolean' [[ +Boolean <- Sp ({} -> True) TRUE + / Sp ({} -> False) FALSE +]] + +grammar 'String' [[ +String <- Sp ({} StringDef {}) + -> String +StringDef <- {'"'} + {~(Esc / !%nl !'"' .)*~} -> 1 + ('"' / {} -> MissQuote1) + / {"'"} + {~(Esc / !%nl !"'" .)*~} -> 1 + ("'" / {} -> MissQuote2) + / ('[' {} {:eq: '='* :} {} '[' %nl? + {(!StringClose .)*} -> 1 + (StringClose / {})) + -> LongString +StringClose <- ']' =eq ']' +]] + +grammar 'Number' [[ +Number <- Sp ({} {NumberDef} {}) -> Number + NumberSuffix? + ErrNumber? +NumberDef <- Number16 / Number10 +NumberSuffix<- ({} {[uU]? [lL] [lL]}) -> FFINumber + / ({} {[iI]}) -> ImaginaryNumber +ErrNumber <- ({} {([0-9a-zA-Z] / '.')+}) -> UnknownSymbol + +Number10 <- Float10 Float10Exp? + / Integer10 Float10? Float10Exp? +Integer10 <- [0-9]+ ('.' [0-9]*)? +Float10 <- '.' [0-9]+ +Float10Exp <- [eE] [+-]? [0-9]+ + / ({} [eE] [+-]? {}) -> MissExponent + +Number16 <- '0' [xX] Float16 Float16Exp? + / '0' [xX] Integer16 Float16? Float16Exp? +Integer16 <- X16+ ('.' X16*)? + / ({} {Word*}) -> MustX16 +Float16 <- '.' X16+ + / '.' ({} {Word*}) -> MustX16 +Float16Exp <- [pP] [+-]? [0-9]+ + / ({} [pP] [+-]? {}) -> MissExponent +]] + +grammar 'Name' [[ +Name <- Sp ({} NameBody {}) + -> Name +NameBody <- {[a-zA-Z_] [a-zA-Z0-9_]*} +FreeName <- Sp ({} {NameBody=>NotReserved} {}) + -> Name +KeyWord <- Sp NameBody=>Reserved +MustName <- Name / DirtyName +DirtyName <- {} -> DirtyName +]] + +grammar 'Exp' [[ +Exp <- (UnUnit BinUnit*) + -> Binary +BinUnit <- (BinaryOp UnUnit?) + -> SubBinary +UnUnit <- ExpUnit + / (UnaryOp+ (ExpUnit / MissExp)) + -> Unary +ExpUnit <- Nil + / Boolean + / String + / Number + / Dots + / Table + / Function + / Simple + +Simple <- {| Prefix (Sp Suffix)* |} + -> Simple +Prefix <- Sp ({} PL DirtyExp DirtyPR {}) + -> Paren + / Single +Single <- FreeName + -> Single +Suffix <- SuffixWithoutCall + / ({} PL SuffixCall DirtyPR {}) + -> Call +SuffixCall <- Sp ({} {| (COMMA / Exp)+ |} {}) + -> PackExpList + / %nil +SuffixWithoutCall + <- (DOT (Name / MissField)) + -> GetField + / ({} BL DirtyExp DirtyBR {}) + -> GetIndex + / (COLON (Name / MissMethod) NeedCall) + -> GetMethod + / ({} {| Table |} {}) + -> Call + / ({} {| String |} {}) + -> Call +NeedCall <- (!(Sp CallStart) {} -> MissPL)? +MissField <- {} -> MissField +MissMethod <- {} -> MissMethod +CallStart <- PL + / TL + / '"' + / "'" + / '[' '='* '[' + +DirtyExp <- Exp + / {} -> DirtyExp +MaybeExp <- Exp / MissExp +MissExp <- {} -> MissExp +ExpList <- Sp {| MaybeExp (Sp ',' MaybeExp)* |} + +Dots <- DOTS + -> VarArgs + +Table <- Sp ({} TL {| TableField* |} DirtyTR {}) + -> Table +TableField <- COMMA + / SEMICOLON + / NewIndex + / NewField + / Exp +Index <- BL DirtyExp DirtyBR +NewIndex <- Sp ({} Index NeedAssign DirtyExp {}) + -> NewIndex +NewField <- Sp ({} MustName ASSIGN DirtyExp {}) + -> NewField + +Function <- FunctionBody + -> Function +FuncArgs <- Sp ({} PL {| FuncArg+ |} DirtyPR {}) + -> FuncArgs + / PL DirtyPR %nil +FuncArgsMiss<- {} -> MissPL DirtyPR %nil +FuncArg <- DOTS + / Name + / COMMA +FunctionBody<- FUNCTION FuncArgs + {| (!END Action)* |} + NeedEnd + / FUNCTION FuncArgsMiss + {| %nil |} + NeedEnd + +-- 纯占位,修改了 `relabel.lua` 使重复定义不抛错 +Action <- !END . +]] + +grammar 'Action' [[ +Action <- Sp (CrtAction / UnkAction) +CrtAction <- Semicolon + / Do + / Break + / Return + / Label + / GoTo + / If + / For + / While + / Repeat + / NamedFunction + / LocalFunction + / Local + / Set + / Call + / ExpInAction +UnkAction <- ({} {Word+}) + -> UnknownAction + / ({} '//' {} (LongComment / ShortComment)) + -> CCommentPrefix + / ({} {. (!Sps !CrtAction .)*}) + -> UnknownAction +ExpInAction <- Sp ({} Exp {}) + -> ExpInAction + +Semicolon <- Sp ';' +SimpleList <- {| Simple (Sp ',' Simple)* |} + +Do <- Sp ({} + 'do' Cut + {| (!END Action)* |} + NeedEnd) + -> Do + +Break <- Sp ({} BREAK {}) + -> Break + +Return <- Sp ({} RETURN ReturnExpList {}) + -> Return +ReturnExpList + <- Sp {| Exp (Sp ',' MaybeExp)* |} + / Sp {| !Exp !',' |} + / ExpList + +Label <- Sp ({} LABEL MustName DirtyLabel {}) + -> Label + +GoTo <- Sp ({} GOTO MustName {}) + -> GoTo + +If <- Sp ({} {| IfHead IfBody* |} NeedEnd) + -> If + +IfHead <- Sp (IfPart {}) -> IfBlock + / Sp (ElseIfPart {}) -> ElseIfBlock + / Sp (ElsePart {}) -> ElseBlock +IfBody <- Sp (ElseIfPart {}) -> ElseIfBlock + / Sp (ElsePart {}) -> ElseBlock +IfPart <- IF DirtyExp NeedThen + {| (!ELSEIF !ELSE !END Action)* |} +ElseIfPart <- ELSEIF DirtyExp NeedThen + {| (!ELSEIF !ELSE !END Action)* |} +ElsePart <- ELSE + {| (!ELSEIF !ELSE !END Action)* |} + +For <- Loop / In + +Loop <- LoopBody + -> Loop +LoopBody <- FOR LoopArgs NeedDo + {} {| (!END Action)* |} + NeedEnd +LoopArgs <- MustName AssignOrEQ + ({} {| (COMMA / !DO !END Exp)* |} {}) + -> PackLoopArgs + +In <- InBody + -> In +InBody <- FOR InNameList NeedIn InExpList NeedDo + {} {| (!END Action)* |} + NeedEnd +InNameList <- ({} {| (COMMA / !IN !DO !END Name)* |} {}) + -> PackInNameList +InExpList <- ({} {| (COMMA / !DO !DO !END Exp)* |} {}) + -> PackInExpList + +While <- WhileBody + -> While +WhileBody <- WHILE DirtyExp NeedDo + {| (!END Action)* |} + NeedEnd + +Repeat <- (RepeatBody {}) + -> Repeat +RepeatBody <- REPEAT + {| (!UNTIL Action)* |} + NeedUntil DirtyExp + +LocalAttr <- {| (Sp '<' Sp MustName Sp LocalAttrEnd)+ |} + -> LocalAttr +LocalAttrEnd<- '>' / {} -> MissGT +Local <- Sp ({} LOCAL LocalNameList ((AssignOrEQ ExpList) / %nil) {}) + -> Local +Set <- Sp ({} SimpleList AssignOrEQ ExpList {}) + -> Set +LocalNameList + <- {| LocalName (Sp ',' LocalName)* |} +LocalName <- (MustName LocalAttr?) + -> LocalName + +Call <- Simple + -> SimpleCall + +LocalFunction + <- Sp ({} LOCAL FunctionNamedBody) + -> LocalFunction + +NamedFunction + <- FunctionNamedBody + -> NamedFunction +FunctionNamedBody + <- FUNCTION FuncName FuncArgs + {| (!END Action)* |} + NeedEnd + / FUNCTION FuncName FuncArgsMiss + {| %nil |} + NeedEnd +FuncName <- {| Single (Sp SuffixWithoutCall)* |} + -> Simple + / {} -> MissName %nil +]] + +grammar 'Lua' [[ +Lua <- Head? + ({} {| Action* |} {}) -> Lua + Sp +Head <- '#' (!%nl .)* +]] + +return function (self, lua, mode) + local gram = compiled[mode] or compiled['Lua'] + local r, _, pos = gram:match(lua) + if not r then + local err = errorpos(pos) + return nil, err + end + + return r +end diff --git a/script/parser/guide.lua b/script/parser/guide.lua new file mode 100644 index 00000000..6ef239f1 --- /dev/null +++ b/script/parser/guide.lua @@ -0,0 +1,3884 @@ +local util = require 'utility' +local error = error +local type = type +local next = next +local tostring = tostring +local print = print +local ipairs = ipairs +local tableInsert = table.insert +local tableUnpack = table.unpack +local tableRemove = table.remove +local tableMove = table.move +local tableSort = table.sort +local tableConcat = table.concat +local mathType = math.type +local pairs = pairs +local setmetatable = setmetatable +local assert = assert +local select = select +local osClock = os.clock +local DEVELOP = _G.DEVELOP +local log = log +local _G = _G + +local function logWarn(...) + log.warn(...) +end + +_ENV = nil + +local m = {} + +local blockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, + ['do'] = true, + ['function'] = true, + ['ifblock'] = true, + ['elseblock'] = true, + ['elseifblock'] = true, + ['main'] = true, +} + +local breakBlockTypes = { + ['while'] = true, + ['in'] = true, + ['loop'] = true, + ['repeat'] = true, +} + +m.childMap = { + ['main'] = {'#', 'docs'}, + ['repeat'] = {'#', 'filter'}, + ['while'] = {'filter', '#'}, + ['in'] = {'keys', '#'}, + ['loop'] = {'loc', 'max', 'step', '#'}, + ['if'] = {'#'}, + ['ifblock'] = {'filter', '#'}, + ['elseifblock'] = {'filter', '#'}, + ['elseblock'] = {'#'}, + ['setfield'] = {'node', 'field', 'value'}, + ['setglobal'] = {'value'}, + ['local'] = {'attrs', 'value'}, + ['setlocal'] = {'value'}, + ['return'] = {'#'}, + ['do'] = {'#'}, + ['select'] = {'vararg'}, + ['table'] = {'#'}, + ['tableindex'] = {'index', 'value'}, + ['tablefield'] = {'field', 'value'}, + ['function'] = {'args', '#'}, + ['funcargs'] = {'#'}, + ['setmethod'] = {'node', 'method', 'value'}, + ['getmethod'] = {'node', 'method'}, + ['setindex'] = {'node', 'index', 'value'}, + ['getindex'] = {'node', 'index'}, + ['paren'] = {'exp'}, + ['call'] = {'node', 'args'}, + ['callargs'] = {'#'}, + ['getfield'] = {'node', 'field'}, + ['list'] = {'#'}, + ['binary'] = {1, 2}, + ['unary'] = {1}, + + ['doc'] = {'#'}, + ['doc.class'] = {'class', 'extends'}, + ['doc.type'] = {'#types', '#enums', 'name'}, + ['doc.alias'] = {'alias', 'extends'}, + ['doc.param'] = {'param', 'extends'}, + ['doc.return'] = {'#returns'}, + ['doc.field'] = {'field', 'extends'}, + ['doc.generic'] = {'#generics'}, + ['doc.generic.object'] = {'generic', 'extends'}, + ['doc.vararg'] = {'vararg'}, + ['doc.type.table'] = {'key', 'value'}, + ['doc.type.function'] = {'#args', '#returns'}, + ['doc.overload'] = {'overload'}, +} + +m.actionMap = { + ['main'] = {'#'}, + ['repeat'] = {'#'}, + ['while'] = {'#'}, + ['in'] = {'#'}, + ['loop'] = {'#'}, + ['if'] = {'#'}, + ['ifblock'] = {'#'}, + ['elseifblock'] = {'#'}, + ['elseblock'] = {'#'}, + ['do'] = {'#'}, + ['function'] = {'#'}, + ['funcargs'] = {'#'}, +} + +local TypeSort = { + ['boolean'] = 1, + ['string'] = 2, + ['integer'] = 3, + ['number'] = 4, + ['table'] = 5, + ['function'] = 6, + ['nil'] = 999, +} + +local NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) + +--- 是否是字面量 +function m.isLiteral(obj) + local tp = obj.type + return tp == 'nil' + or tp == 'boolean' + or tp == 'string' + or tp == 'number' + or tp == 'table' +end + +--- 获取字面量 +function m.getLiteral(obj) + local tp = obj.type + if tp == 'boolean' then + return obj[1] + elseif tp == 'string' then + return obj[1] + elseif tp == 'number' then + return obj[1] + end + return nil +end + +--- 寻找父函数 +function m.getParentFunction(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + break + end + local tp = obj.type + if tp == 'function' or tp == 'main' then + return obj + end + end + return nil +end + +--- 寻找所在区块 +function m.getBlock(obj) + for _ = 1, 1000 do + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + obj = obj.parent + end + error('guide.getBlock overstack') +end + +--- 寻找所在父区块 +function m.getParentBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if blockTypes[tp] then + return obj + end + end + error('guide.getParentBlock overstack') +end + +--- 寻找所在可break的父区块 +function m.getBreakBlock(obj) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + local tp = obj.type + if breakBlockTypes[tp] then + return obj + end + if tp == 'function' then + return nil + end + end + error('guide.getBreakBlock overstack') +end + +--- 寻找doc的主体 +function m.getDocState(obj) + for _ = 1, 1000 do + local parent = obj.parent + if not parent then + return obj + end + if parent.type == 'doc' then + return obj + end + obj = parent + end + error('guide.getDocState overstack') +end + +--- 寻找所在父类型 +function m.getParentType(obj, want) + for _ = 1, 1000 do + obj = obj.parent + if not obj then + return nil + end + if want == obj.type then + return obj + end + end + error('guide.getParentType overstack') +end + +--- 寻找根区块 +function m.getRoot(obj) + for _ = 1, 1000 do + if obj.type == 'main' then + return obj + end + local parent = obj.parent + if not parent then + return nil + end + obj = parent + end + error('guide.getRoot overstack') +end + +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = m.getRoot(obj) + if root then + return root.uri + end + return '' +end + +function m.getENV(source, start) + if not start then + start = 1 + end + return m.getLocal(source, '_ENV', start) + or m.getLocal(source, '@fenv', start) +end + +--- 寻找函数的不定参数,返回不定参在第几个参数上,以及该参数对象。 +--- 如果函数是主函数,则返回`0, nil`。 +---@return table +---@return integer +function m.getFunctionVarArgs(func) + if func.type == 'main' then + return 0, nil + end + if func.type ~= 'function' then + return nil, nil + end + local args = func.args + if not args then + return nil, nil + end + for i = 1, #args do + local arg = args[i] + if arg.type == '...' then + return i, arg + end + end + return nil, nil +end + +--- 获取指定区块中可见的局部变量 +---@param block table +---@param name string {comment = '变量名'} +---@param pos integer {comment = '可见位置'} +function m.getLocal(block, name, pos) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local locals = block.locals + local res + if not locals then + goto CONTINUE + end + for i = 1, #locals do + local loc = locals[i] + if loc.effect > pos then + break + end + if loc[1] == name then + if not res or res.effect < loc.effect then + res = loc + end + end + end + if res then + return res, res + end + ::CONTINUE:: + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +--- 获取指定区块中所有的可见局部变量名称 +function m.getVisibleLocals(block, pos) + local result = {} + m.eachSourceContain(m.getRoot(block), pos, function (source) + local locals = source.locals + if locals then + for i = 1, #locals do + local loc = locals[i] + local name = loc[1] + if loc.effect <= pos then + result[name] = loc + end + end + end + end) + return result +end + +--- 获取指定区块中可见的标签 +---@param block table +---@param name string {comment = '标签名'} +function m.getLabel(block, name) + block = m.getBlock(block) + for _ = 1, 1000 do + if not block then + return nil + end + local labels = block.labels + if labels then + local label = labels[name] + if label then + return label + end + end + if block.type == 'function' then + return nil + end + block = m.getParentBlock(block) + end + error('guide.getLocal overstack') +end + +function m.getStartFinish(source) + local start = source.start + local finish = source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.start + finish = last.finish + end + return start, finish +end + +function m.getRange(source) + local start = source.vstart or source.start + local finish = source.range or source.finish + if not start then + local first = source[1] + if not first then + return nil, nil + end + local last = source[#source] + start = first.vstart or first.start + finish = last.range or last.finish + end + return start, finish +end + +--- 判断source是否包含offset +function m.isContain(source, offset) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +--- 判断offset在source的影响范围内 +--- +--- 主要针对赋值等语句时,key包含value +function m.isInRange(source, offset) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= offset and finish >= offset - 1 +end + +function m.isBetween(source, tStart, tFinish) + local start, finish = m.getStartFinish(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +function m.isBetweenRange(source, tStart, tFinish) + local start, finish = m.getRange(source) + if not start then + return false + end + return start <= tFinish and finish >= tStart - 1 +end + +--- 添加child +function m.addChilds(list, obj, map) + local keys = map[obj.type] + if keys then + for i = 1, #keys do + local key = keys[i] + if key == '#' then + for i = 1, #obj do + list[#list+1] = obj[i] + end + elseif obj[key] then + list[#list+1] = obj[key] + elseif type(key) == 'string' + and key:sub(1, 1) == '#' then + key = key:sub(2) + for i = 1, #obj[key] do + list[#list+1] = obj[key][i] + end + end + end + end +end + +--- 遍历所有包含offset的source +function m.eachSourceContain(ast, offset, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isInRange(obj, offset) then + if m.isContain(obj, offset) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有在某个范围内的source +function m.eachSourceBetween(ast, start, finish, callback) + local list = { ast } + local mark = {} + while true do + local len = #list + if len == 0 then + return + end + local obj = list[len] + list[len] = nil + if not mark[obj] then + mark[obj] = true + if m.isBetweenRange(obj, start, finish) then + if m.isBetween(obj, start, finish) then + local res = callback(obj) + if res ~= nil then + return res + end + end + m.addChilds(list, obj, m.childMap) + end + end + end +end + +--- 遍历所有指定类型的source +function m.eachSourceType(ast, type, callback) + local cache = ast.typeCache + if not cache then + cache = {} + ast.typeCache = cache + m.eachSource(ast, function (source) + local tp = source.type + if not tp then + return + end + local myCache = cache[tp] + if not myCache then + myCache = {} + cache[tp] = myCache + end + myCache[#myCache+1] = source + end) + end + local myCache = cache[type] + if not myCache then + return + end + for i = 1, #myCache do + callback(myCache[i]) + end +end + +--- 遍历所有的source +function m.eachSource(ast, callback) + local list = { ast } + local mark = {} + local index = 1 + while true do + local obj = list[index] + if not obj then + return + end + list[index] = false + index = index + 1 + if not mark[obj] then + mark[obj] = true + callback(obj) + m.addChilds(list, obj, m.childMap) + end + end +end + +--- 获取指定的 special +function m.eachSpecialOf(ast, name, callback) + local root = m.getRoot(ast) + if not root.specials then + return + end + local specials = root.specials[name] + if not specials then + return + end + for i = 1, #specials do + callback(specials[i]) + end +end + +--- 获取偏移对应的坐标 +---@param lines table +---@return integer {name = 'row'} +---@return integer {name = 'col'} +function m.positionOf(lines, offset) + if offset < 1 then + return 0, 0 + end + local lastLine = lines[#lines] + if offset > lastLine.finish then + return #lines, lastLine.finish - lastLine.start + 1 + end + local min = 1 + local max = #lines + for _ = 1, 100 do + if max <= min then + local line = lines[min] + return min, offset - line.start + 1 + end + local row = (max - min) // 2 + min + local line = lines[row] + if offset < line.start then + max = row - 1 + elseif offset > line.finish then + min = row + 1 + else + return row, offset - line.start + 1 + end + end + error('Stack overflow!') +end + +--- 获取坐标对应的偏移 +---@param lines table +---@param row integer +---@param col integer +---@return integer {name = 'offset'} +function m.offsetOf(lines, row, col) + if row < 1 then + return 0 + end + if row > #lines then + local lastLine = lines[#lines] + return lastLine.finish + end + local line = lines[row] + local len = line.finish - line.start + 1 + if col < 0 then + return line.start + elseif col > len then + return line.finish + else + return line.start + col - 1 + end +end + +function m.lineContent(lines, text, row, ignoreNL) + local line = lines[row] + if not line then + return '' + end + if ignoreNL then + return text:sub(line.start, line.range) + else + return text:sub(line.start, line.finish) + end +end + +function m.lineRange(lines, row, ignoreNL) + local line = lines[row] + if not line then + return 0, 0 + end + if ignoreNL then + return line.start, line.range + else + return line.start, line.finish + end +end + +function m.getNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'string' then + return obj[1] + end + return nil +end + +function m.getName(obj) + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + return obj.field and obj.field[1] + elseif tp == 'getmethod' + or tp == 'setmethod' then + return obj.method and obj.method[1] + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' then + return obj[1] + elseif tp == 'doc.class' then + return obj.class[1] + elseif tp == 'doc.alias' then + return obj.alias[1] + elseif tp == 'doc.field' then + return obj.field[1] + end + return m.getNameOfLiteral(obj) +end + +function m.getKeyNameOfLiteral(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'field' + or tp == 'method' then + return 's|' .. obj[1] + elseif tp == 'string' then + local s = obj[1] + if s then + return 's|' .. s + else + return 's' + end + elseif tp == 'number' then + local n = obj[1] + if n then + return ('n|%s'):format(util.viewLiteral(obj[1])) + else + return 'n' + end + elseif tp == 'boolean' then + local b = obj[1] + if b then + return 'b|' .. tostring(b) + else + return 'b' + end + end + return nil +end + +function m.getKeyName(obj) + if not obj then + return nil + end + local tp = obj.type + if tp == 'getglobal' + or tp == 'setglobal' then + return 's|' .. obj[1] + elseif tp == 'local' + or tp == 'getlocal' + or tp == 'setlocal' then + return 'l|' .. obj[1] + elseif tp == 'getfield' + or tp == 'setfield' + or tp == 'tablefield' then + if obj.field then + return 's|' .. obj.field[1] + end + elseif tp == 'getmethod' + or tp == 'setmethod' then + if obj.method then + return 's|' .. obj.method[1] + end + elseif tp == 'getindex' + or tp == 'setindex' + or tp == 'tableindex' then + return m.getKeyNameOfLiteral(obj.index) + elseif tp == 'field' + or tp == 'method' then + return 's|' .. obj[1] + elseif tp == 'doc.class' then + return 's|' .. obj.class[1] + elseif tp == 'doc.alias' then + return 's|' .. obj.alias[1] + elseif tp == 'doc.field' then + return 's|' .. obj.field[1] + end + return m.getKeyNameOfLiteral(obj) +end + +function m.getSimpleName(obj) + if obj.type == 'call' then + local key = obj.args and obj.args[2] + return m.getKeyName(key) + elseif obj.type == 'table' then + return ('t|%p'):format(obj) + elseif obj.type == 'select' then + return ('v|%p'):format(obj) + elseif obj.type == 'string' then + return ('z|%p'):format(obj) + elseif obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' then + return ('c|%s'):format(obj[1]) + elseif obj.type == 'doc.class' then + return ('c|%s'):format(obj.class[1]) + end + return m.getKeyName(obj) +end + +--- 测试 a 到 b 的路径(不经过函数,不考虑 goto), +--- 每个路径是一个 block 。 +--- +--- 如果 a 在 b 的前面,返回 `"before"` 加上 2个`list<block>` +--- +--- 如果 a 在 b 的后面,返回 `"after"` 加上 2个`list<block>` +--- +--- 否则返回 `false` +--- +--- 返回的2个 `list` 分别为基准block到达 a 与 b 的路径。 +---@param a table +---@param b table +---@return string|boolean mode +---@return table pathA? +---@return table pathB? +function m.getPath(a, b, sameFunction) + --- 首先测试双方在同一个函数内 + if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then + return false + end + local mode + local objA + local objB + if a.finish < b.start then + mode = 'before' + objA = a + objB = b + elseif a.start > b.finish then + mode = 'after' + objA = b + objB = a + else + return 'equal', {}, {} + end + local pathA = {} + local pathB = {} + for _ = 1, 1000 do + objA = m.getParentBlock(objA) + pathA[#pathA+1] = objA + if (not sameFunction and objA.type == 'function') or objA.type == 'main' then + break + end + end + for _ = 1, 1000 do + objB = m.getParentBlock(objB) + pathB[#pathB+1] = objB + if (not sameFunction and objA.type == 'function') or objB.type == 'main' then + break + end + end + -- pathA: {1, 2, 3, 4, 5} + -- pathB: {5, 6, 2, 3} + local top = #pathB + local start + for i = #pathA, 1, -1 do + local currentBlock = pathA[i] + if currentBlock == pathB[top] then + start = i + break + end + end + if not start then + return nil + end + -- pathA: { 1, 2, 3} + -- pathB: {5, 6, 2, 3} + local extra = 0 + local align = top - start + for i = start, 1, -1 do + local currentA = pathA[i] + local currentB = pathB[i+align] + if currentA ~= currentB then + extra = i + break + end + end + -- pathA: {1} + local resultA = {} + for i = extra, 1, -1 do + resultA[#resultA+1] = pathA[i] + end + -- pathB: {5, 6} + local resultB = {} + for i = extra + align, 1, -1 do + resultB[#resultB+1] = pathB[i] + end + return mode, resultA, resultB +end + +-- 根据语法,单步搜索定义 +local function stepRefOfLocal(loc, mode) + local results = {} + if loc.start ~= 0 then + results[#results+1] = loc + end + local refs = loc.ref + if not refs then + return results + end + for i = 1, #refs do + local ref = refs[i] + if ref.start == 0 then + goto CONTINUE + end + if mode == 'def' then + if ref.type == 'local' + or ref.type == 'setlocal' then + results[#results+1] = ref + end + else + if ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + end + end + ::CONTINUE:: + end + return results +end + +local function stepRefOfLabel(label, mode) + local results = { label } + if not label or mode == 'def' then + return results + end + local refs = label.ref + for i = 1, #refs do + local ref = refs[i] + results[#results+1] = ref + end + return results +end + +local function stepRefOfDocType(status, obj, mode) + local results = {} + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.alias.name' + or obj.type == 'doc.extends.name' then + local name = obj[1] + if not name or not status.interface.docType then + return results + end + local docs = status.interface.docType(name) + for i = 1, #docs do + local doc = docs[i] + if mode == 'def' then + if doc.type == 'doc.class.name' + or doc.type == 'doc.alias.name' then + results[#results+1] = doc + end + else + results[#results+1] = doc + end + end + else + results[#results+1] = obj + end + return results +end + +function m.getStepRef(status, obj, mode) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepRefOfLocal(obj.node, mode) + end + if obj.type == 'local' then + return stepRefOfLocal(obj, mode) + end + if obj.type == 'label' then + return stepRefOfLabel(obj, mode) + end + if obj.type == 'goto' then + return stepRefOfLabel(obj.node, mode) + end + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' + or obj.type == 'doc.extends.name' + or obj.type == 'doc.alias.name' then + return stepRefOfDocType(status, obj, mode) + end + return nil +end + +-- 根据语法,单步搜索field +local function stepFieldOfLocal(loc) + local results = {} + local refs = loc.ref + for i = 1, #refs do + local ref = refs[i] + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'getlocal' then + local nxt = ref.next + if nxt then + if nxt.type == 'setfield' + or nxt.type == 'getfield' + or nxt.type == 'setmethod' + or nxt.type == 'getmethod' + or nxt.type == 'setindex' + or nxt.type == 'getindex' then + results[#results+1] = nxt + end + end + end + end + return results +end +local function stepFieldOfTable(tbl) + local result = {} + for i = 1, #tbl do + result[i] = tbl[i] + end + return result +end +function m.getStepField(obj) + if obj.type == 'getlocal' + or obj.type == 'setlocal' then + return stepFieldOfLocal(obj.node) + end + if obj.type == 'local' then + return stepFieldOfLocal(obj) + end + if obj.type == 'table' then + return stepFieldOfTable(obj) + end +end + +local function convertSimpleList(list) + local simple = {} + for i = #list, 1, -1 do + local c = list[i] + if c.type == 'getglobal' + or c.type == 'setglobal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + local loc = c.node + if loc.special == '_G' then + simple.mode = 'global' + if not simple.node then + simple.node = c + end + else + simple.mode = 'local' + simple[#simple+1] = m.getSimpleName(loc) + if not simple.node then + simple.node = loc + end + end + elseif c.type == 'getlocal' + or c.type == 'setlocal' then + if c.special == '_G' then + simple.mode = 'global' + goto CONTINUE + end + simple.mode = 'local' + if not simple.node then + simple.node = c.node + end + elseif c.type == 'local' then + simple.mode = 'local' + if not simple.node then + simple.node = c + end + else + if not simple.node then + simple.node = c + end + end + simple[#simple+1] = m.getSimpleName(c) + ::CONTINUE:: + end + if simple.mode == 'global' and #simple == 0 then + simple[1] = 's|_G' + simple.node = list[#list] + end + return simple +end + +-- 搜索 `a.b.c` 的等价表达式 +local function buildSimpleList(obj, max) + local list = {} + local cur = obj + local limit = max and (max + 1) or 11 + for i = 1, max or limit do + if i == limit then + return nil + end + while cur.type == 'paren' do + cur = cur.exp + if not cur then + return nil + end + end + if cur.type == 'setfield' + or cur.type == 'getfield' + or cur.type == 'setmethod' + or cur.type == 'getmethod' + or cur.type == 'setindex' + or cur.type == 'getindex' then + list[i] = cur + cur = cur.node + elseif cur.type == 'tablefield' + or cur.type == 'tableindex' then + list[i] = cur + cur = cur.parent.parent + if cur.type == 'return' then + list[i+1] = list[i].parent + break + end + elseif cur.type == 'getlocal' + or cur.type == 'setlocal' + or cur.type == 'local' then + list[i] = cur + break + elseif cur.type == 'setglobal' + or cur.type == 'getglobal' then + list[i] = cur + break + elseif cur.type == 'select' + or cur.type == 'table' then + list[i] = cur + break + elseif cur.type == 'string' then + list[i] = cur + break + elseif cur.type == 'doc.class.name' + or cur.type == 'doc.type.name' + or cur.type == 'doc.class' then + list[i] = cur + break + elseif cur.type == 'function' + or cur.type == 'main' then + break + else + return nil + end + end + return convertSimpleList(list) +end + +function m.getSimple(obj, max) + local simpleList + if obj.type == 'getfield' + or obj.type == 'setfield' + or obj.type == 'getmethod' + or obj.type == 'setmethod' + or obj.type == 'getindex' + or obj.type == 'setindex' + or obj.type == 'local' + or obj.type == 'getlocal' + or obj.type == 'setlocal' + or obj.type == 'setglobal' + or obj.type == 'getglobal' + or obj.type == 'tablefield' + or obj.type == 'tableindex' + or obj.type == 'select' + or obj.type == 'table' + or obj.type == 'string' + or obj.type == 'doc.class.name' + or obj.type == 'doc.class' + or obj.type == 'doc.type.name' then + simpleList = buildSimpleList(obj, max) + elseif obj.type == 'field' + or obj.type == 'method' then + simpleList = buildSimpleList(obj.parent, max) + end + return simpleList +end + +function m.status(parentStatus, interface) + local status = { + cache = parentStatus and parentStatus.cache or { + count = 0, + }, + depth = parentStatus and (parentStatus.depth + 1) or 1, + interface = parentStatus and parentStatus.interface or {}, + locks = parentStatus and parentStatus.locks or {}, + deep = parentStatus and parentStatus.deep, + results = {}, + } + status.lock = status.locks[status.depth] or {} + status.locks[status.depth] = status.lock + if interface then + for k, v in pairs(interface) do + status.interface[k] = v + end + end + local searchDepth = status.interface.getSearchDepth and status.interface.getSearchDepth() or 0 + if status.depth >= searchDepth then + status.deep = false + end + return status +end + +function m.copyStatusResults(a, b) + local ra = a.results + local rb = b.results + for i = 1, #rb do + ra[#ra+1] = rb[i] + end +end + +function m.isGlobal(source) + if source.type == 'setglobal' + or source.type == 'getglobal' then + if source.node and source.node.tag == '_ENV' then + return true + end + end + if source.type == 'field' then + source = source.parent + end + if source.type == 'getfield' + or source.type == 'setfield' then + local node = source.node + if node and node.special == '_G' then + return true + end + end + return false +end + +function m.isDoc(source) + return source.type:sub(1, 4) == 'doc.' +end + +--- 根据函数的调用参数,获取:调用,参数索引 +function m.getCallAndArgIndex(callarg) + local callargs = callarg.parent + if not callargs or callargs.type ~= 'callargs' then + return nil + end + local index + for i = 1, #callargs do + if callargs[i] == callarg then + index = i + break + end + end + local call = callargs.parent + return call, index +end + +--- 根据函数调用的返回值,获取:调用的函数,参数列表,自己是第几个返回值 +function m.getCallValue(source) + local value = m.getObjectValue(source) or source + if not value then + return + end + local call, index + if value.type == 'call' then + call = value + index = 1 + elseif value.type == 'select' then + call = value.vararg + index = value.index + if call.type ~= 'call' then + return + end + else + return + end + return call.node, call.args, index +end + +function m.getNextRef(ref) + local nextRef = ref.next + if nextRef then + if nextRef.type == 'setfield' + or nextRef.type == 'getfield' + or nextRef.type == 'setmethod' + or nextRef.type == 'getmethod' + or nextRef.type == 'setindex' + or nextRef.type == 'getindex' then + return nextRef + end + end + -- 穿透 rawget 与 rawset + local call, index = m.getCallAndArgIndex(ref) + if call then + if call.node.special == 'rawset' and index == 1 then + return call + end + if call.node.special == 'rawget' and index == 1 then + return call + end + end + -- doc.type.array + if ref.type == 'doc.type' then + local arrays = {} + for _, typeUnit in ipairs(ref.types) do + if typeUnit.type == 'doc.type.array' then + arrays[#arrays+1] = typeUnit.node + end + end + -- 返回一个 dummy + -- TODO 用弱表维护唯一性? + return { + type = 'doc.type', + start = ref.start, + finish = ref.finish, + types = arrays, + parent = ref.parent, + array = true, + enums = {}, + resumes = {}, + } + end + + return nil +end + +function m.checkSameSimpleInValueOfTable(status, value, start, queue) + if value.type ~= 'table' then + return + end + for i = 1, #value do + local field = value[i] + queue[#queue+1] = { + obj = field, + start = start + 1, + } + end +end + +function m.searchFields(status, obj, key) + local simple = m.getSimple(obj) + if not simple then + return + end + simple[#simple+1] = key and ('s|' .. key) or '*' + m.searchSameFields(status, simple, 'field') + m.cleanResults(status.results) +end + +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args[3] + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +function m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + local newStatus = m.status(status) + m.searchFields(newStatus, mt, '__index') + local refsStatus = m.status(status) + for i = 1, #newStatus.results do + local indexValue = m.getObjectValue(newStatus.results[i]) + if indexValue then + m.searchRefs(refsStatus, indexValue, 'ref') + end + end + for i = 1, #refsStatus.results do + local obj = refsStatus.results[i] + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end +function m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + if not func or func.special ~= 'setmetatable' then + return + end + local call = func.parent + local args = call.args + local obj = args[1] + local mt = args[2] + if obj then + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + if mt then + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.checkSameSimpleInValueOfCallMetaTable(status, call, start, queue) + if call.type == 'call' then + m.checkSameSimpleInValueOfSetMetaTable(status, call.node, start, queue) + end +end + +function m.checkSameSimpleInSpecialBranch(status, obj, start, queue) + if not status.interface.index then + return + end + local results = status.interface.index(obj) + if not results then + return + end + for _, res in ipairs(results) do + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end +end + +function m.checkSameSimpleByDocType(status, doc) + if status.cache.searchingBindedDoc then + return + end + if doc.type ~= 'doc.type' then + return + end + local results = {} + for _, piece in ipairs(doc.types) do + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + results[#results+1] = res + end + end + return results +end + +function m.checkSameSimpleByBindDocs(status, obj, start, queue, mode) + if not obj.bindDocs then + return + end + if status.cache.searchingBindedDoc then + return + end + local skipInfer = false + local results = {} + for _, doc in ipairs(obj.bindDocs) do + if doc.type == 'doc.class' then + results[#results+1] = doc + elseif doc.type == 'doc.type' then + results[#results+1] = doc + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if obj.type == 'local' + and m.getName(obj) == doc.param[1] then + if obj.parent.type == 'funcargs' + or obj.parent.type == 'in' + or obj.parent.type == 'loop' then + results[#results+1] = doc.extends + end + end + elseif doc.type == 'doc.field' then + results[#results+1] = doc + end + end + for _, res in ipairs(results) do + if res.type == 'doc.class' + or res.type == 'doc.type' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + skipInfer = true + end + if res.type == 'doc.type.function' then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + elseif res.type == 'doc.field' then + queue[#queue+1] = { + obj = res, + start = start + 1, + } + end + end + return skipInfer +end + +function m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if status.cache.searchingBindedDoc then + return + end + if not obj.bindSources then + return + end + status.cache.searchingBindedDoc = true + local mark = {} + local newStatus = m.status(status) + for _, ref in ipairs(obj.bindSources) do + if not mark[ref] then + mark[ref] = true + m.searchRefs(newStatus, ref, mode) + end + end + status.cache.searchingBindedDoc = nil + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSameSimpleByDoc(status, obj, start, queue, mode) + if obj.type == 'doc.class.name' + or obj.type == 'doc.type.name' then + obj = m.getDocState(obj) + end + if obj.type == 'doc.class' then + local classStart + for _, doc in ipairs(obj.bindGroup) do + if doc == obj then + classStart = true + elseif doc.type == 'doc.class' then + classStart = false + end + if classStart and doc.type == 'doc.field' then + queue[#queue+1] = { + obj = doc, + start = start + 1, + } + end + end + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + if mode == 'ref' then + local pieceResult = stepRefOfDocType(status, obj.class, 'ref') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + return true + elseif obj.type == 'doc.type' then + for _, piece in ipairs(obj.types) do + local pieceResult = stepRefOfDocType(status, piece, 'def') + for _, res in ipairs(pieceResult) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end + if mode == 'ref' then + m.checkSameSimpleOfRefByDocSource(status, obj, start, queue, mode) + end + return true + elseif obj.type == 'doc.field' then + if mode ~= 'field' then + return m.checkSameSimpleByDoc(status, obj.extends, start, queue, mode) + end + end +end + +function m.checkSameSimpleInArg1OfSetMetaTable(status, obj, start, queue) + local args = obj.parent + if not args or args.type ~= 'callargs' then + return + end + if args[1] ~= obj then + return + end + local mt = args[2] + if mt then + if m.checkValueMark(status, obj, mt) then + return + end + m.checkSameSimpleInValueInMetaTable(status, mt, start, queue) + end +end + +function m.searchSameMethodCrossSelf(ref, mark) + local selfNode + if ref.tag == 'self' then + selfNode = ref + else + if ref.type == 'getlocal' + or ref.type == 'setlocal' then + local node = ref.node + if node.tag == 'self' then + selfNode = node + end + end + end + if selfNode then + if mark[selfNode] then + return nil + end + mark[selfNode] = true + return selfNode.method.node + end +end + +function m.searchSameMethod(ref, mark) + if mark['method'] then + return nil + end + local nxt = ref.next + if not nxt then + return nil + end + if nxt.type == 'setmethod' then + mark['method'] = true + return ref + end + return nil +end + +function m.searchSameFieldsCrossMethod(status, ref, start, queue) + local mark = status.cache.crossMethodMark + if not mark then + mark = {} + status.cache.crossMethodMark = mark + end + local method = m.searchSameMethod(ref, mark) + or m.searchSameMethodCrossSelf(ref, mark) + if not method then + return + end + local methodStatus = m.status(status) + m.searchRefs(methodStatus, method, 'ref') + for _, md in ipairs(methodStatus.results) do + queue[#queue+1] = { + obj = md, + start = start, + force = true, + } + local nxt = md.next + if not nxt then + goto CONTINUE + end + if nxt.type == 'setmethod' then + local func = nxt.value + if not func then + goto CONTINUE + end + local selfNode = func.locals and func.locals[1] + if not selfNode or not selfNode.ref then + goto CONTINUE + end + if mark[selfNode] then + goto CONTINUE + end + mark[selfNode] = true + for _, selfRef in ipairs(selfNode.ref) do + queue[#queue+1] = { + obj = selfRef, + start = start, + force = true, + } + end + end + ::CONTINUE:: + end +end + +local function checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, source, index, call) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local types = m.checkSameSimpleByDocType(status, rtn) + if not types then + return + end + for _, res in ipairs(types) do + results[#results+1] = res + end + return true +end + +local function checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local types = m.checkSameSimpleByDocType(status, rtn) + if types then + for _, res in ipairs(types) do + results[#results+1] = res + end + return true + end + end + end + end + end + end +end + +function m.checkSameSimpleInCallInSameFile(status, func, args, index) + local newStatus = m.status(status) + m.searchRefs(newStatus, func, 'def') + local results = {} + for _, def in ipairs(newStatus.results) do + local hasDocReturn = checkSameSimpleAndMergeDocTypeFunctionReturns(status, results, def, index) + or checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, def, index) + if not hasDocReturn then + local value = m.getObjectValue(def) or def + if value.type == 'function' then + local returns = value.returns + if returns then + for _, ret in ipairs(returns) do + local exp = ret[index] + if exp then + results[#results+1] = exp + end + end + end + end + end + end + return results +end + +function m.checkSameSimpleInCall(status, ref, start, queue, mode) + local func, args, index = m.getCallValue(ref) + if not func then + return + end + if m.checkCallMark(status, func.parent, true) then + return + end + status.cache.crossCallCount = status.cache.crossCallCount or 0 + if status.cache.crossCallCount >= 5 then + return + end + status.cache.crossCallCount = status.cache.crossCallCount + 1 + -- 检查赋值是 semetatable() 的情况 + m.checkSameSimpleInValueOfSetMetaTable(status, func, start, queue) + -- 检查赋值是 func() 的情况 + local objs = m.checkSameSimpleInCallInSameFile(status, func, args, index) + if status.interface.call then + local cobjs = status.interface.call(func, args, index) + if cobjs then + for _, obj in ipairs(cobjs) do + if not m.checkReturnMark(status, obj) then + objs[#objs+1] = obj + end + end + end + end + m.cleanResults(objs) + local newStatus = m.status(status) + for _, obj in ipairs(objs) do + m.searchRefs(newStatus, obj, mode) + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + status.cache.crossCallCount = status.cache.crossCallCount - 1 + for _, obj in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end +end + +local function searchRawset(ref, results) + if m.getKeyName(ref) ~= 's|rawset' then + return + end + local call = ref.parent + if call.type ~= 'call' or call.node ~= ref then + return + end + if not call.args then + return + end + local arg1 = call.args[1] + if arg1.special ~= '_G' then + -- 不会吧不会吧,不会真的有人写成 `rawset(_G._G._G, 'xxx', value)` 吧 + return + end + results[#results+1] = call +end + +local function searchG(ref, results) + while ref and m.getKeyName(ref) == 's|_G' do + results[#results+1] = ref + ref = ref.next + end + if ref then + results[#results+1] = ref + searchRawset(ref, results) + end +end + +local function searchEnvRef(ref, results) + if ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + searchG(ref, results) + elseif ref.type == 'getlocal' then + results[#results+1] = ref.next + searchG(ref.next, results) + end +end + +function m.findGlobals(ast) + local root = m.getRoot(ast) + local results = {} + local env = m.getENV(root) + if env.ref then + for _, ref in ipairs(env.ref) do + searchEnvRef(ref, results) + end + end + return results +end + +function m.findGlobalsOfName(ast, name) + local root = m.getRoot(ast) + local results = {} + local globals = m.findGlobals(root) + for _, global in ipairs(globals) do + if m.getKeyName(global) == name then + results[#results+1] = global + end + end + return results +end + +function m.checkSameSimpleInGlobal(status, name, source, start, queue) + if not name then + return + end + local objs + if status.interface.global then + objs = status.interface.global(name) + else + objs = m.findGlobalsOfName(source, name) + end + if objs then + for _, obj in ipairs(objs) do + queue[#queue+1] = { + obj = obj, + start = start, + force = true, + } + end + end +end + +function m.checkValueMark(status, a, b) + if not status.cache.valueMark then + status.cache.valueMark = {} + end + if status.cache.valueMark[a] + or status.cache.valueMark[b] then + return true + end + status.cache.valueMark[a] = true + status.cache.valueMark[b] = true + return false +end + +function m.checkCallMark(status, a, mark) + if not status.cache.callMark then + status.cache.callMark = {} + end + if mark then + status.cache.callMark[a] = mark + else + return status.cache.callMark[a] + end + return false +end + +function m.checkReturnMark(status, a, mark) + if not status.cache.returnMark then + status.cache.returnMark = {} + end + if mark then + status.cache.returnMark[a] = mark + else + return status.cache.returnMark[a] + end + return false +end + +function m.searchSameFieldsInValue(status, ref, start, queue, mode) + local value = m.getObjectValue(ref) + if not value then + return + end + if m.checkValueMark(status, ref, value) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, value, mode) + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + queue[#queue+1] = { + obj = value, + start = start, + force = true, + } + -- 检查形如 a = f() 的分支情况 + m.checkSameSimpleInCall(status, value, start, queue, mode) +end + +function m.checkSameSimpleAsTableField(status, ref, start, queue) + if not status.deep then + --return + end + local parent = ref.parent + if not parent or parent.type ~= 'tablefield' then + return + end + if m.checkValueMark(status, parent, ref) then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, parent.field, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSearchLevel(status) + status.cache.back = status.cache.back or 0 + if status.cache.back >= (status.interface.searchLevel or 0) then + -- TODO 限制向前搜索的次数 + --return true + end + status.cache.back = status.cache.back + 1 + return false +end + +function m.checkSameSimpleAsReturn(status, ref, start, queue) + if not status.deep then + return + end + if not ref.parent or ref.parent.type ~= 'return' then + return + end + if ref.parent.parent.type ~= 'main' then + return + end + if m.checkSearchLevel(status) then + return + end + local newStatus = m.status(status) + m.searchRefsAsFunctionReturn(newStatus, ref, 'ref') + for _, res in ipairs(newStatus.results) do + if not m.checkCallMark(status, res) then + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end + end +end + +function m.checkSameSimpleAsSetValue(status, ref, start, queue) + if ref.type == 'select' then + return + end + local parent = ref.parent + if not parent then + return + end + if m.getObjectValue(parent) ~= ref then + return + end + if m.checkValueMark(status, ref, parent) then + return + end + if m.checkSearchLevel(status) then + return + end + local obj + if parent.type == 'local' + or parent.type == 'setglobal' + or parent.type == 'setlocal' then + obj = parent + elseif parent.type == 'setfield' then + obj = parent.field + elseif parent.type == 'setmethod' then + obj = parent.method + end + if not obj then + return + end + local newStatus = m.status(status) + m.searchRefs(newStatus, obj, 'ref') + for _, res in ipairs(newStatus.results) do + queue[#queue+1] = { + obj = res, + start = start, + force = true, + } + end +end + +function m.checkSameSimpleInString(status, ref, start, queue, mode) + -- 特殊处理 ('xxx').xxx 的形式 + if ref.type ~= 'string' then + return + end + if not status.interface.docType then + return + end + if status.cache.searchingBindedDoc then + return + end + local newStatus = m.status(status) + local docs = status.interface.docType('string*') + local mark = {} + for i = 1, #docs do + local doc = docs[i] + m.searchFields(newStatus, doc) + end + for _, res in ipairs(newStatus.results) do + if mark[res] then + goto CONTINUE + end + mark[res] = true + queue[#queue+1] = { + obj = res, + start = start + 1, + } + ::CONTINUE:: + end + return true +end + +function m.pushResult(status, mode, ref, simple) + local results = status.results + if mode == 'def' then + if ref.type == 'setglobal' + or ref.type == 'setlocal' + or ref.type == 'local' then + results[#results+1] = ref + elseif ref.type == 'setfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' then + results[#results+1] = ref + end + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + if m.getParentFunction(ref) ~= m.getParentFunction(simple.node) then + results[#results+1] = ref + end + end + elseif mode == 'ref' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' + or ref.type == 'local' + or ref.type == 'setlocal' + or ref.type == 'getlocal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + if ref.parent and ref.parent.type == 'return' then + results[#results+1] = ref + end + elseif mode == 'field' then + if ref.type == 'setfield' + or ref.type == 'getfield' + or ref.type == 'tablefield' then + results[#results+1] = ref + elseif ref.type == 'setmethod' + or ref.type == 'getmethod' then + results[#results+1] = ref + elseif ref.type == 'setindex' + or ref.type == 'getindex' + or ref.type == 'tableindex' then + results[#results+1] = ref + elseif ref.type == 'setglobal' + or ref.type == 'getglobal' then + results[#results+1] = ref + elseif ref.type == 'function' then + results[#results+1] = ref + elseif ref.type == 'table' then + results[#results+1] = ref + elseif ref.type == 'call' then + if ref.node.special == 'rawset' + or ref.node.special == 'rawget' then + results[#results+1] = ref + end + elseif ref.type == 'doc.type.function' + or ref.type == 'doc.class.name' + or ref.type == 'doc.field' then + results[#results+1] = ref + end + end +end + +function m.checkSameSimpleName(ref, sm) + if sm == '*' then + return true + end + if m.getSimpleName(ref) == sm then + return true + end + if ref.type == 'doc.type' + and ref.array == true then + return true + end + return false +end + +function m.checkSameSimple(status, simple, data, mode, queue) + local ref = data.obj + local start = data.start + local force = data.force + if start > #simple then + return + end + for i = start, #simple do + local sm = simple[i] + if not force and not m.checkSameSimpleName(ref, sm) then + return + end + force = false + local cmode = mode + if i < #simple then + cmode = 'ref' + end + -- 检查 doc + local skipInfer = m.checkSameSimpleByBindDocs(status, ref, i, queue, cmode) + or m.checkSameSimpleByDoc(status, ref, i, queue, cmode) + if not skipInfer then + -- 穿透 self:func 与 mt:func + m.searchSameFieldsCrossMethod(status, ref, i, queue) + -- 穿透赋值 + m.searchSameFieldsInValue(status, ref, i, queue, cmode) + -- 检查自己是字面量表的情况 + m.checkSameSimpleInValueOfTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 第一个参数的情况 + m.checkSameSimpleInArg1OfSetMetaTable(status, ref, i, queue) + -- 检查自己作为 setmetatable 调用的情况 + m.checkSameSimpleInValueOfCallMetaTable(status, ref, i, queue) + -- 检查自己是特殊变量的分支的情况 + m.checkSameSimpleInSpecialBranch(status, ref, i, queue) + -- 检查自己是字面量字符串的分支情况 + m.checkSameSimpleInString(status, ref, i, queue, cmode) + if cmode == 'ref' then + -- 检查形如 { a = f } 的情况 + m.checkSameSimpleAsTableField(status, ref, i, queue) + -- 检查形如 return m 的情况 + m.checkSameSimpleAsReturn(status, ref, i, queue) + -- 检查形如 a = f 的情况 + m.checkSameSimpleAsSetValue(status, ref, i, queue) + end + end + if i == #simple then + break + end + ref = m.getNextRef(ref) + if not ref then + return + end + end + m.pushResult(status, mode, ref, simple) + local value = m.getObjectValue(ref) + if value then + m.pushResult(status, mode, value, simple) + end +end + +function m.searchSameFields(status, simple, mode) + local queue = {} + if simple.mode == 'global' then + -- 全局变量开头 + m.checkSameSimpleInGlobal(status, simple[1], simple.node, 1, queue) + elseif simple.mode == 'local' then + -- 局部变量开头 + queue[1] = { + obj = simple.node, + start = 1, + } + local refs = simple.node.ref + if refs then + for i = 1, #refs do + queue[#queue+1] = { + obj = refs[i], + start = 1, + } + end + end + else + queue[1] = { + obj = simple.node, + start = 1, + } + end + local max = 0 + for i = 1, 1e6 do + local data = queue[i] + if not data then + return + end + if not status.lock[data.obj] then + status.lock[data.obj] = true + max = max + 1 + status.cache.count = status.cache.count + 1 + m.checkSameSimple(status, simple, data, mode, queue) + if max >= 10000 then + logWarn('Queue too large!') + break + end + end + end +end + +function m.getCallerInSameFile(status, func) + -- 搜索所有所在函数的调用者 + local funcRefs = m.status(status) + m.searchRefOfValue(funcRefs, func) + + local calls = {} + if #funcRefs.results == 0 then + return calls + end + for _, res in ipairs(funcRefs.results) do + local call = res.parent + if call.type == 'call' then + calls[#calls+1] = call + end + end + return calls +end + +function m.getCallerCrossFiles(status, main) + if status.interface.link then + return status.interface.link(main.uri) + end + return {} +end + +function m.searchRefsAsFunctionReturn(status, obj, mode) + if mode == 'def' then + return + end + if m.checkReturnMark(status, obj, true) then + return + end + status.results[#status.results+1] = obj + -- 搜索所在函数 + local currentFunc = m.getParentFunction(obj) + local rtn = obj.parent + if rtn.type ~= 'return' then + return + end + -- 看看他是第几个返回值 + local index + for i = 1, #rtn do + if obj == rtn[i] then + index = i + break + end + end + if not index then + return + end + local calls + if currentFunc.type == 'main' then + calls = m.getCallerCrossFiles(status, currentFunc) + else + calls = m.getCallerInSameFile(status, currentFunc) + end + -- 搜索调用者的返回值 + if #calls == 0 then + return + end + local selects = {} + for i = 1, #calls do + local parent = calls[i].parent + if parent.type == 'select' and parent.index == index then + selects[#selects+1] = parent.parent + end + local extParent = calls[i].extParent + if extParent then + for j = 1, #extParent do + local ext = extParent[j] + if ext.type == 'select' and ext.index == index then + selects[#selects+1] = ext.parent + end + end + end + end + -- 搜索调用者的引用 + for i = 1, #selects do + m.searchRefs(status, selects[i], 'ref') + end +end + +function m.searchRefsAsFunctionSet(status, obj, mode) + local parent = obj.parent + if not parent then + return + end + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' + or parent.type == 'setfield' + or parent.type == 'setmethod' + or parent.type == 'tablefield' then + m.searchRefs(status, parent, mode) + elseif parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == obj then + m.searchRefs(status, parent, mode) + end + end +end + +function m.searchRefsAsFunction(status, obj, mode) + if obj.type ~= 'function' + and obj.type ~= 'table' then + return + end + m.searchRefsAsFunctionSet(status, obj, mode) + -- 检查自己作为返回函数时的引用 + m.searchRefsAsFunctionReturn(status, obj, mode) +end + +function m.cleanResults(results) + local mark = {} + for i = #results, 1, -1 do + local res = results[i] + if res.tag == 'self' + or mark[res] then + results[i] = results[#results] + results[#results] = nil + else + mark[res] = true + end + end +end + +--function m.getRefCache(status, obj, mode) +-- local cache = status.interface.cache and status.interface.cache() +-- if not cache then +-- return +-- end +-- if m.isGlobal(obj) then +-- obj = m.getKeyName(obj) +-- end +-- if not cache[mode] then +-- cache[mode] = {} +-- end +-- local sourceCache = cache[mode][obj] +-- if sourceCache then +-- return sourceCache +-- end +-- sourceCache = {} +-- cache[mode][obj] = sourceCache +-- return nil, function (results) +-- for i = 1, #results do +-- sourceCache[i] = results[i] +-- end +-- end +--end + +function m.getRefCache(status, obj, mode) + local cache, globalCache + if status.depth == 1 + and status.deep then + globalCache = status.interface.cache and status.interface.cache() or {} + end + cache = status.cache.refCache or {} + status.cache.refCache = cache + if m.isGlobal(obj) then + obj = m.getKeyName(obj) + end + if not cache[mode] then + cache[mode] = {} + end + if globalCache and not globalCache[mode] then + globalCache[mode] = {} + end + local sourceCache = globalCache and globalCache[mode][obj] or cache[mode][obj] + if sourceCache then + return sourceCache + end + sourceCache = {} + cache[mode][obj] = sourceCache + if globalCache then + globalCache[mode][obj] = sourceCache + end + return nil, function (results) + for i = 1, #results do + sourceCache[i] = results[i] + end + end +end + +function m.searchRefs(status, obj, mode) + local cache, makeCache = m.getRefCache(status, obj, mode) + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + -- 检查单步引用 + local res = m.getStepRef(status, obj, mode) + if res then + for i = 1, #res do + status.results[#status.results+1] = res[i] + end + end + -- 检查simple + if status.depth <= 100 then + local simple = m.getSimple(obj) + if simple then + m.searchSameFields(status, simple, mode) + end + else + if m.debugMode then + error('status.depth overflow') + elseif DEVELOP then + --log.warn(debug.traceback('status.depth overflow')) + logWarn('status.depth overflow') + end + end + + m.cleanResults(status.results) + + if makeCache then + makeCache(status.results) + end +end + +function m.searchRefOfValue(status, obj) + local var = obj.parent + if var.type == 'local' + or var.type == 'set' then + return m.searchRefs(status, var, 'ref') + end +end + +function m.allocInfer(o) + if type(o.type) == 'table' then + local infers = {} + for i = 1, #o.type do + infers[i] = { + type = o.type[i], + value = o.value, + source = o.source, + } + end + return infers + else + return { + [1] = o, + } + end +end + +function m.mergeTypes(types) + local results = {} + local mark = {} + local hasAny + -- 这里把 any 去掉 + for i = 1, #types do + local tp = types[i] + if tp == 'any' then + hasAny = true + end + if not mark[tp] and tp ~= 'any' then + mark[tp] = true + results[#results+1] = tp + end + end + if #results == 0 then + return 'any' + end + -- 只有显性的 nil 与 any 时,取 any + if #results == 1 then + if results[1] == 'nil' and hasAny then + return 'any' + else + return results[1] + end + end + -- 同时包含 number 与 integer 时,去掉 integer + if mark['number'] and mark['integer'] then + for i = 1, #results do + if results[i] == 'integer' then + tableRemove(results, i) + break + end + end + end + tableSort(results, function (a, b) + local sa = TypeSort[a] or 100 + local sb = TypeSort[b] or 100 + return sa < sb + end) + return tableConcat(results, '|') +end + +function m.viewInferType(infers) + if not infers then + return 'any' + end + local mark = {} + local types = {} + local hasDoc + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' then + if infer.type ~= 'any' then + hasDoc = true + break + end + end + end + if hasDoc then + for i = 1, #infers do + local infer = infers[i] + local src = infer.source + if src.type == 'doc.class' + or src.type == 'doc.class.name' + or src.type == 'doc.type.name' + or src.type == 'doc.type.array' + or src.type == 'doc.type.generic' + or src.type == 'doc.type.enum' + or src.type == 'doc.resume' then + local tp = infer.type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + end + end + else + for i = 1, #infers do + local tp = infers[i].type or 'any' + if not mark[tp] then + types[#types+1] = tp + end + mark[tp] = true + end + end + return m.mergeTypes(types) +end + +function m.checkTrue(status, source) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + -- 当前认为的结果 + local current + for _, infer in ipairs(newStatus.results) do + -- 新的结果 + local new + if infer.type == 'nil' then + new = false + elseif infer.type == 'boolean' then + if infer.value == true then + new = true + elseif infer.value == false then + new = false + end + end + if new ~= nil then + if current == nil then + current = new + else + -- 如果2个结果完全相反,则返回 nil 表示不确定 + if new ~= current then + return nil + end + end + end + end + return current +end + +--- 获取特定类型的字面量值 +function m.getInferLiteral(status, source, type) + local newStatus = m.status(status) + m.searchInfer(newStatus, source) + for _, infer in ipairs(newStatus.results) do + if infer.value ~= nil then + if type == nil or infer.type == type then + return infer.value + end + end + end + return nil +end + +--- 是否包含某种类型 +function m.hasType(status, source, type) + m.searchInfer(status, source) + for _, infer in ipairs(status.results) do + if infer.type == type then + return true + end + end + return false +end + +function m.isSameValue(status, a, b) + local statusA = m.status(status) + m.searchInfer(statusA, a) + local statusB = m.status(status) + m.searchInfer(statusB, b) + local infers = {} + for _, infer in ipairs(statusA.results) do + local literal = infer.value + if literal then + infers[literal] = false + end + end + for _, infer in ipairs(statusB.results) do + local literal = infer.value + if literal then + if infers[literal] == nil then + return false + end + infers[literal] = true + end + end + for k, v in pairs(infers) do + if v == false then + return false + end + end + return true +end + +function m.inferCheckLiteralTableWithDocVararg(status, source) + if #source ~= 1 then + return + end + local vararg = source[1] + if vararg.type ~= 'varargs' then + return + end + local results = m.getVarargDocType(status, source) + status.results[#status.results+1] = { + type = m.viewInferType(results) .. '[]', + source = source, + } + return true +end + +function m.inferCheckLiteral(status, source) + if source.type == 'string' then + status.results = m.allocInfer { + type = 'string', + value = source[1], + source = source, + } + return true + elseif source.type == 'nil' then + status.results = m.allocInfer { + type = 'nil', + value = NIL, + source = source, + } + return true + elseif source.type == 'boolean' then + status.results = m.allocInfer { + type = 'boolean', + value = source[1], + source = source, + } + return true + elseif source.type == 'number' then + if mathType(source[1]) == 'integer' then + status.results = m.allocInfer { + type = 'integer', + value = source[1], + source = source, + } + return true + else + status.results = m.allocInfer { + type = 'number', + value = source[1], + source = source, + } + return true + end + elseif source.type == 'integer' then + status.results = m.allocInfer { + type = 'integer', + source = source, + } + return true + elseif source.type == 'table' then + if m.inferCheckLiteralTableWithDocVararg(status, source) then + return true + end + status.results = m.allocInfer { + type = 'table', + source = source, + } + return true + elseif source.type == 'function' then + status.results = m.allocInfer { + type = 'function', + source = source, + } + return true + elseif source.type == '...' then + status.results = m.allocInfer { + type = '...', + source = source, + } + return true + end +end + +local function getDocAliasExtends(status, name) + if not status.interface.docType then + return nil + end + for _, doc in ipairs(status.interface.docType(name)) do + if doc.type == 'doc.alias.name' then + return m.viewInferType(m.getDocTypeNames(status, doc.parent.extends)) + end + end + return nil +end + +local function getDocTypeUnitName(status, unit, genericCallback) + local typeName + if unit.type == 'doc.type.name' then + typeName = getDocAliasExtends(status, unit[1]) or unit[1] + elseif unit.type == 'doc.type.function' then + typeName = 'function' + elseif unit.type == 'doc.type.array' then + typeName = getDocTypeUnitName(status, unit.node, genericCallback) .. '[]' + elseif unit.type == 'doc.type.generic' then + typeName = ('%s<%s, %s>'):format( + getDocTypeUnitName(status, unit.node, genericCallback), + m.viewInferType(m.getDocTypeNames(status, unit.key, genericCallback)), + m.viewInferType(m.getDocTypeNames(status, unit.value, genericCallback)) + ) + end + if unit.typeGeneric then + if genericCallback then + typeName = genericCallback(typeName, unit) + or ('<%s>'):format(typeName) + else + typeName = ('<%s>'):format(typeName) + end + end + return typeName +end + +function m.getDocTypeNames(status, doc, genericCallback) + local results = {} + if not doc then + return results + end + for _, unit in ipairs(doc.types) do + local typeName = getDocTypeUnitName(status, unit, genericCallback) + results[#results+1] = { + type = typeName, + source = unit, + } + end + for _, enum in ipairs(doc.enums) do + results[#results+1] = { + type = enum[1], + source = enum, + } + end + for _, resume in ipairs(doc.resumes) do + if not resume.additional then + results[#results+1] = { + type = resume[1], + source = resume, + } + end + end + return results +end + +function m.inferCheckDoc(status, source) + if source.type == 'doc.class.name' then + status.results[#status.results+1] = { + type = source[1], + source = source, + } + return true + end + if source.type == 'doc.class' then + status.results[#status.results+1] = { + type = source.class[1], + source = source, + } + return true + end + if source.type == 'doc.type' then + local results = m.getDocTypeNames(status, source) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + if source.type == 'doc.field' then + local results = m.getDocTypeNames(status, source.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end +end + +function m.getVarargDocType(status, source) + local func = m.getParentFunction(source) + if not func then + return + end + if not func.args then + return + end + for _, arg in ipairs(func.args) do + if arg.type == '...' then + if arg.bindDocs then + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.vararg' then + return m.getDocTypeNames(status, doc.vararg) + end + end + end + end + end +end + +function m.inferCheckUpDocOfVararg(status, source) + if not source.vararg then + return + end + local results = m.getVarargDocType(status, source) + if not results then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +function m.inferCheckUpDoc(status, source) + if m.inferCheckUpDocOfVararg(status, source) then + return true + end + local parent = source.parent + if parent then + if parent.type == 'local' + or parent.type == 'setlocal' + or parent.type == 'setglobal' then + source = parent + end + if parent.type == 'setfield' + or parent.type == 'tablefield' then + if parent.field == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setmethod' then + if parent.method == source + or parent.value == source then + source = parent + end + end + if parent.type == 'setindex' + or parent.type == 'tableindex' then + if parent.index == source + or parent.value == source then + source = parent + end + end + end + local binds = source.bindDocs + if not binds then + return + end + status.results = {} + for _, doc in ipairs(binds) do + if doc.type == 'doc.class' then + status.results[#status.results+1] = { + type = doc.class[1], + source = doc, + } + -- ---@class Class + -- local x = { field = 1 } + -- 这种情况下,将字面量表接受为Class的定义 + if source.value and source.value.type == 'table' then + status.results[#status.results+1] = { + type = source.value.type, + source = source.value, + } + end + return true + elseif doc.type == 'doc.type' then + local results = m.getDocTypeNames(status, doc) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + elseif doc.type == 'doc.param' then + -- function (x) 的情况 + if source.type == 'local' + and m.getName(source) == doc.param[1] then + if source.parent.type == 'funcargs' + or source.parent.type == 'in' + or source.parent.type == 'loop' then + local results = m.getDocTypeNames(status, doc.extends) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true + end + end + end + end +end + +function m.inferCheckFieldDoc(status, source) + -- 检查 string[] 的情况 + if source.type == 'getindex' then + local node = source.node + if not node then + return + end + local newStatus = m.status(status) + m.searchInfer(newStatus, node) + local ok + for _, infer in ipairs(newStatus.results) do + local src = infer.source + if src.type == 'doc.type.array' then + ok = true + status.results[#status.results+1] = { + type = infer.type:gsub('%[%]$', ''), + source = src.node, + } + end + end + return ok + end +end + +function m.inferCheckUnary(status, source) + if source.type ~= 'unary' then + return + end + local op = source.op + if op.type == 'not' then + local checkTrue = m.checkTrue(status, source[1]) + local value = nil + if checkTrue == true then + value = false + elseif checkTrue == false then + value = true + end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '#' then + status.results = m.allocInfer { + type = 'integer', + source = source, + } + return true + elseif op.type == '~' then + local l = m.getInferLiteral(status, source[1], 'integer') + status.results = m.allocInfer { + type = 'integer', + value = l and ~l or nil, + source = source, + } + return true + elseif op.type == '-' then + local v = m.getInferLiteral(status, source[1], 'integer') + if v then + status.results = m.allocInfer { + type = 'integer', + value = - v, + source = source, + } + return true + end + v = m.getInferLiteral(status, source[1], 'number') + status.results = m.allocInfer { + type = 'number', + value = v and -v or nil, + source = source, + } + return true + end +end + +local function mathCheck(status, a, b) + local v1 = m.getInferLiteral(status, a, 'integer') + or m.getInferLiteral(status, a, 'number') + local v2 = m.getInferLiteral(status, b, 'integer') + or m.getInferLiteral(status, a, 'number') + local int = m.hasType(status, a, 'integer') + and m.hasType(status, b, 'integer') + and not m.hasType(status, a, 'number') + and not m.hasType(status, b, 'number') + return int and 'integer' or 'number', v1, v2 +end + +function m.inferCheckBinary(status, source) + if source.type ~= 'binary' then + return + end + local op = source.op + if op.type == 'and' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[2]) + return true + elseif isTrue == false then + m.searchInfer(status, source[1]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == 'or' then + local isTrue = m.checkTrue(status, source[1]) + if isTrue == true then + m.searchInfer(status, source[1]) + return true + elseif isTrue == false then + m.searchInfer(status, source[2]) + return true + else + m.searchInfer(status, source[1]) + m.searchInfer(status, source[2]) + return true + end + elseif op.type == '==' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = true + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '~=' then + local value = m.isSameValue(status, source[1], source[2]) + if value ~= nil then + status.results = m.allocInfer { + type = 'boolean', + value = not value, + source = source, + } + return true + end + --local isSame = m.isSameDef(status, source[1], source[2]) + --if isSame == true then + -- value = false + --else + -- value = nil + --end + status.results = m.allocInfer { + type = 'boolean', + value = value, + source = source, + } + return true + elseif op.type == '<=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 <= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '>=' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 >= v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 < v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'boolean', + value = v, + source = source, + } + return true + elseif op.type == '|' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 | v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '~' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 ~ v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '&' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 & v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '<<' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 << v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '>>' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + local v2 = m.getInferLiteral(status, source[2], 'integer') + local v + if v1 and v2 then + v = v1 >> v2 + end + status.results = m.allocInfer { + type = 'integer', + value = v, + source = source, + } + return true + elseif op.type == '..' then + local v1 = m.getInferLiteral(status, source[1], 'string') + local v2 = m.getInferLiteral(status, source[2], 'string') + local v + if v1 and v2 then + v = v1 .. v2 + end + status.results = m.allocInfer { + type = 'string', + value = v, + source = source, + } + return true + elseif op.type == '^' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 ^ v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + } + return true + elseif op.type == '/' then + local v1 = m.getInferLiteral(status, source[1], 'integer') + or m.getInferLiteral(status, source[1], 'number') + local v2 = m.getInferLiteral(status, source[2], 'integer') + or m.getInferLiteral(status, source[2], 'number') + local v + if v1 and v2 then + v = v1 > v2 + end + status.results = m.allocInfer { + type = 'number', + value = v, + source = source, + } + return true + -- 其他数学运算根据2侧的值决定,当2侧的值均为整数时返回整数 + elseif op.type == '+' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 + v2) or nil, + source = source, + } + return true + elseif op.type == '-' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer{ + type = int, + value = (v1 and v2) and (v1 - v2) or nil, + source = source, + } + return true + elseif op.type == '*' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 * v2) or nil, + source = source, + } + return true + elseif op.type == '%' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 % v2) or nil, + source = source, + } + return true + elseif op.type == '//' then + local int, v1, v2 = mathCheck(status, source[1], source[2]) + status.results = m.allocInfer { + type = int, + value = (v1 and v2) and (v1 // v2) or nil, + source = source, + } + return true + end +end + +function m.inferByDef(status, obj) + if not status.cache.inferedDef then + status.cache.inferedDef = {} + end + if status.cache.inferedDef[obj] then + return + end + status.cache.inferedDef[obj] = true + local mark = {} + local newStatus = m.status(status, status.interface) + m.searchRefs(newStatus, obj, 'def') + for _, src in ipairs(newStatus.results) do + local inferStatus = m.status(newStatus) + m.searchInfer(inferStatus, src) + for _, infer in ipairs(inferStatus.results) do + if not mark[infer.source] then + mark[infer.source] = true + status.results[#status.results+1] = infer + end + end + end +end + +local function inferBySetOfLocal(status, source) + if status.cache[source] then + return + end + status.cache[source] = true + local newStatus = m.status(status) + if source.value then + m.searchInfer(newStatus, source.value) + end + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + break + end + m.searchInfer(newStatus, ref) + end + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end +end + +function m.inferBySet(status, source) + if #status.results ~= 0 then + return + end + if source.type == 'local' then + inferBySetOfLocal(status, source) + elseif source.type == 'setlocal' + or source.type == 'getlocal' then + inferBySetOfLocal(status, source.node) + end +end + +function m.inferByCall(status, source) + if #status.results ~= 0 then + return + end + if not source.parent then + return + end + if source.parent.type ~= 'call' then + return + end + if source.parent.node == source then + status.results[#status.results+1] = { + type = 'function', + source = source, + } + return + end +end + +function m.inferByGetTable(status, source) + if #status.results ~= 0 then + return + end + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local next = source.next + if not next then + return + end + if next.type == 'getfield' + or next.type == 'getindex' + or next.type == 'getmethod' + or next.type == 'setfield' + or next.type == 'setindex' + or next.type == 'setmethod' then + status.results[#status.results+1] = { + type = 'table', + source = source, + } + end +end + +function m.inferByUnary(status, source) + if #status.results ~= 0 then + return + end + local parent = source.parent + if not parent or parent.type ~= 'unary' then + return + end + local op = parent.op + if op.type == '#' then + status.results[#status.results+1] = { + type = 'string', + source = source + } + status.results[#status.results+1] = { + type = 'table', + source = source + } + elseif op.type == '~' then + status.results[#status.results+1] = { + type = 'integer', + source = source + } + elseif op.type == '-' then + status.results[#status.results+1] = { + type = 'number', + source = source + } + end +end + +function m.inferByBinary(status, source) + if #status.results ~= 0 then + return + end + local parent = source.parent + if not parent or parent.type ~= 'binary' then + return + end + local op = parent.op + if op.type == '<=' + or op.type == '>=' + or op.type == '<' + or op.type == '>' + or op.type == '^' + or op.type == '/' + or op.type == '+' + or op.type == '-' + or op.type == '*' + or op.type == '%' then + status.results[#status.results+1] = { + type = 'number', + source = source, + } + elseif op.type == '|' + or op.type == '~' + or op.type == '&' + or op.type == '<<' + or op.type == '>>' + -- 整数的可能性比较高 + or op.type == '//' then + status.results[#status.results+1] = { + type = 'integer', + source = source, + } + elseif op.type == '..' then + status.results[#status.results+1] = { + type = 'string', + source = source, + } + end +end + +local function mergeFunctionReturnsByDoc(status, source, index, call) + if not source or source.type ~= 'function' then + return + end + if not source.bindDocs then + return + end + local returns = {} + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + returns[#returns+1] = rtn + end + end + end + local rtn = returns[index] + if not rtn then + return + end + local results = m.getDocTypeNames(status, rtn, function (typeName, typeUnit) + if not source.args or not call.args then + return + end + local name = typeUnit[1] + local generics = typeUnit.typeGeneric[name] + if not generics then + return + end + local first = generics[1] + if not first or first == typeUnit then + return + end + local docParam = m.getParentType(first, 'doc.param') + local paramName = docParam.param[1] + for i, arg in ipairs(source.args) do + if arg[1] == paramName then + local callArg = call.args[i] + if not callArg then + return + end + return m.viewInferType(m.searchInfer(status, callArg)) + end + end + end) + if #results == 0 then + return + end + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + return true +end + +local function mergeDocTypeFunctionReturns(status, source, index) + if not source.bindDocs then + return + end + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.type' then + for _, typeUnit in ipairs(doc.types) do + if typeUnit.type == 'doc.type.function' then + local rtn = typeUnit.returns[index] + if rtn then + local results = m.getDocTypeNames(status, rtn) + for _, res in ipairs(results) do + status.results[#status.results+1] = res + end + end + end + end + end + end +end + +local function mergeFunctionReturns(status, source, index, call) + local returns = source.returns + if not returns then + return + end + for i = 1, #returns do + local rtn = returns[i] + if rtn[index] then + if rtn[index].type == 'call' then + if not m.checkReturnMark(status, rtn[index]) then + m.checkReturnMark(status, rtn[index], true) + m.inferByCallReturnAndIndex(status, rtn[index], index) + end + else + local newStatus = m.status(status) + m.searchInfer(newStatus, rtn[index]) + if #newStatus.results == 0 then + status.results[#status.results+1] = { + type = 'any', + source = rtn[index], + } + else + for _, infer in ipairs(newStatus.results) do + status.results[#status.results+1] = infer + end + end + end + end + end +end + +function m.inferByCallReturnAndIndex(status, call, index) + local node = call.node + local newStatus = m.status(nil, status.interface) + m.searchRefs(newStatus, node, 'def') + local hasDocReturn + for _, src in ipairs(newStatus.results) do + if mergeDocTypeFunctionReturns(status, src, index) then + hasDocReturn = true + elseif mergeFunctionReturnsByDoc(status, src.value, index, call) then + hasDocReturn = true + end + end + if not hasDocReturn then + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + if not m.checkReturnMark(status, src.value, true) then + mergeFunctionReturns(status, src.value, index, call) + end + end + end + end +end + +function m.inferByCallReturn(status, source) + if source.type == 'call' then + m.inferByCallReturnAndIndex(status, source, 1) + return + end + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + if not source.vararg or source.vararg.type ~= 'call' then + return + end + m.inferByCallReturnAndIndex(status, source.vararg, source.index) +end + +function m.inferByPCallReturn(status, source) + if source.type ~= 'select' then + if source.value and source.value.type == 'select' then + source = source.value + else + return + end + end + local call = source.vararg + if not call or call.type ~= 'call' then + return + end + local node = call.node + local specialName = node.special + local func, index + if specialName == 'pcall' then + func = call.args[1] + index = source.index - 1 + elseif specialName == 'xpcall' then + func = call.args[1] + index = source.index - 2 + else + return + end + local newStatus = m.status(nil, status.interface) + m.searchRefs(newStatus, func, 'def') + for _, src in ipairs(newStatus.results) do + if src.value and src.value.type == 'function' then + mergeFunctionReturns(status, src.value, index) + end + end +end + +function m.cleanInfers(infers) + local mark = {} + for i = #infers, 1, -1 do + local infer = infers[i] + local key = ('%s|%p'):format(infer.type, infer.source) + if mark[key] then + infers[i] = infers[#infers] + infers[#infers] = nil + else + mark[key] = true + end + end +end + +function m.searchInfer(status, obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return + end + end + while true do + local value = m.getObjectValue(obj) + if not value or value == obj then + break + end + obj = value + end + + local cache, makeCache = m.getRefCache(status, obj, 'infer') + if cache then + for i = 1, #cache do + status.results[#status.results+1] = cache[i] + end + return + end + + if DEVELOP then + status.cache.clock = status.cache.clock or osClock() + end + + if not status.cache.lockInfer then + status.cache.lockInfer = {} + end + if status.cache.lockInfer[obj] then + return + end + status.cache.lockInfer[obj] = true + + local checked = m.inferCheckDoc(status, obj) + or m.inferCheckUpDoc(status, obj) + or m.inferCheckFieldDoc(status, obj) + or m.inferCheckLiteral(status, obj) + or m.inferCheckUnary(status, obj) + or m.inferCheckBinary(status, obj) + if checked then + m.cleanInfers(status.results) + if makeCache then + makeCache(status.results) + end + return + end + + if status.deep then + m.inferByDef(status, obj) + end + m.inferBySet(status, obj) + m.inferByCall(status, obj) + m.inferByGetTable(status, obj) + m.inferByUnary(status, obj) + m.inferByBinary(status, obj) + m.inferByCallReturn(status, obj) + m.inferByPCallReturn(status, obj) + m.cleanInfers(status.results) + if makeCache then + makeCache(status.results) + end +end + +--- 请求对象的引用,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 不穿透 `setmetatable` ,考虑由 +--- 业务层进行反向 def 搜索。 +function m.requestReference(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'ref') + + m.searchRefsAsFunction(status, obj, 'ref') + + if m.debugMode then + print('count:', status.cache.count) + end + + return status.results, status.cache.count +end + +--- 请求对象的定义,包括 `a.b.c` 形式 +--- 与 `return function` 形式。 +--- 穿透 `setmetatable` 。 +function m.requestDefinition(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + -- 根据 field 搜索定义 + m.searchRefs(status, obj, 'def') + + return status.results, status.cache.count +end + +--- 请求对象的域 +function m.requestFields(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + + m.searchFields(status, obj) + + return status.results, status.cache.count +end + +--- 请求对象的类型推测 +function m.requestInfer(obj, interface, deep) + local status = m.status(nil, interface) + status.deep = deep + m.searchInfer(status, obj) + + return status.results, status.cache.count +end + +return m diff --git a/script/parser/init.lua b/script/parser/init.lua new file mode 100644 index 00000000..ba40d145 --- /dev/null +++ b/script/parser/init.lua @@ -0,0 +1,12 @@ +local api = { + grammar = require 'parser.grammar', + parse = require 'parser.parse', + compile = require 'parser.compile', + split = require 'parser.split', + calcline = require 'parser.calcline', + lines = require 'parser.lines', + guide = require 'parser.guide', + luadoc = require 'parser.luadoc', +} + +return api diff --git a/script/parser/lines.lua b/script/parser/lines.lua new file mode 100644 index 00000000..ee6b4f41 --- /dev/null +++ b/script/parser/lines.lua @@ -0,0 +1,45 @@ +local m = require 'lpeglabel' + +_ENV = nil + +local function Line(start, line, range, finish) + line.start = start + line.finish = finish - 1 + line.range = range - 1 + return line +end + +local function Space(...) + local line = {...} + local sp = 0 + local tab = 0 + for i = 1, #line do + if line[i] == ' ' then + sp = sp + 1 + elseif line[i] == '\t' then + tab = tab + 1 + end + line[i] = nil + end + line.sp = sp + line.tab = tab + return line +end + +local parser = m.P{ +'Lines', +Lines = m.Ct(m.V'Line'^0 * m.V'LastLine'), +Line = m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.V'Nl' * m.Cp() / Line, +LastLine= m.Cp() * m.V'Indent' * (1 - m.V'Nl')^0 * m.Cp() * m.Cp() / Line, +Nl = m.P'\r\n' + m.S'\r\n', +Indent = m.C(m.S' \t')^0 / Space, +} + +return function (self, text) + local lines, err = parser:match(text) + if not lines then + return nil, err + end + + return lines +end diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua new file mode 100644 index 00000000..b31c4baf --- /dev/null +++ b/script/parser/luadoc.lua @@ -0,0 +1,991 @@ +local m = require 'lpeglabel' +local re = require 'parser.relabel' +local lines = require 'parser.lines' +local guide = require 'parser.guide' + +local TokenTypes, TokenStarts, TokenFinishs, TokenContents +local Ci, Offset, pushError, Ct, NextComment +local parseType +local Parser = re.compile([[ +Main <- (Token / Sp)* +Sp <- %s+ +X16 <- [a-fA-F0-9] +Word <- [a-zA-Z0-9_] +Token <- Name / String / Symbol +Name <- ({} {[a-zA-Z0-9_] [a-zA-Z0-9_.*]*} {}) + -> Name +String <- ({} StringDef {}) + -> String +StringDef <- '"' + {~(Esc / !'"' .)*~} -> 1 + ('"'?) + / "'" + {~(Esc / !"'" .)*~} -> 1 + ("'"?) + / ('[' {:eq: '='* :} '[' + {(!StringClose .)*} -> 1 + (StringClose?)) +StringClose <- ']' =eq ']' +Esc <- '\' -> '' + EChar +EChar <- 'a' -> ea + / 'b' -> eb + / 'f' -> ef + / 'n' -> en + / 'r' -> er + / 't' -> et + / 'v' -> ev + / '\' + / '"' + / "'" + / %nl + / ('z' (%nl / %s)*) -> '' + / ('x' {X16 X16}) -> Char16 + / ([0-9] [0-9]? [0-9]?) -> Char10 + / ('u{' {Word*} '}') -> CharUtf8 +Symbol <- ({} { + ':' + / '|' + / ',' + / '[]' + / '<' + / '>' + / '(' + / ')' + / '?' + / '...' + / '+' + } {}) + -> Symbol +]], { + s = m.S' \t', + ea = '\a', + eb = '\b', + ef = '\f', + en = '\n', + er = '\r', + et = '\t', + ev = '\v', + Char10 = function (char) + char = tonumber(char) + if not char or char < 0 or char > 255 then + return '' + end + return string.char(char) + end, + Char16 = function (char) + return string.char(tonumber(char, 16)) + end, + CharUtf8 = function (char) + if #char == 0 then + return '' + end + local v = tonumber(char, 16) + if not v then + return '' + end + if v >= 0 and v <= 0x10FFFF then + return utf8.char(v) + end + return '' + end, + Name = function (start, content, finish) + Ci = Ci + 1 + TokenTypes[Ci] = 'name' + TokenStarts[Ci] = start + TokenFinishs[Ci] = finish - 1 + TokenContents[Ci] = content + end, + String = function (start, content, finish) + Ci = Ci + 1 + TokenTypes[Ci] = 'string' + TokenStarts[Ci] = start + TokenFinishs[Ci] = finish - 1 + TokenContents[Ci] = content + end, + Symbol = function (start, content, finish) + Ci = Ci + 1 + TokenTypes[Ci] = 'symbol' + TokenStarts[Ci] = start + TokenFinishs[Ci] = finish - 1 + TokenContents[Ci] = content + end, +}) + +local function trim(str) + return str:match '^%s*(%S+)%s*$' +end + +local function parseTokens(text, offset) + Ct = offset + Ci = 0 + Offset = offset + TokenTypes = {} + TokenStarts = {} + TokenFinishs = {} + TokenContents = {} + Parser:match(text) + Ci = 0 +end + +local function peekToken() + return TokenTypes[Ci+1], TokenContents[Ci+1] +end + +local function nextToken() + Ci = Ci + 1 + if not TokenTypes[Ci] then + Ci = Ci - 1 + return nil + end + return TokenTypes[Ci], TokenContents[Ci] +end + +local function checkToken(tp, content, offset) + offset = offset or 0 + return TokenTypes[Ci + offset] == tp + and TokenContents[Ci + offset] == content +end + +local function getStart() + if Ci == 0 then + return Offset + end + return TokenStarts[Ci] + Offset +end + +local function getFinish() + if Ci == 0 then + return Offset + end + return TokenFinishs[Ci] + Offset +end + +local function try(callback) + local savePoint = Ci + -- rollback + local suc = callback() + if not suc then + Ci = savePoint + end + return suc +end + +local function parseName(tp, parent) + local nameTp, nameText = peekToken() + if nameTp ~= 'name' then + return nil + end + nextToken() + local class = { + type = tp, + start = getStart(), + finish = getFinish(), + parent = parent, + [1] = nameText, + } + return class +end + +local function parseClass(parent) + local result = { + type = 'doc.class', + parent = parent, + } + result.class = parseName('doc.class.name', result) + if not result.class then + pushError { + type = 'LUADOC_MISS_CLASS_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.start = getStart() + result.finish = getFinish() + if not peekToken() then + return result + end + nextToken() + if not checkToken('symbol', ':') then + pushError { + type = 'LUADOC_MISS_EXTENDS_SYMBOL', + start = result.finish + 1, + finish = getStart() - 1, + } + return result + end + result.extends = parseName('doc.extends.name', result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_CLASS_EXTENDS_NAME', + start = getFinish(), + finish = getFinish(), + } + return result + end + result.finish = getFinish() + return result +end + +local function nextSymbolOrError(symbol) + if checkToken('symbol', symbol, 1) then + nextToken() + return true + end + pushError { + type = 'LUADOC_MISS_SYMBOL', + start = getFinish(), + finish = getFinish(), + info = { + symbol = symbol, + } + } + return false +end + +local function parseTypeUnitArray(node) + if not checkToken('symbol', '[]', 1) then + return nil + end + nextToken() + local result = { + type = 'doc.type.array', + start = node.start, + finish = getFinish(), + node = node, + } + return result +end + +local function parseTypeUnitGeneric(node) + if not checkToken('symbol', '<', 1) then + return nil + end + if not nextSymbolOrError('<') then + return nil + end + local key = parseType(node) + if not key or not nextSymbolOrError(',') then + return nil + end + local value = parseType(node) + if not value then + return nil + end + nextSymbolOrError('>') + local result = { + type = 'doc.type.generic', + start = node.start, + finish = getFinish(), + node = node, + key = key, + value = value, + } + return result +end + +local function parseTypeUnitFunction() + local typeUnit = { + type = 'doc.type.function', + start = getStart(), + args = {}, + returns = {}, + } + if not nextSymbolOrError('(') then + return nil + end + while true do + if checkToken('symbol', ')', 1) then + nextToken() + break + end + local arg = { + type = 'doc.type.arg', + parent = typeUnit, + } + arg.name = parseName('doc.type.name', arg) + if not arg.name then + pushError { + type = 'LUADOC_MISS_ARG_NAME', + start = getFinish(), + finish = getFinish(), + } + break + end + if not arg.start then + arg.start = arg.name.start + end + if checkToken('symbol', '?', 1) then + nextToken() + arg.optional = true + end + arg.finish = getFinish() + if not nextSymbolOrError(':') then + break + end + arg.extends = parseType(arg) + if not arg.extends then + break + end + arg.finish = getFinish() + typeUnit.args[#typeUnit.args+1] = arg + if checkToken('symbol', ',', 1) then + nextToken() + else + nextSymbolOrError(')') + break + end + end + if checkToken('symbol', ':', 1) then + nextToken() + while true do + local rtn = parseType(typeUnit) + if not rtn then + break + end + if checkToken('symbol', '?', 1) then + nextToken() + rtn.optional = true + end + typeUnit.returns[#typeUnit.returns+1] = rtn + if checkToken('symbol', ',', 1) then + nextToken() + else + break + end + end + end + typeUnit.finish = getFinish() + return typeUnit +end + +local function parseTypeUnit(parent, content) + local result + if content == 'fun' then + result = parseTypeUnitFunction() + end + if not result then + result = { + type = 'doc.type.name', + start = getStart(), + finish = getFinish(), + [1] = content, + } + end + if not result then + return nil + end + result.parent = parent + while true do + local newResult = parseTypeUnitArray(result) + or parseTypeUnitGeneric(result) + if not newResult then + break + end + result = newResult + end + return result +end + +local function parseResume() + local result = { + type = 'doc.resume' + } + + if checkToken('symbol', '>', 1) then + nextToken() + result.default = true + end + + if checkToken('symbol', '+', 1) then + nextToken() + result.additional = true + end + + local tp = peekToken() + if tp ~= 'string' then + pushError { + type = 'LUADOC_MISS_STRING', + start = getFinish(), + finish = getFinish(), + } + return nil + end + local _, str = nextToken() + result[1] = str + result.start = getStart() + result.finish = getFinish() + return result +end + +function parseType(parent) + local result = { + type = 'doc.type', + parent = parent, + types = {}, + enums = {}, + resumes = {}, + } + result.start = getStart() + while true do + local tp, content = peekToken() + if not tp then + break + end + if tp == 'name' then + nextToken() + local typeUnit = parseTypeUnit(result, content) + if not typeUnit then + break + end + result.types[#result.types+1] = typeUnit + if not result.start then + result.start = typeUnit.start + end + elseif tp == 'string' then + nextToken() + local typeEnum = { + type = 'doc.type.enum', + start = getStart(), + finish = getFinish(), + parent = result, + [1] = content, + } + result.enums[#result.enums+1] = typeEnum + if not result.start then + result.start = typeEnum.start + end + elseif tp == 'symbol' and content == '...' then + nextToken() + local vararg = { + type = 'doc.type.name', + start = getStart(), + finish = getFinish(), + parent = result, + [1] = content, + } + result.types[#result.types+1] = vararg + if not result.start then + result.start = vararg.start + end + end + if not checkToken('symbol', '|', 1) then + break + end + nextToken() + end + result.finish = getFinish() + + while true do + local nextComm = NextComment('peek') + if nextComm and nextComm.text:sub(1, 2) == '-|' then + NextComment() + local finishPos = nextComm.text:find('#', 3) or #nextComm.text + parseTokens(nextComm.text:sub(3, finishPos), nextComm.start + 1) + local resume = parseResume() + if resume then + resume.comment = nextComm.text:match('#%s*(.+)', 3) + result.resumes[#result.resumes+1] = resume + result.finish = resume.finish + end + else + break + end + end + + if #result.types == 0 and #result.enums == 0 and #result.resumes == 0 then + pushError { + type = 'LUADOC_MISS_TYPE_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + return result +end + +local function parseAlias() + local result = { + type = 'doc.alias', + } + result.alias = parseName('doc.alias.name', result) + if not result.alias then + pushError { + type = 'LUADOC_MISS_ALIAS_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.start = getStart() + result.extends = parseType(result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_ALIAS_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.finish = getFinish() + return result +end + +local function parseParam() + local result = { + type = 'doc.param', + } + result.param = parseName('doc.param.name', result) + if not result.param then + pushError { + type = 'LUADOC_MISS_PARAM_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + if checkToken('symbol', '?', 1) then + nextToken() + result.optional = true + end + result.start = result.param.start + result.finish = getFinish() + result.extends = parseType(result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_PARAM_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return result + end + result.finish = getFinish() + return result +end + +local function parseReturn() + local result = { + type = 'doc.return', + returns = {}, + } + while true do + local docType = parseType(result) + if not docType then + break + end + if not result.start then + result.start = docType.start + end + if checkToken('symbol', '?', 1) then + nextToken() + docType.optional = true + end + docType.name = parseName('doc.return.name', docType) + result.returns[#result.returns+1] = docType + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + if #result.returns == 0 then + return nil + end + result.finish = getFinish() + return result +end + +local function parseField() + local result = { + type = 'doc.field', + } + try(function () + local tp, value = nextToken() + if tp == 'name' then + if value == 'public' + or value == 'protected' + or value == 'private' then + result.visible = value + result.start = getStart() + return true + end + end + return false + end) + result.field = parseName('doc.field.name', result) + if not result.field then + pushError { + type = 'LUADOC_MISS_FIELD_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + if not result.start then + result.start = result.field.start + end + result.extends = parseType(result) + if not result.extends then + pushError { + type = 'LUADOC_MISS_FIELD_EXTENDS', + start = getFinish(), + finish = getFinish(), + } + return nil + end + result.finish = getFinish() + return result +end + +local function parseGeneric() + local result = { + type = 'doc.generic', + generics = {}, + } + while true do + local object = { + type = 'doc.generic.object', + parent = result, + } + object.generic = parseName('doc.generic.name', object) + if not object.generic then + pushError { + type = 'LUADOC_MISS_GENERIC_NAME', + start = getFinish(), + finish = getFinish(), + } + return nil + end + object.start = object.generic.start + if not result.start then + result.start = object.start + end + if checkToken('symbol', ':', 1) then + nextToken() + object.extends = parseType(object) + end + object.finish = getFinish() + result.generics[#result.generics+1] = object + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + result.finish = getFinish() + return result +end + +local function parseVararg() + local result = { + type = 'doc.vararg', + } + result.vararg = parseType(result) + if not result.vararg then + pushError { + type = 'LUADOC_MISS_VARARG_TYPE', + start = getFinish(), + finish = getFinish(), + } + return + end + result.start = result.vararg.start + result.finish = result.vararg.finish + return result +end + +local function parseOverload() + local tp, name = peekToken() + if tp ~= 'name' or name ~= 'fun' then + pushError { + type = 'LUADOC_MISS_FUN_AFTER_OVERLOAD', + start = getFinish(), + finish = getFinish(), + } + return nil + end + nextToken() + local result = { + type = 'doc.overload', + } + result.overload = parseTypeUnitFunction() + if not result.overload then + return nil + end + result.overload.parent = result + result.start = result.overload.start + result.finish = result.overload.finish + return result +end + +local function parseDeprecated() + return { + type = 'doc.deprecated', + start = getFinish(), + finish = getFinish(), + } +end + +local function parseMeta() + return { + type = 'doc.meta', + start = getFinish(), + finish = getFinish(), + } +end + +local function parseVersion() + local result = { + type = 'doc.version', + versions = {}, + } + while true do + local tp, text = nextToken() + if not tp then + pushError { + type = 'LUADOC_MISS_VERSION', + start = getStart(), + finish = getFinish(), + } + break + end + if not result.start then + result.start = getStart() + end + local version = { + type = 'doc.version.unit', + start = getStart(), + } + if tp == 'symbol' then + if text == '>' then + version.ge = true + elseif text == '<' then + version.le = true + end + tp, text = nextToken() + end + if tp ~= 'name' then + pushError { + type = 'LUADOC_MISS_VERSION', + start = getStart(), + finish = getFinish(), + } + break + end + version.version = tonumber(text) or text + version.finish = getFinish() + result.versions[#result.versions+1] = version + if not checkToken('symbol', ',', 1) then + break + end + nextToken() + end + if #result.versions == 0 then + return nil + end + result.finish = getFinish() + return result +end + +local function convertTokens() + local tp, text = nextToken() + if not tp then + return + end + if tp ~= 'name' then + pushError { + type = 'LUADOC_MISS_CATE_NAME', + start = getStart(), + finish = getFinish(), + } + return nil + end + if text == 'class' then + return parseClass() + elseif text == 'type' then + return parseType() + elseif text == 'alias' then + return parseAlias() + elseif text == 'param' then + return parseParam() + elseif text == 'return' then + return parseReturn() + elseif text == 'field' then + return parseField() + elseif text == 'generic' then + return parseGeneric() + elseif text == 'vararg' then + return parseVararg() + elseif text == 'overload' then + return parseOverload() + elseif text == 'deprecated' then + return parseDeprecated() + elseif text == 'meta' then + return parseMeta() + elseif text == 'version' then + return parseVersion() + end +end + +local function buildLuaDoc(comment) + local text = comment.text + if text:sub(1, 1) ~= '-' then + return + end + if text:sub(2, 2) ~= '@' then + return { + type = 'doc.comment', + start = comment.start, + finish = comment.finish, + comment = comment, + } + end + local finishPos = text:find('@', 3) + local doc, lastComment + if finishPos then + doc = text:sub(3, finishPos - 1) + lastComment = text:sub(finishPos) + else + doc = text:sub(3) + end + + parseTokens(doc, comment.start + 1) + local result = convertTokens() + if result then + result.comment = lastComment + end + + return result +end + +local function isNextLine(lns, binded, doc) + if not binded then + return false + end + local lastDoc = binded[#binded] + local lastRow = guide.positionOf(lns, lastDoc.finish) + local newRow = guide.positionOf(lns, doc.start) + return newRow - lastRow == 1 +end + +local function bindGeneric(binded) + local generics = {} + for _, doc in ipairs(binded) do + if doc.type == 'doc.generic' then + for _, obj in ipairs(doc.generics) do + local name = obj.generic[1] + generics[name] = {} + end + elseif doc.type == 'doc.param' + or doc.type == 'doc.return' then + guide.eachSourceType(doc, 'doc.type.name', function (src) + local name = src[1] + if generics[name] then + generics[name][#generics[name]+1] = src + src.typeGeneric = generics + end + end) + end + end +end + +local function bindDoc(state, lns, binded) + if not binded then + return + end + local lastDoc = binded[#binded] + if not lastDoc then + return + end + local bindSources = {} + for _, doc in ipairs(binded) do + doc.bindGroup = binded + doc.bindSources = bindSources + end + bindGeneric(binded) + local row = guide.positionOf(lns, lastDoc.finish) + local start, finish = guide.lineRange(lns, row + 1) + if start >= finish then + -- 空行 + return + end + guide.eachSourceBetween(state.ast, start, finish, function (src) + if src.start and src.start < start then + return + end + if src.type == 'local' + or src.type == 'setlocal' + or src.type == 'setglobal' + or src.type == 'setfield' + or src.type == 'setmethod' + or src.type == 'setindex' + or src.type == 'tablefield' + or src.type == 'tableindex' + or src.type == 'function' + or src.type == '...' then + src.bindDocs = binded + bindSources[#bindSources+1] = src + end + end) +end + +local function bindDocs(state) + local lns = lines(nil, state.lua) + local binded + for _, doc in ipairs(state.ast.docs) do + if not isNextLine(lns, binded, doc) then + bindDoc(state, lns, binded) + binded = {} + state.ast.docs.groups[#state.ast.docs.groups+1] = binded + end + binded[#binded+1] = doc + end + bindDoc(state, lns, binded) +end + +return function (_, state) + local ast = state.ast + local comments = state.comms + table.sort(comments, function (a, b) + return a.start < b.start + end) + ast.docs = { + type = 'doc', + parent = ast, + groups = {}, + } + + pushError = state.pushError + + local ci = 1 + NextComment = function (peek) + local comment = comments[ci] + if not peek then + ci = ci + 1 + end + return comment + end + + while true do + local comment = NextComment() + if not comment then + break + end + local doc = buildLuaDoc(comment) + if doc then + ast.docs[#ast.docs+1] = doc + doc.parent = ast.docs + if ast.start > doc.start then + ast.start = doc.start + end + if ast.finish < doc.finish then + ast.finish = doc.finish + end + end + end + + if #ast.docs == 0 then + return + end + + bindDocs(state) +end diff --git a/script/parser/parse.lua b/script/parser/parse.lua new file mode 100644 index 00000000..f813cc59 --- /dev/null +++ b/script/parser/parse.lua @@ -0,0 +1,49 @@ +local ast = require 'parser.ast' + +return function (self, lua, mode, version) + local errs = {} + local diags = {} + local comms = {} + local state = { + version = version, + lua = lua, + root = {}, + errs = errs, + diags = diags, + comms = comms, + pushError = function (err) + if err.finish < err.start then + err.finish = err.start + end + local last = errs[#errs] + if last then + if last.start <= err.start and last.finish >= err.finish then + return + end + end + err.level = err.level or 'error' + errs[#errs+1] = err + return err + end, + pushDiag = function (code, info) + if not diags[code] then + diags[code] = {} + end + diags[code][#diags[code]+1] = info + end, + pushComment = function (comment) + comms[#comms+1] = comment + end + } + ast.init(state) + local suc, res, err = xpcall(self.grammar, debug.traceback, self, lua, mode) + ast.close() + if not suc then + return nil, res + end + if not res then + state.pushError(err) + end + state.ast = res + return state +end diff --git a/script/parser/relabel.lua b/script/parser/relabel.lua new file mode 100644 index 00000000..ac902403 --- /dev/null +++ b/script/parser/relabel.lua @@ -0,0 +1,361 @@ +-- $Id: re.lua,v 1.44 2013/03/26 20:11:40 roberto Exp $ + +-- imported functions and modules +local tonumber, type, print, error = tonumber, type, print, error +local pcall = pcall +local setmetatable = setmetatable +local tinsert, concat = table.insert, table.concat +local rep = string.rep +local m = require"lpeglabel" + +-- 'm' will be used to parse expressions, and 'mm' will be used to +-- create expressions; that is, 're' runs on 'm', creating patterns +-- on 'mm' +local mm = m + +-- pattern's metatable +local mt = getmetatable(mm.P(0)) + + + +-- No more global accesses after this point +_ENV = nil + + +local any = m.P(1) +local dummy = mm.P(false) + + +local errinfo = { + NoPatt = "no pattern found", + ExtraChars = "unexpected characters after the pattern", + + ExpPatt1 = "expected a pattern after '/'", + + ExpPatt2 = "expected a pattern after '&'", + ExpPatt3 = "expected a pattern after '!'", + + ExpPatt4 = "expected a pattern after '('", + ExpPatt5 = "expected a pattern after ':'", + ExpPatt6 = "expected a pattern after '{~'", + ExpPatt7 = "expected a pattern after '{|'", + + ExpPatt8 = "expected a pattern after '<-'", + + ExpPattOrClose = "expected a pattern or closing '}' after '{'", + + ExpNumName = "expected a number, '+', '-' or a name (no space) after '^'", + ExpCap = "expected a string, number, '{}' or name after '->'", + + ExpName1 = "expected the name of a rule after '=>'", + ExpName2 = "expected the name of a rule after '=' (no space)", + ExpName3 = "expected the name of a rule after '<' (no space)", + + ExpLab1 = "expected a label after '{'", + + ExpNameOrLab = "expected a name or label after '%' (no space)", + + ExpItem = "expected at least one item after '[' or '^'", + + MisClose1 = "missing closing ')'", + MisClose2 = "missing closing ':}'", + MisClose3 = "missing closing '~}'", + MisClose4 = "missing closing '|}'", + MisClose5 = "missing closing '}'", -- for the captures + + MisClose6 = "missing closing '>'", + MisClose7 = "missing closing '}'", -- for the labels + + MisClose8 = "missing closing ']'", + + MisTerm1 = "missing terminating single quote", + MisTerm2 = "missing terminating double quote", +} + +local function expect (pattern, label) + return pattern + m.T(label) +end + + +-- Pre-defined names +local Predef = { nl = m.P"\n" } + + +local mem +local fmem +local gmem + + +local function updatelocale () + mm.locale(Predef) + Predef.a = Predef.alpha + Predef.c = Predef.cntrl + Predef.d = Predef.digit + Predef.g = Predef.graph + Predef.l = Predef.lower + Predef.p = Predef.punct + Predef.s = Predef.space + Predef.u = Predef.upper + Predef.w = Predef.alnum + Predef.x = Predef.xdigit + Predef.A = any - Predef.a + Predef.C = any - Predef.c + Predef.D = any - Predef.d + Predef.G = any - Predef.g + Predef.L = any - Predef.l + Predef.P = any - Predef.p + Predef.S = any - Predef.s + Predef.U = any - Predef.u + Predef.W = any - Predef.w + Predef.X = any - Predef.x + mem = {} -- restart memoization + fmem = {} + gmem = {} + local mt = {__mode = "v"} + setmetatable(mem, mt) + setmetatable(fmem, mt) + setmetatable(gmem, mt) +end + + +updatelocale() + + + +local I = m.P(function (s,i) print(i, s:sub(1, i-1)); return i end) + + +local function getdef (id, defs) + local c = defs and defs[id] + if not c then + error("undefined name: " .. id) + end + return c +end + + +local function mult (p, n) + local np = mm.P(true) + while n >= 1 do + if n%2 >= 1 then np = np * p end + p = p * p + n = n/2 + end + return np +end + +local function equalcap (s, i, c) + if type(c) ~= "string" then return nil end + local e = #c + i + if s:sub(i, e - 1) == c then return e else return nil end +end + + +local S = (Predef.space + "--" * (any - Predef.nl)^0)^0 + +local name = m.C(m.R("AZ", "az", "__") * m.R("AZ", "az", "__", "09")^0) + +local arrow = S * "<-" + +-- a defined name only have meaning in a given environment +local Def = name * m.Carg(1) + +local num = m.C(m.R"09"^1) * S / tonumber + +local String = "'" * m.C((any - "'" - m.P"\n")^0) * expect("'", "MisTerm1") + + '"' * m.C((any - '"' - m.P"\n")^0) * expect('"', "MisTerm2") + + +local defined = "%" * Def / function (c,Defs) + local cat = Defs and Defs[c] or Predef[c] + if not cat then + error("name '" .. c .. "' undefined") + end + return cat +end + +local Range = m.Cs(any * (m.P"-"/"") * (any - "]")) / mm.R + +local item = defined + Range + m.C(any - m.P"\n") + +local Class = + "[" + * (m.C(m.P"^"^-1)) -- optional complement symbol + * m.Cf(expect(item, "ExpItem") * (item - "]")^0, mt.__add) + / function (c, p) return c == "^" and any - p or p end + * expect("]", "MisClose8") + +local function adddef (t, k, exp) + if t[k] then + -- TODO 改了一下这里的代码,重复定义不会抛错 + --error("'"..k.."' already defined as a rule") + else + t[k] = exp + end + return t +end + +local function firstdef (n, r) return adddef({n}, n, r) end + + +local function NT (n, b) + if not b then + error("rule '"..n.."' used outside a grammar") + else return mm.V(n) + end +end + + +local exp = m.P{ "Exp", + Exp = S * ( m.V"Grammar" + + m.Cf(m.V"Seq" * (S * "/" * expect(S * m.V"Seq", "ExpPatt1"))^0, mt.__add) ); + Seq = m.Cf(m.Cc(m.P"") * m.V"Prefix" * (S * m.V"Prefix")^0, mt.__mul); + Prefix = "&" * expect(S * m.V"Prefix", "ExpPatt2") / mt.__len + + "!" * expect(S * m.V"Prefix", "ExpPatt3") / mt.__unm + + m.V"Suffix"; + Suffix = m.Cf(m.V"Primary" * + ( S * ( m.P"+" * m.Cc(1, mt.__pow) + + m.P"*" * m.Cc(0, mt.__pow) + + m.P"?" * m.Cc(-1, mt.__pow) + + "^" * expect( m.Cg(num * m.Cc(mult)) + + m.Cg(m.C(m.S"+-" * m.R"09"^1) * m.Cc(mt.__pow) + + name * m.Cc"lab" + ), + "ExpNumName") + + "->" * expect(S * ( m.Cg((String + num) * m.Cc(mt.__div)) + + m.P"{}" * m.Cc(nil, m.Ct) + + m.Cg(Def / getdef * m.Cc(mt.__div)) + ), + "ExpCap") + + "=>" * expect(S * m.Cg(Def / getdef * m.Cc(m.Cmt)), + "ExpName1") + ) + )^0, function (a,b,f) if f == "lab" then return a + mm.T(b) else return f(a,b) end end ); + Primary = "(" * expect(m.V"Exp", "ExpPatt4") * expect(S * ")", "MisClose1") + + String / mm.P + + Class + + defined + + "%" * expect(m.P"{", "ExpNameOrLab") + * expect(S * m.V"Label", "ExpLab1") + * expect(S * "}", "MisClose7") / mm.T + + "{:" * (name * ":" + m.Cc(nil)) * expect(m.V"Exp", "ExpPatt5") + * expect(S * ":}", "MisClose2") + / function (n, p) return mm.Cg(p, n) end + + "=" * expect(name, "ExpName2") + / function (n) return mm.Cmt(mm.Cb(n), equalcap) end + + m.P"{}" / mm.Cp + + "{~" * expect(m.V"Exp", "ExpPatt6") + * expect(S * "~}", "MisClose3") / mm.Cs + + "{|" * expect(m.V"Exp", "ExpPatt7") + * expect(S * "|}", "MisClose4") / mm.Ct + + "{" * expect(m.V"Exp", "ExpPattOrClose") + * expect(S * "}", "MisClose5") / mm.C + + m.P"." * m.Cc(any) + + (name * -arrow + "<" * expect(name, "ExpName3") + * expect(">", "MisClose6")) * m.Cb("G") / NT; + Label = num + name; + Definition = name * arrow * expect(m.V"Exp", "ExpPatt8"); + Grammar = m.Cg(m.Cc(true), "G") + * m.Cf(m.V"Definition" / firstdef * (S * m.Cg(m.V"Definition"))^0, + adddef) / mm.P; +} + +local pattern = S * m.Cg(m.Cc(false), "G") * expect(exp, "NoPatt") / mm.P + * S * expect(-any, "ExtraChars") + +local function lineno (s, i) + if i == 1 then return 1, 1 end + local adjustment = 0 + -- report the current line if at end of line, not the next + if s:sub(i,i) == '\n' then + i = i-1 + adjustment = 1 + end + local rest, num = s:sub(1,i):gsub("[^\n]*\n", "") + local r = #rest + return 1 + num, (r ~= 0 and r or 1) + adjustment +end + +local function calcline (s, i) + if i == 1 then return 1, 1 end + local rest, line = s:sub(1,i):gsub("[^\n]*\n", "") + local col = #rest + return 1 + line, col ~= 0 and col or 1 +end + + +local function splitlines(str) + local t = {} + local function helper(line) tinsert(t, line) return "" end + helper((str:gsub("(.-)\r?\n", helper))) + return t +end + +local function compile (p, defs) + if mm.type(p) == "pattern" then return p end -- already compiled + p = p .. " " -- for better reporting of column numbers in errors when at EOF + local ok, cp, label, poserr = pcall(function() return pattern:match(p, 1, defs) end) + if not ok and cp then + if type(cp) == "string" then + cp = cp:gsub("^[^:]+:[^:]+: ", "") + end + error(cp, 3) + end + if not cp then + local lines = splitlines(p) + local line, col = lineno(p, poserr) + local err = {} + tinsert(err, "L" .. line .. ":C" .. col .. ": " .. errinfo[label]) + tinsert(err, lines[line]) + tinsert(err, rep(" ", col-1) .. "^") + error("syntax error(s) in pattern\n" .. concat(err, "\n"), 3) + end + return cp +end + +local function match (s, p, i) + local cp = mem[p] + if not cp then + cp = compile(p) + mem[p] = cp + end + return cp:match(s, i or 1) +end + +local function find (s, p, i) + local cp = fmem[p] + if not cp then + cp = compile(p) / 0 + cp = mm.P{ mm.Cp() * cp * mm.Cp() + 1 * mm.V(1) } + fmem[p] = cp + end + local i, e = cp:match(s, i or 1) + if i then return i, e - 1 + else return i + end +end + +local function gsub (s, p, rep) + local g = gmem[p] or {} -- ensure gmem[p] is not collected while here + gmem[p] = g + local cp = g[rep] + if not cp then + cp = compile(p) + cp = mm.Cs((cp / rep + 1)^0) + g[rep] = cp + end + return cp:match(s) +end + + +-- exported names +local re = { + compile = compile, + match = match, + find = find, + gsub = gsub, + updatelocale = updatelocale, + calcline = calcline +} + +return re diff --git a/script/parser/split.lua b/script/parser/split.lua new file mode 100644 index 00000000..6ce4a4e7 --- /dev/null +++ b/script/parser/split.lua @@ -0,0 +1,9 @@ +local m = require 'lpeglabel' + +local NL = m.P'\r\n' + m.S'\r\n' +local LINE = m.C(1 - NL) + +return function (str) + local MATCH = m.Ct((LINE * NL)^0 * LINE) + return MATCH:match(str) +end diff --git a/script/proto/define.lua b/script/proto/define.lua new file mode 100644 index 00000000..966a5161 --- /dev/null +++ b/script/proto/define.lua @@ -0,0 +1,287 @@ +local guide = require 'parser.guide' +local util = require 'utility' + +local m = {} + +--- 获取 position 对应的光标位置 +---@param lines table +---@param text string +---@param position position +---@return integer +function m.offset(lines, text, position) + local row = position.line + 1 + local start = guide.lineRange(lines, row) + if start <= 0 or start > #text then + return #text + 1 + end + local offset = utf8.offset(text, position.character + 1, start) + return offset - 1 +end + +--- 获取 position 对应的光标位置(根据附近的单词) +---@param lines table +---@param text string +---@param position position +---@return integer +function m.offsetOfWord(lines, text, position) + local row = position.line + 1 + local start = guide.lineRange(lines, row) + if start <= 0 or start > #text then + return #text + 1 + end + local offset = utf8.offset(text, position.character + 1, start) + if offset > #text + or text:sub(offset-1, offset):match '[%w_][^%w_]' then + offset = offset - 1 + end + return offset +end + +--- 将光标位置转化为 position +---@alias position table +---@param lines table +---@param text string +---@param offset integer +---@return position +function m.position(lines, text, offset) + local row, col = guide.positionOf(lines, offset) + local start = guide.lineRange(lines, row) + if start < 1 then + start = 1 + end + local ucol = util.utf8Len(text, start, start + col - 1) + if row < 1 then + row = 1 + end + return { + line = row - 1, + character = ucol, + } +end + +--- 将起点与终点位置转化为 range +---@alias range table +---@param lines table +---@param text string +---@param offset1 integer +---@param offset2 integer +function m.range(lines, text, offset1, offset2) + local range = { + start = m.position(lines, text, offset1), + ['end'] = m.position(lines, text, offset2), + } + if range.start.character > 0 then + range.start.character = range.start.character - 1 + end + return range +end + +---@alias location table +---@param uri string +---@param range range +---@return location +function m.location(uri, range) + return { + uri = uri, + range = range, + } +end + +---@alias locationLink table +---@param uri string +---@param range range +---@param selection range +---@param origin range +function m.locationLink(uri, range, selection, origin) + return { + targetUri = uri, + targetRange = range, + targetSelectionRange = selection, + originSelectionRange = origin, + } +end + +function m.textEdit(range, newtext) + return { + range = range, + newText = newtext, + } +end + +--- 诊断等级 +m.DiagnosticSeverity = { + Error = 1, + Warning = 2, + Information = 3, + Hint = 4, +} + +--- 诊断类型与默认等级 +m.DiagnosticDefaultSeverity = { + ['unused-local'] = 'Hint', + ['unused-function'] = 'Hint', + ['undefined-global'] = 'Warning', + ['global-in-nil-env'] = 'Warning', + ['unused-label'] = 'Hint', + ['unused-vararg'] = 'Hint', + ['trailing-space'] = 'Hint', + ['redefined-local'] = 'Hint', + ['newline-call'] = 'Information', + ['newfield-call'] = 'Warning', + ['redundant-parameter'] = 'Hint', + ['ambiguity-1'] = 'Warning', + ['lowercase-global'] = 'Information', + ['undefined-env-child'] = 'Information', + ['duplicate-index'] = 'Warning', + ['empty-block'] = 'Hint', + ['redundant-value'] = 'Hint', + ['code-after-break'] = 'Hint', + + ['duplicate-doc-class'] = 'Warning', + ['undefined-doc-class'] = 'Warning', + ['undefined-doc-name'] = 'Warning', + ['circle-doc-class'] = 'Warning', + ['undefined-doc-param'] = 'Warning', + ['duplicate-doc-param'] = 'Warning', + ['doc-field-no-class'] = 'Warning', + ['duplicate-doc-field'] = 'Warning', +} + +--- 诊断报告标签 +m.DiagnosticTag = { + Unnecessary = 1, + Deprecated = 2, +} + +m.DocumentHighlightKind = { + Text = 1, + Read = 2, + Write = 3, +} + +m.MessageType = { + Error = 1, + Warning = 2, + Info = 3, + Log = 4, +} + +m.FileChangeType = { + Created = 1, + Changed = 2, + Deleted = 3, +} + +m.CompletionItemKind = { + Text = 1, + Method = 2, + Function = 3, + Constructor = 4, + Field = 5, + Variable = 6, + Class = 7, + Interface = 8, + Module = 9, + Property = 10, + Unit = 11, + Value = 12, + Enum = 13, + Keyword = 14, + Snippet = 15, + Color = 16, + File = 17, + Reference = 18, + Folder = 19, + EnumMember = 20, + Constant = 21, + Struct = 22, + Event = 23, + Operator = 24, + TypeParameter = 25, +} + +m.DiagnosticSeverity = { + Error = 1, + Warning = 2, + Information = 3, + Hint = 4, +} + +m.ErrorCodes = { + -- Defined by JSON RPC + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + serverErrorStart = -32099, + serverErrorEnd = -32000, + ServerNotInitialized = -32002, + UnknownErrorCode = -32001, + + -- Defined by the protocol. + RequestCancelled = -32800, +} + +m.SymbolKind = { + File = 1, + Module = 2, + Namespace = 3, + Package = 4, + Class = 5, + Method = 6, + Property = 7, + Field = 8, + Constructor = 9, + Enum = 10, + Interface = 11, + Function = 12, + Variable = 13, + Constant = 14, + String = 15, + Number = 16, + Boolean = 17, + Array = 18, + Object = 19, + Key = 20, + Null = 21, + EnumMember = 22, + Struct = 23, + Event = 24, + Operator = 25, + TypeParameter = 26, +} + +m.TokenModifiers = { + ["declaration"] = 1 << 0, + ["documentation"] = 1 << 1, + ["static"] = 1 << 2, + ["abstract"] = 1 << 3, + ["deprecated"] = 1 << 4, + ["readonly"] = 1 << 5, +} + +m.TokenTypes = { + ["comment"] = 0, + ["keyword"] = 1, + ["number"] = 2, + ["regexp"] = 3, + ["operator"] = 4, + ["namespace"] = 5, + ["type"] = 6, + ["struct"] = 7, + ["class"] = 8, + ["interface"] = 9, + ["enum"] = 10, + ["typeParameter"] = 11, + ["function"] = 12, + ["member"] = 13, + ["macro"] = 14, + ["variable"] = 15, + ["parameter"] = 16, + ["property"] = 17, + ["label"] = 18, +} + + +return m diff --git a/script/proto/init.lua b/script/proto/init.lua new file mode 100644 index 00000000..33e637f6 --- /dev/null +++ b/script/proto/init.lua @@ -0,0 +1,3 @@ +local proto = require 'proto.proto' + +return proto diff --git a/script/proto/proto.lua b/script/proto/proto.lua new file mode 100644 index 00000000..d8538d8f --- /dev/null +++ b/script/proto/proto.lua @@ -0,0 +1,147 @@ +local subprocess = require 'bee.subprocess' +local util = require 'utility' +local await = require 'await' +local pub = require 'pub' +local jsonrpc = require 'jsonrpc' +local define = require 'proto.define' +local timer = require 'timer' +local json = require 'json' + +local reqCounter = util.counter() + +local m = {} + +m.ability = {} +m.waiting = {} +m.holdon = {} + +function m.getMethodName(proto) + if proto.method:sub(1, 2) == '$/' then + return proto.method:sub(3), true + else + return proto.method, false + end +end + +function m.on(method, callback) + m.ability[method] = callback +end + +function m.response(id, res) + if id == nil then + log.error('Response id is nil!', util.dump(res)) + return + end + assert(m.holdon[id]) + m.holdon[id] = nil + local data = {} + data.id = id + data.result = res == nil and json.null or res + local buf = jsonrpc.encode(data) + --log.debug('Response', id, #buf) + io.stdout:write(buf) +end + +function m.responseErr(id, code, message) + if id == nil then + log.error('Response id is nil!', util.dump(message)) + return + end + local buf = jsonrpc.encode { + id = id, + error = { + code = code, + message = message, + } + } + --log.debug('ResponseErr', id, #buf) + io.stdout:write(buf) +end + +function m.notify(name, params) + local buf = jsonrpc.encode { + method = name, + params = params, + } + --log.debug('Notify', name, #buf) + io.stdout:write(buf) +end + +function m.awaitRequest(name, params) + local id = reqCounter() + local buf = jsonrpc.encode { + id = id, + method = name, + params = params, + } + --log.debug('Request', name, #buf) + io.stdout:write(buf) + return await.wait(function (waker) + m.waiting[id] = waker + end) +end + +function m.doMethod(proto) + local method, optional = m.getMethodName(proto) + local abil = m.ability[method] + if not abil then + if not optional then + log.warn('Recieved unknown proto: ' .. method) + end + if proto.id then + m.responseErr(proto.id, define.ErrorCodes.MethodNotFound, method) + end + return + end + if proto.id then + m.holdon[proto.id] = method + end + await.call(function () + --log.debug('Start method:', method) + local clock = os.clock() + local ok = true + local res + -- 任务可能在执行过程中被中断,通过close来捕获 + local response <close> = util.defer(function () + local passed = os.clock() - clock + if passed > 0.2 then + log.debug(('Method [%s] takes [%.3f]sec.'):format(method, passed)) + end + --log.debug('Finish method:', method) + if not proto.id then + return + end + if ok then + m.response(proto.id, res) + else + m.responseErr(proto.id, define.ErrorCodes.InternalError, res) + end + end) + ok, res = xpcall(abil, log.error, proto.params) + end) +end + +function m.doResponse(proto) + local id = proto.id + local waker = m.waiting[id] + if not waker then + log.warn('Response id not found: ' .. util.dump(proto)) + return + end + m.waiting[id] = nil + if proto.error then + log.warn(('Response error [%d]: %s'):format(proto.error.code, proto.error.message)) + return + end + waker(proto.result) +end + +function m.listen() + subprocess.filemode(io.stdin, 'b') + subprocess.filemode(io.stdout, 'b') + io.stdin:setvbuf 'no' + io.stdout:setvbuf 'no' + pub.task('loadProto') +end + +return m diff --git a/script/provider/capability.lua b/script/provider/capability.lua new file mode 100644 index 00000000..23ec27b0 --- /dev/null +++ b/script/provider/capability.lua @@ -0,0 +1,61 @@ +local sp = require 'bee.subprocess' +local nonil = require 'without-check-nil' +local client = require 'provider.client' + +local m = {} + +local function allWords() + local str = [[abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.:('"[,#*@| ]] + local list = {} + for c in str:gmatch '.' do + list[#list+1] = c + end + return list +end + +function m.getIniter() + local initer = { + -- 文本同步方式 + textDocumentSync = { + -- 打开关闭文本时通知 + openClose = true, + -- 文本改变时完全通知 TODO 支持差量更新(2) + change = 1, + }, + + hoverProvider = true, + definitionProvider = true, + referencesProvider = true, + renameProvider = { + prepareProvider = true, + }, + documentSymbolProvider = true, + workspaceSymbolProvider = true, + documentHighlightProvider = true, + codeActionProvider = true, + signatureHelpProvider = { + triggerCharacters = { '(', ',' }, + }, + executeCommandProvider = { + commands = { + 'lua.removeSpace:' .. sp:get_id(), + 'lua.solve:' .. sp:get_id(), + }, + } + --documentOnTypeFormattingProvider = { + -- firstTriggerCharacter = '}', + --}, + } + + nonil.enable() + if not client.info.capabilities.textDocument.completion.dynamicRegistration then + initer.completionProvider = { + triggerCharacters = allWords(), + } + end + nonil.disable() + + return initer +end + +return m diff --git a/script/provider/client.lua b/script/provider/client.lua new file mode 100644 index 00000000..c1b16f0f --- /dev/null +++ b/script/provider/client.lua @@ -0,0 +1,18 @@ +local nonil = require 'without-check-nil' +local util = require 'utility' + +local m = {} + +function m.client() + nonil.enable() + local name = m.info.clientInfo.name + nonil.disable() + return name +end + +function m.init(t) + log.debug('Client init', util.dump(t)) + m.info = t +end + +return m diff --git a/script/provider/completion.lua b/script/provider/completion.lua new file mode 100644 index 00000000..e506cd7b --- /dev/null +++ b/script/provider/completion.lua @@ -0,0 +1,54 @@ +local proto = require 'proto' + +local isEnable = false + +local function allWords() + local str = [[abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.:('"[,#*@| ]] + local list = {} + for c in str:gmatch '.' do + list[#list+1] = c + end + return list +end + +local function enable() + -- TODO 检查客户端是否支持动态注册自动完成 + if isEnable then + return + end + isEnable = true + log.debug('Enable completion.') + proto.awaitRequest('client/registerCapability', { + registrations = { + { + id = 'completion', + method = 'textDocument/completion', + registerOptions = { + resolveProvider = true, + triggerCharacters = allWords(), + }, + }, + } + }) +end + +local function disable() + if not isEnable then + return + end + isEnable = false + log.debug('Disable completion.') + proto.awaitRequest('client/unregisterCapability', { + unregisterations = { + { + id = 'completion', + method = 'textDocument/completion', + }, + } + }) +end + +return { + enable = enable, + disable = disable, +} diff --git a/script/provider/diagnostic.lua b/script/provider/diagnostic.lua new file mode 100644 index 00000000..845b9f44 --- /dev/null +++ b/script/provider/diagnostic.lua @@ -0,0 +1,303 @@ +local await = require 'await' +local proto = require 'proto.proto' +local define = require 'proto.define' +local lang = require 'language' +local files = require 'files' +local config = require 'config' +local core = require 'core.diagnostics' +local util = require 'utility' +local ws = require 'workspace' + +local m = {} +m._start = false +m.cache = {} +m.sleepRest = 0.0 + +local function concat(t, sep) + if type(t) ~= 'table' then + return t + end + return table.concat(t, sep) +end + +local function buildSyntaxError(uri, err) + local lines = files.getLines(uri) + local text = files.getText(uri) + local message = lang.script('PARSER_'..err.type, err.info) + + if err.version then + local version = err.info and err.info.version or config.config.runtime.version + message = message .. ('(%s)'):format(lang.script('DIAG_NEED_VERSION' + , concat(err.version, '/') + , version + )) + end + + local related = err.info and err.info.related + local relatedInformation + if related then + relatedInformation = {} + for _, rel in ipairs(related) do + local rmessage + if rel.message then + rmessage = lang.script('PARSER_'..rel.message) + else + rmessage = text:sub(rel.start, rel.finish) + end + relatedInformation[#relatedInformation+1] = { + message = rmessage, + location = define.location(uri, define.range(lines, text, rel.start, rel.finish)), + } + end + end + + return { + code = err.type:lower():gsub('_', '-'), + range = define.range(lines, text, err.start, err.finish), + severity = define.DiagnosticSeverity.Error, + source = lang.script.DIAG_SYNTAX_CHECK, + message = message, + relatedInformation = relatedInformation, + } +end + +local function buildDiagnostic(uri, diag) + local lines = files.getLines(uri) + local text = files.getText(uri) + + local relatedInformation + if diag.related then + relatedInformation = {} + for _, rel in ipairs(diag.related) do + local rtext = files.getText(rel.uri) + local rlines = files.getLines(rel.uri) + relatedInformation[#relatedInformation+1] = { + message = rel.message or rtext:sub(rel.start, rel.finish), + location = define.location(rel.uri, define.range(rlines, rtext, rel.start, rel.finish)) + } + end + end + + return { + range = define.range(lines, text, diag.start, diag.finish), + source = lang.script.DIAG_DIAGNOSTICS, + severity = diag.level, + message = diag.message, + code = diag.code, + tags = diag.tags, + relatedInformation = relatedInformation, + } +end + +local function merge(a, b) + if not a and not b then + return nil + end + local t = {} + if a then + for i = 1, #a do + t[#t+1] = a[i] + end + end + if b then + for i = 1, #b do + t[#t+1] = b[i] + end + end + return t +end + +function m.clear(uri) + local luri = uri:lower() + if not m.cache[luri] then + return + end + m.cache[luri] = nil + proto.notify('textDocument/publishDiagnostics', { + uri = files.getOriginUri(luri) or uri, + diagnostics = {}, + }) +end + +function m.clearAll() + for luri in pairs(m.cache) do + m.clear(luri) + end +end + +function m.syntaxErrors(uri, ast) + if #ast.errs == 0 then + return nil + end + + local results = {} + + for _, err in ipairs(ast.errs) do + results[#results+1] = buildSyntaxError(uri, err) + end + + return results +end + +function m.diagnostics(uri, diags) + if not m._start then + return + end + + core(uri, function (results) + if #results == 0 then + return + end + for i = 1, #results do + diags[#diags+1] = buildDiagnostic(uri, results[i]) + end + end) +end + +function m.doDiagnostic(uri) + if not config.config.diagnostics.enable then + return + end + uri = uri:lower() + if files.isLibrary(uri) then + return + end + + await.delay() + + local ast = files.getAst(uri) + if not ast then + m.clear(uri) + return + end + + local syntax = m.syntaxErrors(uri, ast) + local diags = {} + + local function pushResult() + local full = merge(syntax, diags) + if not full then + m.clear(uri) + return + end + + if util.equal(m.cache, full) then + return + end + m.cache[uri] = full + + proto.notify('textDocument/publishDiagnostics', { + uri = files.getOriginUri(uri), + diagnostics = full, + }) + end + + if await.hasID 'diagnosticsAll' then + m.checkStepResult = nil + else + local clock = os.clock() + m.checkStepResult = function () + if os.clock() - clock >= 0.2 then + pushResult() + clock = os.clock() + end + end + end + + m.diagnostics(uri, diags) + pushResult() +end + +function m.refresh(uri) + if not m._start then + return + end + await.call(function () + if uri then + m.doDiagnostic(uri) + end + m.diagnosticsAll() + end, 'files.version') +end + +function m.diagnosticsAll() + if not config.config.diagnostics.enable then + m.clearAll() + return + end + if not m._start then + return + end + local delay = config.config.diagnostics.workspaceDelay / 1000 + if delay < 0 then + return + end + await.close 'diagnosticsAll' + await.call(function () + await.sleep(delay) + m.diagnosticsAllClock = os.clock() + local clock = os.clock() + for uri in files.eachFile() do + m.doDiagnostic(uri) + await.delay() + end + log.debug('全文诊断耗时:', os.clock() - clock) + end, 'files.version', 'diagnosticsAll') +end + +function m.start() + m._start = true + m.diagnosticsAll() +end + +function m.checkStepResult() + if await.hasID 'diagnosticsAll' then + return + end +end + +function m.checkWorkspaceDiag() + if not await.hasID 'diagnosticsAll' then + return + end + local speedRate = config.config.diagnostics.workspaceRate + if speedRate <= 0 or speedRate >= 100 then + return + end + local currentClock = os.clock() + local passed = currentClock - m.diagnosticsAllClock + local sleepTime = passed * (100 - speedRate) / speedRate + m.sleepRest + m.sleepRest = 0.0 + if sleepTime < 0.001 then + m.sleepRest = m.sleepRest + sleepTime + return + end + if sleepTime > 0.1 then + m.sleepRest = sleepTime - 0.1 + sleepTime = 0.1 + end + await.sleep(sleepTime) + m.diagnosticsAllClock = os.clock() + return false +end + +files.watch(function (ev, uri) + if ev == 'remove' then + m.clear(uri) + elseif ev == 'update' then + m.refresh(uri) + elseif ev == 'open' then + m.doDiagnostic(uri) + end +end) + +await.watch(function (ev, co) + if ev == 'delay' then + if m.checkStepResult then + m.checkStepResult() + end + return m.checkWorkspaceDiag() + end +end) + +return m diff --git a/script/provider/init.lua b/script/provider/init.lua new file mode 100644 index 00000000..7eafb70a --- /dev/null +++ b/script/provider/init.lua @@ -0,0 +1 @@ +require 'provider.provider' diff --git a/script/provider/markdown.lua b/script/provider/markdown.lua new file mode 100644 index 00000000..ca76ec89 --- /dev/null +++ b/script/provider/markdown.lua @@ -0,0 +1,26 @@ +local mt = {} +mt.__index = mt +mt.__name = 'markdown' + +function mt:add(language, text) + if not text or #text == 0 then + return + end + if language == 'md' then + if self._last == 'md' then + self[#self+1] = '' + end + self[#self+1] = text + else + self[#self+1] = ('```%s\n%s\n```'):format(language, text) + end + self._last = language +end + +function mt:string() + return table.concat(self, '\n') +end + +return function () + return setmetatable({}, mt) +end diff --git a/script/provider/provider.lua b/script/provider/provider.lua new file mode 100644 index 00000000..3508116f --- /dev/null +++ b/script/provider/provider.lua @@ -0,0 +1,642 @@ +local util = require 'utility' +local cap = require 'provider.capability' +local completion= require 'provider.completion' +local semantic = require 'provider.semantic-tokens' +local await = require 'await' +local files = require 'files' +local proto = require 'proto.proto' +local define = require 'proto.define' +local workspace = require 'workspace' +local config = require 'config' +local library = require 'library' +local markdown = require 'provider.markdown' +local client = require 'provider.client' +local furi = require 'file-uri' +local pub = require 'pub' +local fs = require 'bee.filesystem' +local lang = require 'language' + +local function updateConfig() + local diagnostics = require 'provider.diagnostic' + local vm = require 'vm' + local configs = proto.awaitRequest('workspace/configuration', { + items = { + { + scopeUri = workspace.uri, + section = 'Lua', + }, + { + scopeUri = workspace.uri, + section = 'files.associations', + }, + { + scopeUri = workspace.uri, + section = 'files.exclude', + } + }, + }) + + local updated = configs[1] + local other = { + associations = configs[2], + exclude = configs[3], + } + + local oldConfig = util.deepCopy(config.config) + local oldOther = util.deepCopy(config.other) + config.setConfig(updated, other) + local newConfig = config.config + local newOther = config.other + if not util.equal(oldConfig.runtime, newConfig.runtime) then + library.init() + workspace.reload() + end + if not util.equal(oldConfig.diagnostics, newConfig.diagnostics) then + diagnostics.diagnosticsAll() + end + if not util.equal(oldConfig.plugin, newConfig.plugin) then + end + if not util.equal(oldConfig.workspace, newConfig.workspace) + or not util.equal(oldConfig.plugin, newConfig.plugin) + or not util.equal(oldOther.associations, newOther.associations) + or not util.equal(oldOther.exclude, newOther.exclude) + then + workspace.reload() + end + if not util.equal(oldConfig.luadoc, newConfig.luadoc) then + files.flushCache() + end + if not util.equal(oldConfig.intelliSense, newConfig.intelliSense) then + files.flushCache() + end + + if newConfig.completion.enable then + completion.enable() + else + completion.disable() + end + if newConfig.color.mode == 'Semantic' then + semantic.enable() + else + semantic.disable() + end +end + +proto.on('initialize', function (params) + client.init(params) + library.init() + workspace.init(params.rootUri) + return { + capabilities = cap.getIniter(), + serverInfo = { + name = 'sumneko.lua', + }, + } +end) + +proto.on('initialized', function (params) + updateConfig() + proto.awaitRequest('client/registerCapability', { + registrations = { + -- 监视文件变化 + { + id = '0', + method = 'workspace/didChangeWatchedFiles', + registerOptions = { + watchers = { + { + globPattern = '**/', + kind = 1 | 2 | 4, + } + }, + }, + }, + -- 配置变化 + { + id = '1', + method = 'workspace/didChangeConfiguration', + } + } + }) + await.call(workspace.awaitPreload) + return true +end) + +proto.on('exit', function () + log.info('Server exited.') + os.exit(true) +end) + +proto.on('shutdown', function () + log.info('Server shutdown.') + return true +end) + +proto.on('workspace/didChangeConfiguration', function () + updateConfig() +end) + +proto.on('workspace/didChangeWatchedFiles', function (params) + for _, change in ipairs(params.changes) do + local uri = change.uri + -- TODO 创建文件与删除文件直接重新扫描(文件改名、文件夹删除等情况太复杂了) + if change.type == define.FileChangeType.Created + or change.type == define.FileChangeType.Deleted then + workspace.reload() + break + elseif change.type == define.FileChangeType.Changed then + -- 如果文件处于关闭状态,则立即更新;否则等待didChange协议来更新 + if files.isLua(uri) and not files.isOpen(uri) then + files.setText(uri, pub.awaitTask('loadFile', uri)) + else + local path = furi.decode(uri) + local filename = fs.path(path):filename():string() + -- 排除类文件发生更改需要重新扫描 + if files.eq(filename, '.gitignore') + or files.eq(filename, '.gitmodules') then + workspace.reload() + break + end + end + end + end +end) + +proto.on('textDocument/didOpen', function (params) + local doc = params.textDocument + local uri = doc.uri + local text = doc.text + files.open(uri) + files.setText(uri, text) +end) + +proto.on('textDocument/didClose', function (params) + local doc = params.textDocument + local uri = doc.uri + files.close(uri) + if not files.isLua(uri) then + files.remove(uri) + end +end) + +proto.on('textDocument/didChange', function (params) + local doc = params.textDocument + local change = params.contentChanges + local uri = doc.uri + local text = change[1].text + if files.isLua(uri) or files.isOpen(uri) then + --log.debug('didChange:', uri) + files.setText(uri, text) + --log.debug('setText:', #text) + end +end) + +proto.on('textDocument/hover', function (params) + await.close 'hover' + await.setID 'hover' + local core = require 'core.hover' + local doc = params.textDocument + local uri = doc.uri + if not files.exists(uri) then + return nil + end + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offsetOfWord(lines, text, params.position) + local hover = core.byUri(uri, offset) + if not hover then + return nil + end + local md = markdown() + md:add('lua', hover.label) + md:add('md', hover.description) + return { + contents = { + value = md:string(), + kind = 'markdown', + }, + range = define.range(lines, text, hover.source.start, hover.source.finish), + } +end) + +proto.on('textDocument/definition', function (params) + local core = require 'core.definition' + local uri = params.textDocument.uri + if not files.exists(uri) then + return nil + end + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offsetOfWord(lines, text, params.position) + local result = core(uri, offset) + if not result then + return nil + end + local response = {} + for i, info in ipairs(result) do + local targetUri = info.uri + if targetUri then + local targetLines = files.getLines(targetUri) + local targetText = files.getText(targetUri) + response[i] = define.locationLink(targetUri + , define.range(targetLines, targetText, info.target.start, info.target.finish) + , define.range(targetLines, targetText, info.target.start, info.target.finish) + , define.range(lines, text, info.source.start, info.source.finish) + ) + end + end + return response +end) + +proto.on('textDocument/references', function (params) + local core = require 'core.reference' + local uri = params.textDocument.uri + if not files.exists(uri) then + return nil + end + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offsetOfWord(lines, text, params.position) + local result = core(uri, offset) + if not result then + return nil + end + local response = {} + for i, info in ipairs(result) do + local targetUri = info.uri + local targetLines = files.getLines(targetUri) + local targetText = files.getText(targetUri) + response[i] = define.location(targetUri + , define.range(targetLines, targetText, info.target.start, info.target.finish) + ) + end + return response +end) + +proto.on('textDocument/documentHighlight', function (params) + local core = require 'core.highlight' + local uri = params.textDocument.uri + if not files.exists(uri) then + return nil + end + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offsetOfWord(lines, text, params.position) + local result = core(uri, offset) + if not result then + return nil + end + local response = {} + for _, info in ipairs(result) do + response[#response+1] = { + range = define.range(lines, text, info.start, info.finish), + kind = info.kind, + } + end + return response +end) + +proto.on('textDocument/rename', function (params) + local core = require 'core.rename' + local uri = params.textDocument.uri + if not files.exists(uri) then + return nil + end + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offsetOfWord(lines, text, params.position) + local result = core.rename(uri, offset, params.newName) + if not result then + return nil + end + local workspaceEdit = { + changes = {}, + } + for _, info in ipairs(result) do + local ruri = info.uri + local rlines = files.getLines(ruri) + local rtext = files.getText(ruri) + if not workspaceEdit.changes[ruri] then + workspaceEdit.changes[ruri] = {} + end + local textEdit = define.textEdit(define.range(rlines, rtext, info.start, info.finish), info.text) + workspaceEdit.changes[ruri][#workspaceEdit.changes[ruri]+1] = textEdit + end + return workspaceEdit +end) + +proto.on('textDocument/prepareRename', function (params) + local core = require 'core.rename' + local uri = params.textDocument.uri + if not files.exists(uri) then + return nil + end + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offsetOfWord(lines, text, params.position) + local result = core.prepareRename(uri, offset) + if not result then + return nil + end + return { + range = define.range(lines, text, result.start, result.finish), + placeholder = result.text, + } +end) + +proto.on('textDocument/completion', function (params) + --log.info(util.dump(params)) + local core = require 'core.completion' + --log.debug('completion:', params.context and params.context.triggerKind, params.context and params.context.triggerCharacter) + local uri = params.textDocument.uri + if not files.exists(uri) then + return nil + end + await.setPriority(1000) + local clock = os.clock() + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offset(lines, text, params.position) + local result = core.completion(uri, offset) + local passed = os.clock() - clock + if passed > 0.1 then + log.warn(('Completion takes %.3f sec.'):format(passed)) + end + if not result then + return nil + end + local easy = false + local items = {} + for i, res in ipairs(result) do + local item = { + label = res.label, + kind = res.kind, + deprecated = res.deprecated, + sortText = ('%04d'):format(i), + insertText = res.insertText, + insertTextFormat = res.insertTextFormat, + textEdit = res.textEdit and { + range = define.range( + lines, + text, + res.textEdit.start, + res.textEdit.finish + ), + newText = res.textEdit.newText, + }, + additionalTextEdits = res.additionalTextEdits and (function () + local t = {} + for j, edit in ipairs(res.additionalTextEdits) do + t[j] = { + range = define.range( + lines, + text, + edit.start, + edit.finish + ) + } + end + return t + end)(), + documentation = res.description and { + value = res.description, + kind = 'markdown', + }, + } + if res.id then + if easy and os.clock() - clock < 0.05 then + local resolved = core.resolve(res.id) + if resolved then + item.detail = resolved.detail + item.documentation = resolved.description and { + value = resolved.description, + kind = 'markdown', + } + end + else + easy = false + item.data = { + version = files.globalVersion, + id = res.id, + } + end + end + items[i] = item + end + return { + isIncomplete = false, + items = items, + } +end) + +proto.on('completionItem/resolve', function (item) + local core = require 'core.completion' + if not item.data then + return item + end + local globalVersion = item.data.version + local id = item.data.id + if globalVersion ~= files.globalVersion then + return item + end + --await.setPriority(1000) + local resolved = core.resolve(id) + if not resolved then + return nil + end + item.detail = resolved.detail + item.documentation = resolved.description and { + value = resolved.description, + kind = 'markdown', + } + return item +end) + +proto.on('textDocument/signatureHelp', function (params) + if not config.config.signatureHelp.enable then + return nil + end + local uri = params.textDocument.uri + if not files.exists(uri) then + return nil + end + await.close('signatureHelp') + await.setID('signatureHelp') + local lines = files.getLines(uri) + local text = files.getText(uri) + local offset = define.offset(lines, text, params.position) + local core = require 'core.signature' + local results = core(uri, offset) + if not results then + return nil + end + local infos = {} + for i, result in ipairs(results) do + local parameters = {} + for j, param in ipairs(result.params) do + parameters[j] = { + label = { + param.label[1] - 1, + param.label[2], + } + } + end + infos[i] = { + label = result.label, + parameters = parameters, + activeParameter = result.index - 1, + documentation = result.description and { + value = result.description, + kind = 'markdown', + }, + } + end + return { + signatures = infos, + } +end) + +proto.on('textDocument/documentSymbol', function (params) + local core = require 'core.document-symbol' + local uri = params.textDocument.uri + local lines = files.getLines(uri) + local text = files.getText(uri) + while not lines or not text do + await.sleep(0.1) + lines = files.getLines(uri) + text = files.getText(uri) + end + + local symbols = core(uri) + if not symbols then + return nil + end + + local function convert(symbol) + await.delay() + symbol.range = define.range( + lines, + text, + symbol.range[1], + symbol.range[2] + ) + symbol.selectionRange = define.range( + lines, + text, + symbol.selectionRange[1], + symbol.selectionRange[2] + ) + if symbol.name == '' then + symbol.name = lang.script.SYMBOL_ANONYMOUS + end + symbol.valueRange = nil + if symbol.children then + for _, child in ipairs(symbol.children) do + convert(child) + end + end + end + + for _, symbol in ipairs(symbols) do + convert(symbol) + end + + return symbols +end) + +proto.on('textDocument/codeAction', function (params) + local core = require 'core.code-action' + local uri = params.textDocument.uri + local range = params.range + local diagnostics = params.context.diagnostics + local results = core(uri, range, diagnostics) + + if not results or #results == 0 then + return nil + end + + return results +end) + +proto.on('workspace/executeCommand', function (params) + local command = params.command:gsub(':.+', '') + if command == 'lua.removeSpace' then + local core = require 'core.command.removeSpace' + return core(params.arguments[1]) + elseif command == 'lua.solve' then + local core = require 'core.command.solve' + return core(params.arguments[1]) + end +end) + +proto.on('workspace/symbol', function (params) + local core = require 'core.workspace-symbol' + + await.close('workspace/symbol') + await.setID('workspace/symbol') + + local symbols = core(params.query) + if not symbols or #symbols == 0 then + return nil + end + + local function convert(symbol) + symbol.location = define.location( + symbol.uri, + define.range( + files.getLines(symbol.uri), + files.getText(symbol.uri), + symbol.range[1], + symbol.range[2] + ) + ) + symbol.uri = nil + end + + for _, symbol in ipairs(symbols) do + convert(symbol) + end + + return symbols +end) + + +proto.on('textDocument/semanticTokens/full', function (params) + local core = require 'core.semantic-tokens' + local uri = params.textDocument.uri + log.debug('semanticTokens/full', uri) + local text = files.getText(uri) + while not text do + await.sleep(0.1) + text = files.getText(uri) + end + local results = core(uri, 0, #text) + if not results or #results == 0 then + return nil + end + return { + data = results + } +end) + +proto.on('textDocument/semanticTokens/range', function (params) + local core = require 'core.semantic-tokens' + local uri = params.textDocument.uri + log.debug('semanticTokens/range', uri) + local lines = files.getLines(uri) + local text = files.getText(uri) + while not lines or not text do + await.sleep(0.1) + lines = files.getLines(uri) + text = files.getText(uri) + end + local start = define.offset(lines, text, params.range.start) + local finish = define.offset(lines, text, params.range['end']) + local results = core(uri, start, finish) + if not results or #results == 0 then + return nil + end + return { + data = results + } +end) diff --git a/script/provider/semantic-tokens.lua b/script/provider/semantic-tokens.lua new file mode 100644 index 00000000..17985bcd --- /dev/null +++ b/script/provider/semantic-tokens.lua @@ -0,0 +1,64 @@ +local proto = require 'proto' +local define = require 'proto.define' +local client = require 'provider.client' + +local isEnable = false + +local function toArray(map) + local array = {} + for k in pairs(map) do + array[#array+1] = k + end + table.sort(array, function (a, b) + return map[a] < map[b] + end) + return array +end + +local function enable() + if isEnable then + return + end + if not client.info.capabilities.textDocument.semanticTokens then + return + end + isEnable = true + log.debug('Enable semantic tokens.') + proto.awaitRequest('client/registerCapability', { + registrations = { + { + id = 'semantic-tokens', + method = 'textDocument/semanticTokens', + registerOptions = { + legend = { + tokenTypes = toArray(define.TokenTypes), + tokenModifiers = toArray(define.TokenModifiers), + }, + range = true, + full = true, + }, + }, + } + }) +end + +local function disable() + if not isEnable then + return + end + isEnable = false + log.debug('Disable semantic tokens.') + proto.awaitRequest('client/unregisterCapability', { + unregisterations = { + { + id = 'semantic-tokens', + method = 'textDocument/semanticTokens', + }, + } + }) +end + +return { + enable = enable, + disable = disable, +} diff --git a/script/pub/init.lua b/script/pub/init.lua new file mode 100644 index 00000000..61b43da7 --- /dev/null +++ b/script/pub/init.lua @@ -0,0 +1,4 @@ +local pub = require 'pub.pub' +require 'pub.report' + +return pub diff --git a/script/pub/pub.lua b/script/pub/pub.lua new file mode 100644 index 00000000..ad1cd749 --- /dev/null +++ b/script/pub/pub.lua @@ -0,0 +1,242 @@ +local thread = require 'bee.thread' +local utility = require 'utility' +local await = require 'await' +local timer = require 'timer' + +local errLog = thread.channel 'errlog' +local type = type +local counter = utility.counter() + +local braveTemplate = [[ +package.path = %q +package.cpath = %q +DEVELOP = %s +DBGPORT = %d +DBGWAIT = %s + +collectgarbage 'generational' + +log = require 'brave.log' + +xpcall(dofile, log.error, %q) +local brave = require 'brave' +brave.register(%d) +]] + +---@class pub +local m = {} +m.type = 'pub' +m.braves = {} +m.ability = {} +m.taskQueue = {} + +--- 注册酒馆的功能 +function m.on(name, callback) + m.ability[name] = callback +end + +--- 招募勇者,勇者会从公告板上领取任务,完成任务后到看板娘处交付任务 +---@param num integer +function m.recruitBraves(num) + for _ = 1, num do + local id = #m.braves + 1 + log.info('Create brave:', id) + thread.newchannel('taskpad' .. id) + thread.newchannel('waiter' .. id) + m.braves[id] = { + id = id, + taskpad = thread.channel('taskpad' .. id), + waiter = thread.channel('waiter' .. id), + thread = thread.thread(braveTemplate:format( + package.path, + package.cpath, + DEVELOP, + DBGPORT or 11412, + DBGWAIT or 'nil', + (ROOT / 'debugger.lua'):string(), + id + )), + taskMap = {}, + currentTask = nil, + memory = 0, + } + end +end + +--- 勇者是否有空 +function m.isIdle(brave) + return next(brave.taskMap) == nil +end + +--- 给勇者推送任务 +function m.pushTask(brave, info) + if info.removed then + return false + end + brave.taskpad:push(info.name, info.id, info.params) + brave.taskMap[info.id] = info + --log.info(('Push task %q(%d) to # %d, queue length %d'):format(info.name, info.id, brave.id, #m.taskQueue)) + return true +end + +--- 从勇者处接收任务反馈 +function m.popTask(brave, id, result) + local info = brave.taskMap[id] + if not info then + log.warn(('Brave pushed unknown task result: # %d => [%d]'):format(brave.id, id)) + return + end + brave.taskMap[id] = nil + --log.info(('Pop task %q(%d) from # %d'):format(info.name, info.id, brave.id)) + m.checkWaitingTask(brave) + if not info.removed then + info.removed = true + if info.callback then + xpcall(info.callback, log.error, result) + end + end +end + +--- 从勇者处接收报告 +function m.popReport(brave, name, params) + local abil = m.ability[name] + if not abil then + log.warn(('Brave pushed unknown report: # %d => %q'):format(brave.id, name)) + return + end + xpcall(abil, log.error, params, brave) +end + +--- 发布任务 +---@parma name string +---@param params any +function m.awaitTask(name, params) + local info = { + id = counter(), + name = name, + params = params, + } + for _, brave in ipairs(m.braves) do + if m.isIdle(brave) then + if m.pushTask(brave, info) then + return await.wait(function (waker) + info.callback = waker + end) + else + return nil + end + end + end + -- 如果所有勇者都在战斗,那么把任务缓存到队列里 + -- 当有勇者提交任务反馈后,尝试把按顺序将堆积任务 + -- 交给该勇者 + m.taskQueue[#m.taskQueue+1] = info + --log.info(('Add task %q(%d) in queue, length %d.'):format(name, info.id, #m.taskQueue)) + return await.wait(function (waker) + info.callback = waker + end) +end + +--- 发布同步任务,如果任务进入了队列,会返回执行器 +--- 通过 jumpQueue 可以插队 +---@parma name string +---@param params any +---@param callback function +function m.task(name, params, callback) + local info = { + id = counter(), + name = name, + params = params, + callback = callback, + } + for _, brave in ipairs(m.braves) do + if m.isIdle(brave) then + m.pushTask(brave, info) + return nil + end + end + -- 如果所有勇者都在战斗,那么把任务缓存到队列里 + -- 当有勇者提交任务反馈后,尝试把按顺序将堆积任务 + -- 交给该勇者 + m.taskQueue[#m.taskQueue+1] = info + --log.info(('Add task %q(%d) in queue, length %d.'):format(name, info.id, #m.taskQueue)) + return info +end + +--- 插队 +function m.jumpQueue(info) + for i = 2, #m.taskQueue do + if m.taskQueue[i] == info then + m.taskQueue[i] = nil + table.move(m.taskQueue, 1, i - 1, 2) + m.taskQueue[1] = info + return + end + end +end + +--- 移除任务 +function m.remove(info) + info.removed = true + for i = 1, #m.taskQueue do + if m.taskQueue[i] == info then + table.remove(m.taskQueue[i], i) + return + end + end +end + +--- 检查堆积任务 +function m.checkWaitingTask(brave) + if #m.taskQueue == 0 then + return + end + -- 如果勇者还有其他活要忙,那么让他继续忙去吧 + if next(brave.taskMap) then + return + end + while #m.taskQueue > 0 do + local info = table.remove(m.taskQueue, 1) + if m.pushTask(brave, info) then + break + end + end +end + +--- 接收反馈 +---|返回接收到的反馈数量 +---@return integer +function m.recieve() + for _, brave in ipairs(m.braves) do + while true do + local suc, id, result = brave.waiter:pop() + if not suc then + goto CONTINUE + end + if type(id) == 'string' then + m.popReport(brave, id, result) + else + m.popTask(brave, id, result) + end + end + ::CONTINUE:: + end +end + +--- 检查伤亡情况 +function m.checkDead() + while true do + local suc, err = errLog:pop() + if not suc then + break + end + log.error('Brave is dead!: ' .. err) + end +end + +function m.step() + m.checkDead() + m.recieve() +end + +return m diff --git a/script/pub/report.lua b/script/pub/report.lua new file mode 100644 index 00000000..549277e1 --- /dev/null +++ b/script/pub/report.lua @@ -0,0 +1,26 @@ +local pub = require 'pub.pub' +local await = require 'await' + +pub.on('log', function (params, brave) + log.raw(brave.id, params.level, params.msg, params.src, params.line, params.clock) +end) + +pub.on('mem', function (count, brave) + brave.memory = count +end) + +pub.on('proto', function (params) + local proto = require 'proto' + await.call(function () + if params.method then + proto.doMethod(params) + else + proto.doResponse(params) + end + end) +end) + +pub.on('protoerror', function (err) + log.warn('Load proto error:', err) + os.exit(true) +end) diff --git a/script/service/init.lua b/script/service/init.lua new file mode 100644 index 00000000..eb0bd057 --- /dev/null +++ b/script/service/init.lua @@ -0,0 +1,3 @@ +local service = require 'service.service' + +return service diff --git a/script/service/service.lua b/script/service/service.lua new file mode 100644 index 00000000..11cc7b19 --- /dev/null +++ b/script/service/service.lua @@ -0,0 +1,158 @@ +local pub = require 'pub' +local thread = require 'bee.thread' +local await = require 'await' +local timer = require 'timer' +local proto = require 'proto' +local vm = require 'vm' + +local m = {} +m.type = 'service' + +local function countMemory() + local mems = {} + local total = 0 + mems[0] = collectgarbage 'count' + total = total + collectgarbage 'count' + for id, brave in ipairs(pub.braves) do + mems[id] = brave.memory + total = total + brave.memory + end + return total, mems +end + +function m.reportMemoryCollect() + local totalMemBefore = countMemory() + local clock = os.clock() + collectgarbage() + local passed = os.clock() - clock + local totalMemAfter, mems = countMemory() + + local lines = {} + lines[#lines+1] = ' --------------- Memory ---------------' + lines[#lines+1] = (' Total: %.3f(%.3f) MB'):format(totalMemAfter / 1000.0, totalMemBefore / 1000.0) + for i = 0, #mems do + lines[#lines+1] = (' # %02d : %.3f MB'):format(i, mems[i] / 1000.0) + end + lines[#lines+1] = (' Collect garbage takes [%.3f] sec'):format(passed) + return table.concat(lines, '\n') +end + +function m.reportMemory() + local totalMem, mems = countMemory() + + local lines = {} + lines[#lines+1] = ' --------------- Memory ---------------' + lines[#lines+1] = (' Total: %.3f MB'):format(totalMem / 1000.0) + for i = 0, #mems do + lines[#lines+1] = (' # %02d : %.3f MB'):format(i, mems[i] / 1000.0) + end + return table.concat(lines, '\n') +end + +function m.reportTask() + local total = 0 + local running = 0 + local suspended = 0 + local normal = 0 + local dead = 0 + + for co in pairs(await.coMap) do + total = total + 1 + local status = coroutine.status(co) + if status == 'running' then + running = running + 1 + elseif status == 'suspended' then + suspended = suspended + 1 + elseif status == 'normal' then + normal = normal + 1 + elseif status == 'dead' then + dead = dead + 1 + end + end + + local lines = {} + lines[#lines+1] = ' --------------- Coroutine ---------------' + lines[#lines+1] = (' Total: %d'):format(total) + lines[#lines+1] = (' Running: %d'):format(running) + lines[#lines+1] = (' Suspended: %d'):format(suspended) + lines[#lines+1] = (' Normal: %d'):format(normal) + lines[#lines+1] = (' Dead: %d'):format(dead) + return table.concat(lines, '\n') +end + +function m.reportCache() + local total = 0 + local dead = 0 + + for cache in pairs(vm.cacheTracker) do + total = total + 1 + if cache.dead then + dead = dead + 1 + end + end + + local lines = {} + lines[#lines+1] = ' --------------- Cache ---------------' + lines[#lines+1] = (' Total: %d'):format(total) + lines[#lines+1] = (' Dead: %d'):format(dead) + return table.concat(lines, '\n') +end + +function m.reportProto() + local holdon = 0 + local waiting = 0 + + for _ in pairs(proto.holdon) do + holdon = holdon + 1 + end + for _ in pairs(proto.waiting) do + waiting = waiting + 1 + end + + local lines = {} + lines[#lines+1] = ' --------------- Proto ---------------' + lines[#lines+1] = (' Holdon: %d'):format(holdon) + lines[#lines+1] = (' Waiting: %d'):format(waiting) + return table.concat(lines, '\n') +end + +function m.report() + local t = timer.loop(60.0, function () + local lines = {} + lines[#lines+1] = '' + lines[#lines+1] = '========= Medical Examination Report =========' + lines[#lines+1] = m.reportMemory() + lines[#lines+1] = m.reportTask() + lines[#lines+1] = m.reportCache() + lines[#lines+1] = m.reportProto() + lines[#lines+1] = '==============================================' + + log.debug(table.concat(lines, '\n')) + end) + t:onTimer() +end + +function m.startTimer() + while true do + ::CONTINUE:: + pub.step() + if await.step() then + goto CONTINUE + end + thread.sleep(0.001) + timer.update() + end +end + +function m.start() + await.setErrorHandle(log.error) + pub.recruitBraves(4) + proto.listen() + m.report() + + require 'provider' + + m.startTimer() +end + +return m diff --git a/script/timer.lua b/script/timer.lua new file mode 100644 index 00000000..1d4343f1 --- /dev/null +++ b/script/timer.lua @@ -0,0 +1,218 @@ +local setmetatable = setmetatable +local mathMax = math.max +local mathFloor = math.floor +local osClock = os.clock + +_ENV = nil + +local curFrame = 0 +local maxFrame = 0 +local curIndex = 0 +local freeQueue = {} +local timer = {} + +local function allocQueue() + local n = #freeQueue + if n > 0 then + local r = freeQueue[n] + freeQueue[n] = nil + return r + else + return {} + end +end + +local function mTimeout(self, timeout) + if self._pauseRemaining or self._running then + return + end + local ti = curFrame + timeout + local q = timer[ti] + if q == nil then + q = allocQueue() + timer[ti] = q + end + self._timeoutFrame = ti + self._running = true + q[#q + 1] = self +end + +local function mWakeup(self) + if self._removed then + return + end + self._running = false + if self._onTimer then + self:_onTimer() + end + if self._removed then + return + end + if self._timerCount then + if self._timerCount > 1 then + self._timerCount = self._timerCount - 1 + mTimeout(self, self._timeout) + else + self._removed = true + end + else + mTimeout(self, self._timeout) + end +end + +local function getRemaining(self) + if self._removed then + return 0 + end + if self._pauseRemaining then + return self._pauseRemaining + end + if self._timeoutFrame == curFrame then + return self._timeout or 0 + end + return self._timeoutFrame - curFrame +end + +local function onTick() + local q = timer[curFrame] + if q == nil then + curIndex = 0 + return + end + for i = curIndex + 1, #q do + local callback = q[i] + curIndex = i + q[i] = nil + if callback then + mWakeup(callback) + end + end + curIndex = 0 + timer[curFrame] = nil + freeQueue[#freeQueue + 1] = q +end + +local m = {} +local mt = {} +mt.__index = mt +mt.type = 'timer' + +function mt:__tostring() + return '[table:timer]' +end + +function mt:__call() + if self._onTimer then + self:_onTimer() + end +end + +function mt:remove() + self._removed = true +end + +function mt:pause() + if self._removed or self._pauseRemaining then + return + end + self._pauseRemaining = getRemaining(self) + self._running = false + local ti = self._timeoutFrame + local q = timer[ti] + if q then + for i = #q, 1, -1 do + if q[i] == self then + q[i] = false + return + end + end + end +end + +function mt:resume() + if self._removed or not self._pauseRemaining then + return + end + local timeout = self._pauseRemaining + self._pauseRemaining = nil + mTimeout(self, timeout) +end + +function mt:restart() + if self._removed or self._pauseRemaining or not self._running then + return + end + local ti = self._timeoutFrame + local q = timer[ti] + if q then + for i = #q, 1, -1 do + if q[i] == self then + q[i] = false + break + end + end + end + self._running = false + mTimeout(self, self._timeout) +end + +function mt:remaining() + return getRemaining(self) / 1000.0 +end + +function mt:onTimer() + self:_onTimer() +end + +function m.wait(timeout, onTimer) + local t = setmetatable({ + ['_timeout'] = mathMax(mathFloor(timeout * 1000.0), 1), + ['_onTimer'] = onTimer, + ['_timerCount'] = 1, + }, mt) + mTimeout(t, t._timeout) + return t +end + +function m.loop(timeout, onTimer) + local t = setmetatable({ + ['_timeout'] = mathFloor(timeout * 1000.0), + ['_onTimer'] = onTimer, + }, mt) + mTimeout(t, t._timeout) + return t +end + +function m.timer(timeout, count, onTimer) + if count == 0 then + return m.loop(timeout, onTimer) + end + local t = setmetatable({ + ['_timeout'] = mathFloor(timeout * 1000.0), + ['_onTimer'] = onTimer, + ['_timerCount'] = count, + }, mt) + mTimeout(t, t._timeout) + return t +end + +function m.clock() + return curFrame / 1000.0 +end + +local lastClock = osClock() +function m.update() + local currentClock = osClock() + local delta = currentClock - lastClock + lastClock = currentClock + if curIndex ~= 0 then + curFrame = curFrame - 1 + end + maxFrame = maxFrame + delta * 1000.0 + while curFrame < maxFrame do + curFrame = curFrame + 1 + onTick() + end +end + +return m diff --git a/script/utility.lua b/script/utility.lua new file mode 100644 index 00000000..a1ea1804 --- /dev/null +++ b/script/utility.lua @@ -0,0 +1,559 @@ +local tableSort = table.sort +local stringRep = string.rep +local tableConcat = table.concat +local tostring = tostring +local type = type +local pairs = pairs +local ipairs = ipairs +local next = next +local rawset = rawset +local move = table.move +local setmetatable = setmetatable +local mathType = math.type +local mathCeil = math.ceil +local getmetatable = getmetatable +local mathAbs = math.abs +local mathRandom = math.random +local ioOpen = io.open +local utf8Len = utf8.len +local mathHuge = math.huge +local inf = 1 / 0 +local nan = 0 / 0 + +_ENV = nil + +local function formatNumber(n) + if n == inf + or n == -inf + or n == nan + or n ~= n then -- IEEE 标准中,NAN 不等于自己。但是某些实现中没有遵守这个规则 + return ('%q'):format(n) + end + local str = ('%.10f'):format(n) + str = str:gsub('%.?0*$', '') + return str +end + +local function isInteger(n) + if mathType then + return mathType(n) == 'integer' + else + return type(n) == 'number' and n % 1 == 0 + end +end + +local TAB = setmetatable({}, { __index = function (self, n) + self[n] = stringRep(' ', n) + return self[n] +end}) + +local RESERVED = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['goto'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +local m = {} + +--- 打印表的结构 +---@param tbl table +---@param option table {optional = 'self'} +---@return string +function m.dump(tbl, option) + if not option then + option = {} + end + if type(tbl) ~= 'table' then + return ('%s'):format(tbl) + end + local lines = {} + local mark = {} + lines[#lines+1] = '{' + local function unpack(tbl, deep) + mark[tbl] = (mark[tbl] or 0) + 1 + local keys = {} + local keymap = {} + local integerFormat = '[%d]' + local alignment = 0 + if #tbl >= 10 then + local width = #tostring(#tbl) + integerFormat = ('[%%0%dd]'):format(mathCeil(width)) + end + for key in pairs(tbl) do + if type(key) == 'string' then + if not key:match('^[%a_][%w_]*$') + or RESERVED[key] + or option['longStringKey'] + then + keymap[key] = ('[%q]'):format(key) + else + keymap[key] = ('%s'):format(key) + end + elseif isInteger(key) then + keymap[key] = integerFormat:format(key) + else + keymap[key] = ('["<%s>"]'):format(tostring(key)) + end + keys[#keys+1] = key + if option['alignment'] then + if #keymap[key] > alignment then + alignment = #keymap[key] + end + end + end + local mt = getmetatable(tbl) + if not mt or not mt.__pairs then + if option['sorter'] then + option['sorter'](keys, keymap) + else + tableSort(keys, function (a, b) + return keymap[a] < keymap[b] + end) + end + end + for _, key in ipairs(keys) do + local keyWord = keymap[key] + if option['noArrayKey'] + and isInteger(key) + and key <= #tbl + then + keyWord = '' + else + if #keyWord < alignment then + keyWord = keyWord .. (' '):rep(alignment - #keyWord) .. ' = ' + else + keyWord = keyWord .. ' = ' + end + end + local value = tbl[key] + local tp = type(value) + if option['format'] and option['format'][key] then + lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, option['format'][key](value, unpack, deep+1)) + elseif tp == 'table' then + if mark[value] and mark[value] > 0 then + lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, option['loop'] or '"<Loop>"') + elseif deep >= (option['deep'] or mathHuge) then + lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, '"<Deep>"') + else + lines[#lines+1] = ('%s%s{'):format(TAB[deep+1], keyWord) + unpack(value, deep+1) + lines[#lines+1] = ('%s},'):format(TAB[deep+1]) + end + elseif tp == 'string' then + lines[#lines+1] = ('%s%s%q,'):format(TAB[deep+1], keyWord, value) + elseif tp == 'number' then + lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, (option['number'] or formatNumber)(value)) + elseif tp == 'nil' then + else + lines[#lines+1] = ('%s%s%s,'):format(TAB[deep+1], keyWord, tostring(value)) + end + end + mark[tbl] = mark[tbl] - 1 + end + unpack(tbl, 0) + lines[#lines+1] = '}' + return tableConcat(lines, '\r\n') +end + +--- 递归判断A与B是否相等 +---@param a any +---@param b any +---@return boolean +function m.equal(a, b) + local tp1 = type(a) + local tp2 = type(b) + if tp1 ~= tp2 then + return false + end + if tp1 == 'table' then + local mark = {} + for k, v in pairs(a) do + mark[k] = true + local res = m.equal(v, b[k]) + if not res then + return false + end + end + for k in pairs(b) do + if not mark[k] then + return false + end + end + return true + elseif tp1 == 'number' then + return mathAbs(a - b) <= 1e-10 + else + return a == b + end +end + +local function sortTable(tbl) + if not tbl then + tbl = {} + end + local mt = {} + local keys = {} + local mark = {} + local n = 0 + for key in next, tbl do + n=n+1;keys[n] = key + mark[key] = true + end + tableSort(keys) + function mt:__newindex(key, value) + rawset(self, key, value) + n=n+1;keys[n] = key + mark[key] = true + if type(value) == 'table' then + sortTable(value) + end + end + function mt:__pairs() + local list = {} + local m = 0 + for key in next, self do + if not mark[key] then + m=m+1;list[m] = key + end + end + if m > 0 then + move(keys, 1, n, m+1) + tableSort(list) + for i = 1, m do + local key = list[i] + keys[i] = key + mark[key] = true + end + n = n + m + end + local i = 0 + return function () + i = i + 1 + local key = keys[i] + return key, self[key] + end + end + + return setmetatable(tbl, mt) +end + +--- 创建一个有序表 +---@param tbl table {optional = 'self'} +---@return table +function m.container(tbl) + return sortTable(tbl) +end + +--- 读取文件 +---@param path string +function m.loadFile(path) + local f, e = ioOpen(path, 'rb') + if not f then + return nil, e + end + if f:read(3) ~= '\xEF\xBB\xBF' then + f:seek("set") + end + local buf = f:read 'a' + f:close() + return buf +end + +--- 写入文件 +---@param path string +---@param content string +function m.saveFile(path, content) + local f, e = ioOpen(path, "wb") + + if f then + f:write(content) + f:close() + return true + else + return false, e + end +end + +--- 计数器 +---@param init integer {optional = 'after'} +---@param step integer {optional = 'after'} +---@return fun():integer +function m.counter(init, step) + if not step then + step = 1 + end + local current = init and (init - 1) or 0 + return function () + current = current + step + return current + end +end + +--- 排序后遍历 +---@param t table +function m.sortPairs(t) + local keys = {} + for k in pairs(t) do + keys[#keys+1] = k + end + tableSort(keys) + local i = 0 + return function () + i = i + 1 + local k = keys[i] + return k, t[k] + end +end + +--- 深拷贝(不处理元表) +---@param source table +---@param target table {optional = 'self'} +function m.deepCopy(source, target) + local mark = {} + local function copy(a, b) + if type(a) ~= 'table' then + return a + end + if mark[a] then + return mark[a] + end + if not b then + b = {} + end + mark[a] = b + for k, v in pairs(a) do + b[copy(k)] = copy(v) + end + return b + end + return copy(source, target) +end + +--- 序列化 +function m.unpack(t) + local result = {} + local tid = 0 + local cache = {} + local function unpack(o) + local id = cache[o] + if not id then + tid = tid + 1 + id = tid + cache[o] = tid + if type(o) == 'table' then + local new = {} + result[tid] = new + for k, v in next, o do + new[unpack(k)] = unpack(v) + end + else + result[id] = o + end + end + return id + end + unpack(t) + return result +end + +--- 反序列化 +function m.pack(t) + local cache = {} + local function pack(id) + local o = cache[id] + if o then + return o + end + o = t[id] + if type(o) == 'table' then + local new = {} + cache[id] = new + for k, v in next, o do + new[pack(k)] = pack(v) + end + return new + else + cache[id] = o + return o + end + end + return pack(1) +end + +--- defer +local deferMT = { __close = function (self) self[1]() end } +function m.defer(callback) + return setmetatable({ callback }, deferMT) +end + +local esc = { + ["'"] = [[\']], + ['"'] = [[\"]], + ['\r'] = [[\r]], + ['\n'] = '\\\n', +} + +function m.viewString(str, quo) + if not quo then + if str:find('[\r\n]') then + quo = '[[' + elseif not str:find("'", 1, true) and str:find('"', 1, true) then + quo = "'" + else + quo = '"' + end + end + if quo == "'" then + str = str:gsub('[\000-\008\011-\012\014-\031\127]', function (char) + return ('\\%03d'):format(char:byte()) + end) + return quo .. str:gsub([=[['\r\n]]=], esc) .. quo + elseif quo == '"' then + str = str:gsub('[\000-\008\011-\012\014-\031\127]', function (char) + return ('\\%03d'):format(char:byte()) + end) + return quo .. str:gsub([=[["\r\n]]=], esc) .. quo + else + local eqnum = #quo - 2 + local fsymb = ']' .. ('='):rep(eqnum) .. ']' + if not str:find(fsymb, 1, true) then + str = str:gsub('[\000-\008\011-\012\014-\031\127]', '') + return quo .. str .. fsymb + end + for i = 0, 10 do + local fsymb = ']' .. ('='):rep(i) .. ']' + if not str:find(fsymb, 1, true) then + local ssymb = '[' .. ('='):rep(i) .. '[' + str = str:gsub('[\000-\008\011-\012\014-\031\127]', '') + return ssymb .. str .. fsymb + end + end + return m.viewString(str, '"') + end +end + +function m.viewLiteral(v) + local tp = type(v) + if tp == 'nil' then + return 'nil' + elseif tp == 'string' then + return m.viewString(v) + elseif tp == 'boolean' then + return tostring(v) + elseif tp == 'number' then + if isInteger(v) then + return tostring(v) + else + return formatNumber(v) + end + end + return nil +end + +function m.utf8Len(str, start, finish) + local len, pos = utf8Len(str, start, finish, true) + if len then + return len + end + return 1 + m.utf8Len(str, start, pos-1) + m.utf8Len(str, pos+1, finish) +end + +function m.revertTable(t) + local len = #t + if len <= 1 then + return t + end + for x = 1, len // 2 do + local y = len - x + 1 + t[x], t[y] = t[y], t[x] + end + return t +end + +function m.randomSortTable(t, max) + local len = #t + if len <= 1 then + return t + end + if not max or max > len then + max = len + end + for x = 1, max do + local y = mathRandom(len) + t[x], t[y] = t[y], t[x] + end + return t +end + +function m.tableMultiRemove(t, index) + local mark = {} + for i = 1, #index do + local v = index[i] + mark[v] = true + end + local offset = 0 + local me = 1 + local len = #t + while true do + local it = me + offset + if it > len then + for i = me, len do + t[i] = nil + end + break + end + if mark[it] then + offset = offset + 1 + else + if me ~= it then + t[me] = t[it] + end + me = me + 1 + end + end +end + +function m.eachLine(text) + local offset = 1 + local lineCount = 0 + return function () + if offset > #text then + return nil + end + lineCount = lineCount + 1 + local nl = text:find('[\r\n]', offset) + if not nl then + local lastLine = text:sub(offset) + offset = #text + 1 + return lastLine + end + local line = text:sub(offset, nl - 1) + if text:sub(nl, nl + 1) == '\r\n' then + offset = nl + 2 + else + offset = nl + 1 + end + return line + end +end + +return m diff --git a/script/vm/eachDef.lua b/script/vm/eachDef.lua new file mode 100644 index 00000000..5ff58889 --- /dev/null +++ b/script/vm/eachDef.lua @@ -0,0 +1,40 @@ +local vm = require 'vm.vm' +local guide = require 'parser.guide' +local files = require 'files' +local util = require 'utility' +local await = require 'await' + +local function eachDef(source, deep) + local results = {} + local lock = vm.lock('eachDef', source) + if not lock then + return results + end + + await.delay() + + local clock = os.clock() + local myResults, count = guide.requestDefinition(source, vm.interface, deep) + if DEVELOP and os.clock() - clock > 0.1 then + log.warn('requestDefinition', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) + end + vm.mergeResults(results, myResults) + + lock() + + return results +end + +function vm.getDefs(source, deep) + if guide.isGlobal(source) then + local key = guide.getKeyName(source) + return vm.getGlobalSets(key) + else + local cache = vm.getCache('eachDef')[source] + or eachDef(source, deep) + if deep then + vm.getCache('eachDef')[source] = cache + end + return cache + end +end diff --git a/script/vm/eachField.lua b/script/vm/eachField.lua new file mode 100644 index 00000000..ce0e3928 --- /dev/null +++ b/script/vm/eachField.lua @@ -0,0 +1,45 @@ +local vm = require 'vm.vm' +local guide = require 'parser.guide' +local await = require 'await' + +local function eachField(source, deep) + local unlock = vm.lock('eachField', source) + if not unlock then + return {} + end + + while source.type == 'paren' do + source = source.exp + if not source then + return {} + end + end + + await.delay() + local results = guide.requestFields(source, vm.interface, deep) + + unlock() + return results +end + +function vm.getFields(source, deep) + if source.special == '_G' then + return vm.getGlobals '*' + end + if guide.isGlobal(source) then + local name = guide.getKeyName(source) + local cache = vm.getCache('eachFieldOfGlobal')[name] + or vm.getCache('eachField')[source] + or eachField(source, 'deep') + vm.getCache('eachFieldOfGlobal')[name] = cache + vm.getCache('eachField')[source] = cache + return cache + else + local cache = vm.getCache('eachField')[source] + or eachField(source, deep) + if deep then + vm.getCache('eachField')[source] = cache + end + return cache + end +end diff --git a/script/vm/eachRef.lua b/script/vm/eachRef.lua new file mode 100644 index 00000000..4e735abf --- /dev/null +++ b/script/vm/eachRef.lua @@ -0,0 +1,39 @@ +local vm = require 'vm.vm' +local guide = require 'parser.guide' +local util = require 'utility' +local await = require 'await' + +local function getRefs(source, deep) + local results = {} + local lock = vm.lock('eachRef', source) + if not lock then + return results + end + + await.delay() + + local clock = os.clock() + local myResults, count = guide.requestReference(source, vm.interface, deep) + if DEVELOP and os.clock() - clock > 0.1 then + log.warn('requestReference', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) + end + vm.mergeResults(results, myResults) + + lock() + + return results +end + +function vm.getRefs(source, deep) + if guide.isGlobal(source) then + local key = guide.getKeyName(source) + return vm.getGlobals(key) + else + local cache = vm.getCache('eachRef')[source] + or getRefs(source, deep) + if deep then + vm.getCache('eachRef')[source] = cache + end + return cache + end +end diff --git a/script/vm/getClass.lua b/script/vm/getClass.lua new file mode 100644 index 00000000..a8bd7e40 --- /dev/null +++ b/script/vm/getClass.lua @@ -0,0 +1,62 @@ +local vm = require 'vm.vm' +local guide = require 'parser.guide' + +local function lookUpDocClass(source) + local infers = vm.getInfers(source, 'deep') + for _, infer in ipairs(infers) do + if infer.source.type == 'doc.class' + or infer.source.type == 'doc.type' then + return infer.type + end + end +end + +local function getClass(source, classes, depth, deep) + local docClass = lookUpDocClass(source) + if docClass then + classes[#classes+1] = docClass + return + end + if depth > 3 then + return + end + local value = guide.getObjectValue(source) or source + if not deep then + if value and value.type == 'string' then + classes[#classes+1] = value[1] + end + else + for _, src in ipairs(vm.getFields(value)) do + local key = vm.getKeyName(src) + if not key then + goto CONTINUE + end + local lkey = key:lower() + if lkey == 's|type' + or lkey == 's|__name' + or lkey == 's|name' + or lkey == 's|class' then + local value = guide.getObjectValue(src) + if value and value.type == 'string' then + classes[#classes+1] = value[1] + end + end + ::CONTINUE:: + end + end + if #classes ~= 0 then + return + end + vm.eachMeta(source, function (mt) + getClass(mt, classes, depth + 1, deep) + end) +end + +function vm.getClass(source, deep) + local classes = {} + getClass(source, classes, 1, deep) + if #classes == 0 then + return nil + end + return guide.mergeTypes(classes) +end diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua new file mode 100644 index 00000000..c06efc11 --- /dev/null +++ b/script/vm/getDocs.lua @@ -0,0 +1,175 @@ +local files = require 'files' +local util = require 'utility' +local guide = require 'parser.guide' +local vm = require 'vm.vm' +local config = require 'config' + +local function getTypesOfFile(uri) + local types = {} + local ast = files.getAst(uri) + if not ast or not ast.ast.docs then + return types + end + guide.eachSource(ast.ast.docs, function (src) + if src.type == 'doc.type.name' + or src.type == 'doc.class.name' + or src.type == 'doc.extends.name' + or src.type == 'doc.alias.name' then + local name = src[1] + if name then + if not types[name] then + types[name] = {} + end + types[name][#types[name]+1] = src + end + end + end) + return types +end + +local function getDocTypes(name) + local results = {} + for uri in files.eachFile() do + local cache = files.getCache(uri) + cache.classes = cache.classes or getTypesOfFile(uri) + if name == '*' then + for _, sources in util.sortPairs(cache.classes) do + for _, source in ipairs(sources) do + results[#results+1] = source + end + end + else + if cache.classes[name] then + for _, source in ipairs(cache.classes[name]) do + results[#results+1] = source + end + end + end + end + return results +end + +function vm.getDocEnums(doc, mark, results) + mark = mark or {} + if mark[doc] then + return nil + end + mark[doc] = true + results = results or {} + for _, enum in ipairs(doc.enums) do + results[#results+1] = enum + end + for _, resume in ipairs(doc.resumes) do + results[#results+1] = resume + end + for _, unit in ipairs(doc.types) do + if unit.type == 'doc.type.name' then + for _, other in ipairs(vm.getDocTypes(unit[1])) do + if other.type == 'doc.alias.name' then + vm.getDocEnums(other.parent.extends, mark, results) + end + end + end + end + return results +end + +function vm.getDocTypes(name) + local cache = vm.getCache('getDocTypes')[name] + if cache ~= nil then + return cache + end + cache = getDocTypes(name) + vm.getCache('getDocTypes')[name] = cache + return cache +end + +function vm.isMetaFile(uri) + local status = files.getAst(uri) + if not status then + return false + end + local cache = files.getCache(uri) + if cache.isMeta ~= nil then + return cache.isMeta + end + cache.isMeta = false + if not status.ast.docs then + return false + end + for _, doc in ipairs(status.ast.docs) do + if doc.type == 'doc.meta' then + cache.isMeta = true + return true + end + end + return false +end + +function vm.getValidVersions(doc) + if doc.type ~= 'doc.version' then + return + end + local valids = { + ['Lua 5.1'] = false, + ['Lua 5.2'] = false, + ['Lua 5.3'] = false, + ['Lua 5.4'] = false, + ['LuaJIT'] = false, + } + for _, version in ipairs(doc.versions) do + if version.ge and type(version.version) == 'number' then + for ver in pairs(valids) do + local verNumber = tonumber(ver:sub(-3)) + if verNumber and verNumber >= version.version then + valids[ver] = true + end + end + elseif version.le and type(version.version) == 'number' then + for ver in pairs(valids) do + local verNumber = tonumber(ver:sub(-3)) + if verNumber and verNumber <= version.version then + valids[ver] = true + end + end + elseif type(version.version) == 'number' then + valids[('Lua %.1f'):format(version.version)] = true + elseif 'JIT' == version.version then + valids['LuaJIT'] = true + end + end + if valids['Lua 5.1'] then + valids['LuaJIT'] = true + end + return valids +end + +local function isDeprecated(value) + if not value.bindDocs then + return false + end + for _, doc in ipairs(value.bindDocs) do + if doc.type == 'doc.deprecated' then + return true + elseif doc.type == 'doc.version' then + local valids = vm.getValidVersions(doc) + if not valids[config.config.runtime.version] then + return true + end + end + end + return false +end + +function vm.isDeprecated(value) + local defs = vm.getDefs(value, 'deep') + if #defs == 0 then + return false + end + for _, def in ipairs(defs) do + if not isDeprecated(def) then + return false + end + end + return true +end diff --git a/script/vm/getGlobals.lua b/script/vm/getGlobals.lua new file mode 100644 index 00000000..83d6a5e6 --- /dev/null +++ b/script/vm/getGlobals.lua @@ -0,0 +1,192 @@ +local guide = require 'parser.guide' +local vm = require 'vm.vm' +local files = require 'files' +local util = require 'utility' +local config = require 'config' + +local function getGlobalsOfFile(uri) + local cache = files.getCache(uri) + if cache.globals then + return cache.globals + end + local globals = {} + cache.globals = globals + local ast = files.getAst(uri) + if not ast then + return globals + end + local results = guide.findGlobals(ast.ast) + local mark = {} + for _, res in ipairs(results) do + if mark[res] then + goto CONTINUE + end + mark[res] = true + local name = guide.getSimpleName(res) + if name then + if not globals[name] then + globals[name] = {} + end + globals[name][#globals[name]+1] = res + end + ::CONTINUE:: + end + return globals +end + +local function getGlobalSetsOfFile(uri) + local cache = files.getCache(uri) + if cache.globalSets then + return cache.globalSets + end + local globals = {} + cache.globalSets = globals + local ast = files.getAst(uri) + if not ast then + return globals + end + local results = guide.findGlobals(ast.ast) + local mark = {} + for _, res in ipairs(results) do + if mark[res] then + goto CONTINUE + end + mark[res] = true + if vm.isSet(res) then + local name = guide.getSimpleName(res) + if name then + if not globals[name] then + globals[name] = {} + end + globals[name][#globals[name]+1] = res + end + end + ::CONTINUE:: + end + return globals +end + +local function getGlobals(name) + local results = {} + for uri in files.eachFile() do + local globals = getGlobalsOfFile(uri) + if name == '*' then + for _, sources in util.sortPairs(globals) do + for _, source in ipairs(sources) do + results[#results+1] = source + end + end + else + if globals[name] then + for _, source in ipairs(globals[name]) do + results[#results+1] = source + end + end + end + end + return results +end + +local function getGlobalSets(name) + local results = {} + for uri in files.eachFile() do + local globals = getGlobalSetsOfFile(uri) + if name == '*' then + for _, sources in util.sortPairs(globals) do + for _, source in ipairs(sources) do + results[#results+1] = source + end + end + else + if globals[name] then + for _, source in ipairs(globals[name]) do + results[#results+1] = source + end + end + end + end + return results +end + +local function fastGetAnyGlobals() + local results = {} + local mark = {} + for uri in files.eachFile() do + --local globalSets = getGlobalsOfFile(uri) + --for destName, sources in util.sortPairs(globalSets) do + -- if not mark[destName] then + -- mark[destName] = true + -- results[#results+1] = sources[1] + -- end + --end + local globals = getGlobalsOfFile(uri) + for destName, sources in util.sortPairs(globals) do + if not mark[destName] then + mark[destName] = true + results[#results+1] = sources[1] + end + end + end + return results +end + +local function fastGetAnyGlobalSets() + local results = {} + local mark = {} + for uri in files.eachFile() do + local globals = getGlobalSetsOfFile(uri) + for destName, sources in util.sortPairs(globals) do + if not mark[destName] then + mark[destName] = true + results[#results+1] = sources[1] + end + end + end + return results +end + +function vm.getGlobals(key) + if key == '*' and config.config.intelliSense.fastGlobal then + local cache = vm.getCache('fastGetAnyGlobals')[key] + if cache ~= nil then + return cache + end + cache = fastGetAnyGlobals() + vm.getCache('fastGetAnyGlobals')[key] = cache + return cache + else + local cache = vm.getCache('getGlobals')[key] + if cache ~= nil then + return cache + end + cache = getGlobals(key) + vm.getCache('getGlobals')[key] = cache + return cache + end +end + +function vm.getGlobalSets(key) + if key == '*' and config.config.intelliSense.fastGlobal then + local cache = vm.getCache('fastGetAnyGlobalSets')[key] + if cache ~= nil then + return cache + end + cache = fastGetAnyGlobalSets() + vm.getCache('fastGetAnyGlobalSets')[key] = cache + return cache + end + local cache = vm.getCache('getGlobalSets')[key] + if cache ~= nil then + return cache + end + cache = getGlobalSets(key) + vm.getCache('getGlobalSets')[key] = cache + return cache +end + +files.watch(function (ev, uri) + if ev == 'update' then + getGlobalsOfFile(uri) + getGlobalSetsOfFile(uri) + end +end) diff --git a/script/vm/getInfer.lua b/script/vm/getInfer.lua new file mode 100644 index 00000000..5272c389 --- /dev/null +++ b/script/vm/getInfer.lua @@ -0,0 +1,96 @@ +local vm = require 'vm.vm' +local guide = require 'parser.guide' +local util = require 'utility' +local await = require 'await' + +NIL = setmetatable({'<nil>'}, { __tostring = function () return 'nil' end }) + +--- 是否包含某种类型 +function vm.hasType(source, type, deep) + local defs = vm.getDefs(source, deep) + for i = 1, #defs do + local def = defs[i] + local value = guide.getObjectValue(def) or def + if value.type == type then + return true + end + end + return false +end + +--- 是否包含某种类型 +function vm.hasInferType(source, type, deep) + local infers = vm.getInfers(source, deep) + for i = 1, #infers do + local infer = infers[i] + if infer.type == type then + return true + end + end + return false +end + +function vm.getInferType(source, deep) + local infers = vm.getInfers(source, deep) + return guide.viewInferType(infers) +end + +function vm.getInferLiteral(source, deep) + local infers = vm.getInfers(source, deep) + local literals = {} + local mark = {} + for _, infer in ipairs(infers) do + local value = infer.value + if value and not mark[value] then + mark[value] = true + literals[#literals+1] = util.viewLiteral(value) + end + end + if #literals == 0 then + return nil + end + table.sort(literals) + return table.concat(literals, '|') +end + +local function getInfers(source, deep) + local results = {} + local lock = vm.lock('getInfers', source) + if not lock then + return results + end + + await.delay() + + local clock = os.clock() + local myResults, count = guide.requestInfer(source, vm.interface, deep) + if DEVELOP and os.clock() - clock > 0.1 then + log.warn('requestInfer', count, os.clock() - clock, guide.getUri(source), util.dump(source, { deep = 1 })) + end + vm.mergeResults(results, myResults) + + lock() + + return results +end + +--- 获取对象的值 +--- 会尝试穿透函数调用 +function vm.getInfers(source, deep) + if guide.isGlobal(source) then + local name = guide.getKeyName(source) + local cache = vm.getCache('getInfersOfGlobal')[name] + or vm.getCache('getInfers')[source] + or getInfers(source, 'deep') + vm.getCache('getInfersOfGlobal')[name] = cache + vm.getCache('getInfers')[source] = cache + return cache + else + local cache = vm.getCache('getInfers')[source] + or getInfers(source, deep) + if deep then + vm.getCache('getInfers')[source] = cache + end + return cache + end +end diff --git a/script/vm/getLibrary.lua b/script/vm/getLibrary.lua new file mode 100644 index 00000000..5803a73b --- /dev/null +++ b/script/vm/getLibrary.lua @@ -0,0 +1,32 @@ +local vm = require 'vm.vm' + +function vm.getLibraryName(source, deep) + local defs = vm.getDefs(source, deep) + for _, def in ipairs(defs) do + if def.special then + return def.special + end + end + return nil +end + +local globalLibraryNames = { + 'arg', 'assert', 'collectgarbage', 'dofile', '_G', 'getfenv', + 'getmetatable', 'ipairs', 'load', 'loadfile', 'loadstring', + 'module', 'next', 'pairs', 'pcall', 'print', 'rawequal', + 'rawget', 'rawlen', 'rawset', 'select', 'setfenv', + 'setmetatable', 'tonumber', 'tostring', 'type', '_VERSION', + 'warn', 'xpcall', 'require', 'unpack', 'bit32', 'coroutine', + 'debug', 'io', 'math', 'os', 'package', 'string', 'table', + 'utf8', +} +local globalLibraryNamesMap +function vm.isGlobalLibraryName(name) + if not globalLibraryNamesMap then + globalLibraryNamesMap = {} + for _, v in ipairs(globalLibraryNames) do + globalLibraryNamesMap[v] = true + end + end + return globalLibraryNamesMap[name] or false +end diff --git a/script/vm/getLinks.lua b/script/vm/getLinks.lua new file mode 100644 index 00000000..0bb1c6ff --- /dev/null +++ b/script/vm/getLinks.lua @@ -0,0 +1,61 @@ +local guide = require 'parser.guide' +local vm = require 'vm.vm' +local files = require 'files' + +local function getFileLinks(uri) + local ws = require 'workspace' + local links = {} + local ast = files.getAst(uri) + if not ast then + return links + end + guide.eachSpecialOf(ast.ast, 'require', function (source) + local call = source.parent + if not call or call.type ~= 'call' then + return + end + local args = call.args + if not args[1] or args[1].type ~= 'string' then + return + end + local uris = ws.findUrisByRequirePath(args[1][1]) + for _, u in ipairs(uris) do + u = files.asKey(u) + if not links[u] then + links[u] = {} + end + links[u][#links[u]+1] = call + end + end) + return links +end + +local function getLinksTo(uri) + uri = files.asKey(uri) + local links = {} + for u in files.eachFile() do + local ls = vm.getFileLinks(u) + if ls[uri] then + for _, l in ipairs(ls[uri]) do + links[#links+1] = l + end + end + end + return links +end + +function vm.getLinksTo(uri) + local cache = vm.getCache('getLinksTo')[uri] + if cache ~= nil then + return cache + end + cache = getLinksTo(uri) + vm.getCache('getLinksTo')[uri] = cache + return cache +end + +function vm.getFileLinks(uri) + local cache = files.getCache(uri) + cache.links = cache.links or getFileLinks(uri) + return cache.links +end diff --git a/script/vm/getMeta.lua b/script/vm/getMeta.lua new file mode 100644 index 00000000..aebef1a7 --- /dev/null +++ b/script/vm/getMeta.lua @@ -0,0 +1,52 @@ +local vm = require 'vm.vm' + +local function eachMetaOfArg1(source, callback) + local node, index = vm.getArgInfo(source) + local special = vm.getSpecial(node) + if special == 'setmetatable' and index == 1 then + local mt = node.next.args[2] + if mt then + callback(mt) + end + end +end + +local function eachMetaOfRecv(source, callback) + if not source or source.type ~= 'select' then + return + end + if source.index ~= 1 then + return + end + local call = source.vararg + if not call or call.type ~= 'call' then + return + end + local special = vm.getSpecial(call.node) + if special ~= 'setmetatable' then + return + end + local mt = call.args[2] + if mt then + callback(mt) + end +end + +function vm.eachMetaValue(source, callback) + vm.eachMeta(source, function (mt) + for _, src in ipairs(vm.getFields(mt)) do + if vm.getKeyName(src) == 's|__index' then + if src.value then + for _, valueSrc in ipairs(vm.getFields(src.value)) do + callback(valueSrc) + end + end + end + end + end) +end + +function vm.eachMeta(source, callback) + eachMetaOfArg1(source, callback) + eachMetaOfRecv(source.value, callback) +end diff --git a/script/vm/guideInterface.lua b/script/vm/guideInterface.lua new file mode 100644 index 00000000..e646def8 --- /dev/null +++ b/script/vm/guideInterface.lua @@ -0,0 +1,106 @@ +local vm = require 'vm.vm' +local files = require 'files' +local ws = require 'workspace' +local guide = require 'parser.guide' +local await = require 'await' +local config = require 'config' + +local m = {} + +function m.searchFileReturn(results, ast, index) + local returns = ast.returns + if not returns then + return + end + for _, ret in ipairs(returns) do + local exp = ret[index] + if exp then + vm.mergeResults(results, { exp }) + end + end +end + +function m.require(args, index) + local reqName = args[1] and args[1][1] + if not reqName then + return nil + end + local results = {} + local myUri = guide.getUri(args[1]) + local uris = ws.findUrisByRequirePath(reqName) + for _, uri in ipairs(uris) do + if not files.eq(myUri, uri) then + local ast = files.getAst(uri) + if ast then + m.searchFileReturn(results, ast.ast, index) + end + end + end + + return results +end + +function m.dofile(args, index) + local reqName = args[1] and args[1][1] + if not reqName then + return + end + local results = {} + local myUri = guide.getUri(args[1]) + local uris = ws.findUrisByFilePath(reqName) + for _, uri in ipairs(uris) do + if not files.eq(myUri, uri) then + local ast = files.getAst(uri) + if ast then + m.searchFileReturn(results, ast.ast, index) + end + end + end + return results +end + +vm.interface = {} + +-- 向前寻找引用的层数限制,一般情况下都为0 +-- 在自动完成/漂浮提示等情况时设置为5(需要清空缓存) +-- 在查找引用时设置为10(需要清空缓存) +vm.interface.searchLevel = 0 + +function vm.interface.call(func, args, index) + if func.special == 'require' and index == 1 then + await.delay() + return m.require(args, index) + end + if func.special == 'dofile' then + await.delay() + return m.dofile(args, index) + end +end + +function vm.interface.global(name) + await.delay() + return vm.getGlobals(name) +end + +function vm.interface.docType(name) + await.delay() + return vm.getDocTypes(name) +end + +function vm.interface.link(uri) + await.delay() + return vm.getLinksTo(uri) +end + +function vm.interface.index(obj) + return nil +end + +function vm.interface.cache() + await.delay() + return vm.getCache('cache') +end + +function vm.interface.getSearchDepth() + return config.config.intelliSense.searchDepth +end diff --git a/script/vm/init.lua b/script/vm/init.lua new file mode 100644 index 00000000..b9e8e147 --- /dev/null +++ b/script/vm/init.lua @@ -0,0 +1,13 @@ +local vm = require 'vm.vm' +require 'vm.getGlobals' +require 'vm.getDocs' +require 'vm.getLibrary' +require 'vm.getInfer' +require 'vm.getClass' +require 'vm.getMeta' +require 'vm.eachField' +require 'vm.eachDef' +require 'vm.eachRef' +require 'vm.getLinks' +require 'vm.guideInterface' +return vm diff --git a/script/vm/vm.lua b/script/vm/vm.lua new file mode 100644 index 00000000..e942d55e --- /dev/null +++ b/script/vm/vm.lua @@ -0,0 +1,167 @@ +local guide = require 'parser.guide' +local util = require 'utility' +local files = require 'files' +local timer = require 'timer' + +local setmetatable = setmetatable +local assert = assert +local require = require +local type = type +local running = coroutine.running +local ipairs = ipairs +local log = log +local xpcall = xpcall +local mathHuge = math.huge +local collectgarbage = collectgarbage + +_ENV = nil + +---@class vm +local m = {} + +function m.lock(tp, source) + local co = running() + local master = m.locked[co] + if not master then + master = {} + m.locked[co] = master + end + if not master[tp] then + master[tp] = {} + end + if master[tp][source] then + return nil + end + master[tp][source] = true + return function () + master[tp][source] = nil + end +end + +function m.isSet(src) + local tp = src.type + if tp == 'setglobal' + or tp == 'local' + or tp == 'setlocal' + or tp == 'setfield' + or tp == 'setmethod' + or tp == 'setindex' + or tp == 'tablefield' + or tp == 'tableindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(src.node) + if special == 'rawset' then + return true + end + end + return false +end + +function m.isGet(src) + local tp = src.type + if tp == 'getglobal' + or tp == 'getlocal' + or tp == 'getfield' + or tp == 'getmethod' + or tp == 'getindex' then + return true + end + if tp == 'call' then + local special = m.getSpecial(src.node) + if special == 'rawget' then + return true + end + end + return false +end + +function m.getArgInfo(source) + local callargs = source.parent + if not callargs or callargs.type ~= 'callargs' then + return nil + end + local call = callargs.parent + if not call or call.type ~= 'call' then + return nil + end + for i = 1, #callargs do + if callargs[i] == source then + return call.node, i + end + end + return nil +end + +function m.getSpecial(source) + if not source then + return nil + end + return source.special +end + +function m.getKeyName(source) + if not source then + return nil + end + if source.type == 'call' then + local special = m.getSpecial(source.node) + if special == 'rawset' + or special == 'rawget' then + return guide.getKeyNameOfLiteral(source.args[2]) + end + end + return guide.getKeyName(source) +end + +function m.mergeResults(a, b) + for _, r in ipairs(b) do + if not a[r] then + a[r] = true + a[#a+1] = r + end + end + return a +end + +m.cacheTracker = setmetatable({}, { __mode = 'kv' }) + +function m.flushCache() + if m.cache then + m.cache.dead = true + end + m.cacheVersion = files.globalVersion + m.cache = {} + m.cacheActiveTime = mathHuge + m.locked = setmetatable({}, { __mode = 'k' }) + m.cacheTracker[m.cache] = true +end + +function m.getCache(name) + if m.cacheVersion ~= files.globalVersion then + m.flushCache() + end + m.cacheActiveTime = timer.clock() + if not m.cache[name] then + m.cache[name] = {} + end + return m.cache[name] +end + +local function init() + m.flushCache() + + -- 可以在一段时间不活动后清空缓存,不过目前看起来没有必要 + --timer.loop(1, function () + -- if timer.clock() - m.cacheActiveTime > 10.0 then + -- log.info('Flush cache: Inactive') + -- m.flushCache() + -- collectgarbage() + -- end + --end) +end + +xpcall(init, log.error) + +return m diff --git a/script/without-check-nil.lua b/script/without-check-nil.lua new file mode 100644 index 00000000..cc7da9d4 --- /dev/null +++ b/script/without-check-nil.lua @@ -0,0 +1,126 @@ +local m = {} + +local mt = {} +mt.__add = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a + b +end +mt.__sub = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a - b +end +mt.__mul = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a * b +end +mt.__div = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a / b +end +mt.__mod = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a % b +end +mt.__pow = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a ^ b +end +mt.__unm = function () + return 0 +end +mt.__concat = function (a, b) + if a == nil then a = '' end + if b == nil then b = '' end + return a .. b +end +mt.__len = function () + return 0 +end +mt.__lt = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a < b +end +mt.__le = function (a, b) + if a == nil then a = 0 end + if b == nil then b = 0 end + return a <= b +end +mt.__index = function () end +mt.__newindex = function () end +mt.__call = function () end +mt.__pairs = function () end +mt.__ipairs = function () end +if _VERSION == 'Lua 5.3' or _VERSION == 'Lua 5.4' then + mt.__idiv = load[[ + local a, b = ... + if a == nil then a = 0 end + if b == nil then b = 0 end + return a // b + ]] + mt.__band = load[[ + local a, b = ... + if a == nil then a = 0 end + if b == nil then b = 0 end + return a & b + ]] + mt.__bor = load[[ + local a, b = ... + if a == nil then a = 0 end + if b == nil then b = 0 end + return a | b + ]] + mt.__bxor = load[[ + local a, b = ... + if a == nil then a = 0 end + if b == nil then b = 0 end + return a ~ b + ]] + mt.__bnot = load[[ + return ~ 0 + ]] + mt.__shl = load[[ + local a, b = ... + if a == nil then a = 0 end + if b == nil then b = 0 end + return a << b + ]] + mt.__shr = load[[ + local a, b = ... + if a == nil then a = 0 end + if b == nil then b = 0 end + return a >> b + ]] +end + +for event, func in pairs(mt) do + mt[event] = function (...) + local watch = m.watch + if not watch then + return func(...) + end + local care, result = watch(event, ...) + if not care then + return func(...) + end + return result + end +end + +function m.enable() + debug.setmetatable(nil, mt) +end + +function m.disable() + if debug.getmetatable(nil) == mt then + debug.setmetatable(nil, nil) + end +end + +return m diff --git a/script/workspace/init.lua b/script/workspace/init.lua new file mode 100644 index 00000000..7cbe15d7 --- /dev/null +++ b/script/workspace/init.lua @@ -0,0 +1,3 @@ +local workspace = require 'workspace.workspace' + +return workspace diff --git a/script/workspace/require-path.lua b/script/workspace/require-path.lua new file mode 100644 index 00000000..cfdc0455 --- /dev/null +++ b/script/workspace/require-path.lua @@ -0,0 +1,74 @@ +local platform = require 'bee.platform' +local files = require 'files' +local furi = require 'file-uri' +local m = {} + +m.cache = {} + +--- `aaa/bbb/ccc.lua` 与 `?.lua` 将返回 `aaa.bbb.cccc` +local function getOnePath(path, searcher) + local stemPath = path + : gsub('%.[^%.]+$', '') + : gsub('[/\\]+', '.') + local stemSearcher = searcher + : gsub('%.[^%.]+$', '') + : gsub('[/\\]+', '.') + local start = stemSearcher:match '()%?' or 1 + for pos = start, #stemPath do + local word = stemPath:sub(start, pos) + local newSearcher = stemSearcher:gsub('%?', word) + if newSearcher == stemPath then + return word + end + end + return nil +end + +function m.getVisiblePath(path, searchers) + path = path:gsub('^[/\\]+', '') + local uri = furi.encode(path) + local libraryPath = files.getLibraryPath(uri) + if not m.cache[path] then + local result = {} + m.cache[path] = result + local pos = 1 + if libraryPath then + pos = #libraryPath + 2 + end + repeat + local cutedPath = path:sub(pos) + local head + if pos > 1 then + head = path:sub(1, pos - 1) + end + pos = path:match('[/\\]+()', pos) + for _, searcher in ipairs(searchers) do + if platform.OS == 'Windows' then + searcher = searcher:gsub('[/\\]+', '\\') + else + searcher = searcher:gsub('[/\\]+', '/') + end + local expect = getOnePath(cutedPath, searcher) + if expect then + if head then + searcher = head .. searcher + end + result[#result+1] = { + searcher = searcher, + expect = expect, + } + end + end + if not pos then + break + end + until not pos + end + return m.cache[path] +end + +function m.flush() + m.cache = {} +end + +return m diff --git a/script/workspace/workspace.lua b/script/workspace/workspace.lua new file mode 100644 index 00000000..96b982b8 --- /dev/null +++ b/script/workspace/workspace.lua @@ -0,0 +1,330 @@ +local pub = require 'pub' +local fs = require 'bee.filesystem' +local furi = require 'file-uri' +local files = require 'files' +local config = require 'config' +local glob = require 'glob' +local platform = require 'bee.platform' +local await = require 'await' +local rpath = require 'workspace.require-path' +local proto = require 'proto.proto' +local lang = require 'language' +local library = require 'library' +local sp = require 'bee.subprocess' + +local m = {} +m.type = 'workspace' +m.nativeVersion = -1 +m.libraryVersion = -1 +m.nativeMatcher = nil +m.requireCache = {} +m.matchOption = { + ignoreCase = platform.OS == 'Windows', +} + +--- 初始化工作区 +function m.init(uri) + log.info('Workspace inited: ', uri) + if not uri then + return + end + m.uri = uri + m.path = m.normalize(furi.decode(uri)) + local logPath = ROOT / 'log' / (uri:gsub('[/:]+', '_') .. '.log') + log.info('Log path: ', logPath) + log.init(ROOT, logPath) +end + +local function interfaceFactory(root) + return { + type = function (path) + if fs.is_directory(fs.path(root .. '/' .. path)) then + return 'directory' + else + return 'file' + end + end, + list = function (path) + local fullPath = fs.path(root .. '/' .. path) + if not fs.exists(fullPath) then + return nil + end + local paths = {} + pcall(function () + for fullpath in fullPath:list_directory() do + paths[#paths+1] = fullpath:string() + end + end) + return paths + end + } +end + +--- 创建排除文件匹配器 +function m.getNativeMatcher() + if not m.path then + return nil + end + if m.nativeVersion == config.version then + return m.nativeMatcher + end + + local interface = interfaceFactory(m.path) + local pattern = {} + -- config.workspace.ignoreDir + for path in pairs(config.config.workspace.ignoreDir) do + log.info('Ignore directory:', path) + pattern[#pattern+1] = path + end + -- config.files.exclude + for path, ignore in pairs(config.other.exclude) do + if ignore then + log.info('Ignore by exclude:', path) + pattern[#pattern+1] = path + end + end + -- config.workspace.ignoreSubmodules + if config.config.workspace.ignoreSubmodules then + local buf = pub.awaitTask('loadFile', furi.encode(m.path .. '/.gitmodules')) + if buf then + for path in buf:gmatch('path = ([^\r\n]+)') do + log.info('Ignore by .gitmodules:', path) + pattern[#pattern+1] = path + end + end + end + -- config.workspace.useGitIgnore + if config.config.workspace.useGitIgnore then + local buf = pub.awaitTask('loadFile', furi.encode(m.path .. '/.gitignore')) + if buf then + for line in buf:gmatch '[^\r\n]+' do + log.info('Ignore by .gitignore:', line) + pattern[#pattern+1] = line + end + end + buf = pub.awaitTask('loadFile', furi.encode(m.path .. '/.git/info/exclude')) + if buf then + for line in buf:gmatch '[^\r\n]+' do + log.info('Ignore by .git/info/exclude:', line) + pattern[#pattern+1] = line + end + end + end + -- config.workspace.library + for path in pairs(config.config.workspace.library) do + log.info('Ignore by library:', path) + pattern[#pattern+1] = path + end + + m.nativeMatcher = glob.gitignore(pattern, m.matchOption, interface) + + m.nativeVersion = config.version + return m.nativeMatcher +end + +--- 创建代码库筛选器 +function m.getLibraryMatchers() + if m.libraryVersion == config.version then + return m.libraryMatchers + end + + local librarys = {} + for path, pattern in pairs(config.config.workspace.library) do + librarys[path] = pattern + end + if library.metaPath then + librarys[library.metaPath] = true + end + m.libraryMatchers = {} + for path, pattern in pairs(librarys) do + local nPath = fs.absolute(fs.path(path)):string() + local matcher = glob.gitignore(pattern, m.matchOption) + if platform.OS == 'Windows' then + matcher:setOption 'ignoreCase' + end + log.debug('getLibraryMatchers', path, nPath) + m.libraryMatchers[#m.libraryMatchers+1] = { + path = nPath, + matcher = matcher + } + end + + m.libraryVersion = config.version + return m.libraryMatchers +end + +--- 文件是否被忽略 +function m.isIgnored(uri) + local path = furi.decode(uri) + local ignore = m.getNativeMatcher() + if not ignore then + return false + end + return ignore(path) +end + +local function loadFileFactory(root, progress, isLibrary) + return function (path) + local uri = furi.encode(root .. '/' .. path) + if not files.isLua(uri) then + return + end + if progress.preload >= config.config.workspace.maxPreload then + if not m.hasHitMaxPreload then + m.hasHitMaxPreload = true + proto.notify('window/showMessage', { + type = 3, + message = lang.script('MWS_MAX_PRELOAD', config.config.workspace.maxPreload), + }) + end + return + end + if not isLibrary then + progress.preload = progress.preload + 1 + end + progress.max = progress.max + 1 + pub.task('loadFile', uri, function (text) + progress.read = progress.read + 1 + --log.info(('Preload file at: %s , size = %.3f KB'):format(uri, #text / 1000.0)) + if isLibrary then + files.setLibraryPath(uri, root) + end + files.setText(uri, text) + end) + end +end + +--- 预读工作区内所有文件 +function m.awaitPreload() + await.close 'preload' + await.setID 'preload' + local progress = { + max = 0, + read = 0, + preload = 0, + } + log.info('Preload start.') + local nativeLoader = loadFileFactory(m.path, progress) + local native = m.getNativeMatcher() + local librarys = m.getLibraryMatchers() + if native then + native:scan(nativeLoader) + end + for _, library in ipairs(librarys) do + local libraryInterface = interfaceFactory(library.path) + local libraryLoader = loadFileFactory(library.path, progress, true) + for k, v in pairs(libraryInterface) do + library.matcher:setInterface(k, v) + end + library.matcher:scan(libraryLoader) + end + + log.info(('Found %d files.'):format(progress.max)) + while true do + log.info(('Loaded %d/%d files'):format(progress.read, progress.max)) + if progress.read >= progress.max then + break + end + await.sleep(0.1) + end + + --for i = 1, 100 do + -- await.sleep(0.1) + -- log.info('sleep', i) + --end + + log.info('Preload finish.') + + local diagnostic = require 'provider.diagnostic' + diagnostic.start() +end + +--- 查找符合指定file path的所有uri +---@param path string +function m.findUrisByFilePath(path) + if type(path) ~= 'string' then + return {} + end + local results = {} + local posts = {} + for uri in files.eachFile() do + local pathLen = #path + local uriLen = #uri + local seg = uri:sub(uriLen - pathLen, uriLen - pathLen) + if seg == '/' or seg == '\\' or seg == '' then + local see = uri:sub(uriLen - pathLen + 1, uriLen) + if files.eq(see, path) then + results[#results+1] = uri + posts[uri] = files.getOriginUri(uri):sub(1, uriLen - pathLen) + end + end + end + return results, posts +end + +--- 查找符合指定require path的所有uri +---@param path string +function m.findUrisByRequirePath(path) + if type(path) ~= 'string' then + return {} + end + local results = {} + local mark = {} + local searchers = {} + local input = path:gsub('%.', '/') + :gsub('%%', '%%%%') + for _, luapath in ipairs(config.config.runtime.path) do + local part = luapath:gsub('%?', input) + local uris, posts = m.findUrisByFilePath(part) + for _, uri in ipairs(uris) do + if not mark[uri] then + mark[uri] = true + results[#results+1] = uri + searchers[uri] = posts[uri] .. luapath + end + end + end + return results, searchers +end + +function m.normalize(path) + if platform.OS == 'Windows' then + path = path:gsub('[/\\]+', '\\') + :gsub('^%a+%:', function (str) + return str:upper() + end) + else + path = path:gsub('[/\\]+', '/') + end + return path:gsub('^[/\\]+', '') +end + +function m.getRelativePath(uri) + local path = furi.decode(uri) + if not m.path then + return m.normalize(path) + end + local _, pos = m.normalize(path):lower():find(m.path:lower(), 1, true) + if pos then + return m.normalize(path:sub(pos + 1)) + else + return m.normalize(path) + end +end + +function m.reload() + files.flushAllLibrary() + files.removeAllClosed() + rpath.flush() + await.call(m.awaitPreload) +end + +files.watch(function (ev, uri) + if ev == 'close' + and m.isIgnored(uri) + and not files.isLibrary(uri) then + files.remove(uri) + end +end) + +return m |