diff options
Diffstat (limited to 'script/await.lua')
-rw-r--r-- | script/await.lua | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/script/await.lua b/script/await.lua new file mode 100644 index 00000000..d8e2a9ad --- /dev/null +++ b/script/await.lua @@ -0,0 +1,227 @@ +local timer = require 'timer' +local util = require 'utility' + +---@class await +local m = {} +m.type = 'await' + +m.coMap = setmetatable({}, { __mode = 'k' }) +m.idMap = {} +m.delayQueue = {} +m.delayQueueIndex = 1 +m.watchList = {} +m._enable = true + +--- 设置错误处理器 +---@param errHandle function {comment = '当有错误发生时,会以错误堆栈为参数调用该函数'} +function m.setErrorHandle(errHandle) + m.errorHandle = errHandle +end + +function m.checkResult(co, ...) + local suc, err = ... + if not suc and m.errorHandle then + m.errorHandle(debug.traceback(co, err)) + end + return ... +end + +--- 创建一个任务 +function m.call(callback, ...) + local co = coroutine.create(callback) + local closers = {} + m.coMap[co] = { + closers = closers, + priority = false, + } + for i = 1, select('#', ...) do + local id = select(i, ...) + if not id then + break + end + m.setID(id, co) + end + + local currentCo = coroutine.running() + local current = m.coMap[currentCo] + if current then + for closer in pairs(current.closers) do + closers[closer] = true + closer(co) + end + end + return m.checkResult(co, coroutine.resume(co)) +end + +--- 创建一个任务,并挂起当前线程,当任务完成后再延续当前线程/若任务被关闭,则返回nil +function m.await(callback, ...) + if not coroutine.isyieldable() then + return callback(...) + end + return m.wait(function (waker, ...) + m.call(function () + local returnNil <close> = util.defer(waker) + waker(callback()) + end, ...) + end, ...) +end + +--- 设置一个id,用于批量关闭任务 +function m.setID(id, co) + co = co or coroutine.running() + if not m.idMap[id] then + m.idMap[id] = setmetatable({}, { __mode = 'k' }) + end + m.idMap[id][co] = true +end + +--- 根据id批量关闭任务 +function m.close(id) + local map = m.idMap[id] + if not map then + return + end + local count = 0 + for co in pairs(map) do + map[co] = nil + coroutine.close(co) + count = count + 1 + end + log.debug('Close await:', id, count) +end + +function m.hasID(id, co) + co = co or coroutine.running() + return m.idMap[id] and m.idMap[id][co] ~= nil +end + +--- 休眠一段时间 +---@param time number +function m.sleep(time) + if not coroutine.isyieldable() then + if m.errorHandle then + m.errorHandle(debug.traceback('Cannot yield')) + end + return + end + local co = coroutine.running() + timer.wait(time, function () + if coroutine.status(co) ~= 'suspended' then + return + end + return m.checkResult(co, coroutine.resume(co)) + end) + return coroutine.yield() +end + +--- 等待直到唤醒 +---@param callback function +function m.wait(callback, ...) + if not coroutine.isyieldable() then + return + end + local co = coroutine.running() + local waked + callback(function (...) + if waked then + return + end + waked = true + if coroutine.status(co) ~= 'suspended' then + return + end + return m.checkResult(co, coroutine.resume(co, ...)) + end, ...) + return coroutine.yield() +end + +--- 延迟 +function m.delay() + if not m._enable then + return + end + if not coroutine.isyieldable() then + return + end + local co = coroutine.running() + local current = m.coMap[co] + if m.onWatch('delay', co) == false then + return + end + -- TODO + if current.priority then + return + end + m.delayQueue[#m.delayQueue+1] = function () + if coroutine.status(co) ~= 'suspended' then + return + end + return m.checkResult(co, coroutine.resume(co)) + end + return coroutine.yield() +end + +local function warnStepTime(passed, waker) + if passed < 1 then + log.warn(('Await step takes [%.3f] sec.'):format(passed)) + return + end + for i = 1, 100 do + local name, v = debug.getupvalue(waker, i) + if not name then + return + end + if name == 'co' then + log.warn(debug.traceback(v, ('[fire]Await step takes [%.3f] sec.'):format(passed))) + return + end + end +end + +--- 步进 +function m.step() + local waker = m.delayQueue[m.delayQueueIndex] + if waker then + m.delayQueue[m.delayQueueIndex] = false + m.delayQueueIndex = m.delayQueueIndex + 1 + local clock = os.clock() + waker() + local passed = os.clock() - clock + if passed > 0.1 then + warnStepTime(passed, waker) + end + return true + else + m.delayQueue = {} + m.delayQueueIndex = 1 + return false + end +end + +function m.setPriority(n) + m.coMap[coroutine.running()].priority = true +end + +function m.enable() + m._enable = true +end + +function m.disable() + m._enable = false +end + +--- 注册事件 +function m.watch(callback) + m.watchList[#m.watchList+1] = callback +end + +function m.onWatch(ev, ...) + for _, callback in ipairs(m.watchList) do + local res = callback(ev, ...) + if res ~= nil then + return res + end + end +end + +return m |