From 793fa4abf906362ba43bf2ea1ffbe7083499805a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Thu, 18 Apr 2024 11:56:30 +0800 Subject: =?UTF-8?q?=E6=9B=B4=E6=96=B0=20bee.net?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- script/meta/bee/socket.lua | 6 +- script/service/net.lua | 237 ++++++++++++++++++++++----------------------- 2 files changed, 120 insertions(+), 123 deletions(-) (limited to 'script') diff --git a/script/meta/bee/socket.lua b/script/meta/bee/socket.lua index 1724cbb3..55c349a6 100644 --- a/script/meta/bee/socket.lua +++ b/script/meta/bee/socket.lua @@ -8,9 +8,13 @@ ---| 'udp6' ---@class bee.socket ----@overload fun(protocol: bee.socket.protocol): bee.socket.fd?, string? local socket = {} +---@param protocol bee.socket.protocol +---@return bee.socket.fd? +---@return string? +function socket.create(protocol) end + ---@param readfds? bee.socket.fd[] ---@param writefds? bee.socket.fd[] ---@param timeout number diff --git a/script/service/net.lua b/script/service/net.lua index 2019406e..bf77e7df 100644 --- a/script/service/net.lua +++ b/script/service/net.lua @@ -1,5 +1,7 @@ local socket = require "bee.socket" local select = require "bee.select" +local fs = require "bee.filesystem" + local selector = select.create() local SELECT_READ = select.SELECT_READ local SELECT_WRITE = select.SELECT_WRITE @@ -39,7 +41,7 @@ end local function on_event(self, name, ...) local f = self._event[name] if f then - f(self, ...) + return f(self, ...) end end @@ -90,69 +92,31 @@ local function close_write(self) close(self) end end -function stream:select_r() - local data = self._fd:recv() - if data == nil then - self:close() - elseif data == false then - else - on_event(self, "data", data) - end -end -function stream:select_w() - local n = self._fd:send(self._writebuf) - if n == nil then - self.shutdown_w = true - close_write(self) - elseif n == false then - else - self._writebuf = self._writebuf:sub(n + 1) - if self._writebuf == "" then - close_write(self) - end - end -end local function update_stream(s, event) if event & SELECT_READ ~= 0 then - s:select_r() + local data = s._fd:recv() + if data == nil then + s:close() + elseif data == false then + else + on_event(s, "data", data) + end end if event & SELECT_WRITE ~= 0 then - s:select_w() - end -end - -local function accept_stream(fd) - local s = setmetatable({ - _fd = fd, - _flags = SELECT_READ, - _event = {}, - _writebuf = "", - shutdown_r = false, - shutdown_w = false, - }, stream_mt) - selector:event_add(fd, SELECT_READ, function (event) - update_stream(s, event) - end) - return s -end -local function connect_stream(s) - setmetatable(s, stream_mt) - selector:event_del(s._fd) - if s._writebuf ~= "" then - s._flags = SELECT_READ | SELECT_WRITE - selector:event_add(s._fd, SELECT_READ | SELECT_WRITE, function (event) - update_stream(s, event) - end) - s:select_w() - else - s._flags = SELECT_READ - selector:event_add(s._fd, SELECT_READ, function (event) - update_stream(s, event) - end) + local n = s._fd:send(s._writebuf) + if n == nil then + s.shutdown_w = true + close_write(s) + elseif n == false then + else + s._writebuf = s._writebuf:sub(n + 1) + if s._writebuf == "" then + close_write(s) + end + end end end - local listen_mt = {} local listen = {} listen_mt.__index = listen @@ -168,32 +132,6 @@ function listen:close() self.shutdown_r = true close(self) end -local function new_listen(fd) - local s = { - _fd = fd, - _flags = SELECT_READ, - _event = {}, - shutdown_r = false, - shutdown_w = true, - } - selector:event_add(fd, SELECT_READ, function () - local newfd, err = fd:accept() - if not newfd then - on_event(s, "error", err) - return - end - local ok, err = newfd:status() - if not ok then - on_event(s, "error", err) - return - end - if newfd:status() then - local news = accept_stream(newfd) - on_event(s, "accept", news) - end - end) - return setmetatable(s, listen_mt) -end local connect_mt = {} local connect = {} @@ -216,7 +154,84 @@ function connect:close() self.shutdown_w = true close(self) end -local function new_connect(fd) + +local m = {} + +function m.listen(protocol, address, port) + local fd; do + local err + fd, err = socket.create(protocol) + if not fd then + return nil, err + end + if protocol == "unix" then + fs.remove(address) + end + end + do + local ok, err = fd:bind(address, port) + if not ok then + fd:close() + return nil, err + end + end + do + local ok, err = fd:listen() + if not ok then + fd:close() + return nil, err + end + end + local s = { + _fd = fd, + _flags = SELECT_READ, + _event = {}, + shutdown_r = false, + shutdown_w = true, + } + selector:event_add(fd, SELECT_READ, function () + local new_fd, err = fd:accept() + if new_fd == nil then + fd:close() + on_event(s, "error", err) + return + elseif new_fd == false then + else + local new_s = setmetatable({ + _fd = new_fd, + _flags = SELECT_READ, + _event = {}, + _writebuf = "", + shutdown_r = false, + shutdown_w = false, + }, stream_mt) + if on_event(s, "accepted", new_s) then + selector:event_add(new_fd, new_s._flags, function (event) + update_stream(new_s, event) + end) + else + new_fd:close() + end + end + end) + return setmetatable(s, listen_mt) +end + +function m.connect(protocol, address, port) + local fd; do + local err + fd, err = socket.create(protocol) + if not fd then + return nil, err + end + end + do + local ok, err = fd:connect(address, port) + if ok == nil then + fd:close() + return nil, err + end + end local s = { _fd = fd, _flags = SELECT_WRITE, @@ -228,51 +243,29 @@ local function new_connect(fd) selector:event_add(fd, SELECT_WRITE, function () local ok, err = fd:status() if ok then - connect_stream(s) - on_event(s, "connect") + on_event(s, "connected") + setmetatable(s, stream_mt) + if s._writebuf ~= "" then + update_stream(s, SELECT_WRITE) + if s._writebuf ~= "" then + s._flags = SELECT_READ | SELECT_WRITE + else + s._flags = SELECT_READ + end + else + s._flags = SELECT_READ + end + selector:event_add(s._fd, s._flags, function (event) + update_stream(s, event) + end) else - on_event(s, "error", err) s:close() + on_event(s, "error", err) end end) return setmetatable(s, connect_mt) end -local m = {} - -function m.listen(protocol, ...) - local fd, err = socket(protocol) - if not fd then - return nil, err - end - local ok - ok, err = fd:bind(...) - if not ok then - fd:close() - return nil, err - end - ok, err = fd:listen() - if not ok then - fd:close() - return nil, err - end - return new_listen(fd) -end - -function m.connect(protocol, ...) - local fd, err = socket(protocol) - if not fd then - return nil, err - end - local ok - ok, err = fd:connect(...) - if ok == nil then - fd:close() - return nil, err - end - return new_connect(fd) -end - function m.update(timeout) for func, event in selector:wait(timeout or 0) do func(event) -- cgit v1.2.3