diff options
Diffstat (limited to 'script/service/net.lua')
-rw-r--r-- | script/service/net.lua | 207 |
1 files changed, 89 insertions, 118 deletions
diff --git a/script/service/net.lua b/script/service/net.lua index 86edc9a6..2019406e 100644 --- a/script/service/net.lua +++ b/script/service/net.lua @@ -1,42 +1,39 @@ local socket = require "bee.socket" +local select = require "bee.select" +local selector = select.create() +local SELECT_READ <const> = select.SELECT_READ +local SELECT_WRITE <const> = select.SELECT_WRITE -local readfds = {} -local writefds = {} -local map = {} - -local function FD_SET(set, fd) - for i = 1, #set do - if fd == set[i] then - return - end +local function fd_set_read(s) + if s._flags & SELECT_READ ~= 0 then + return end - set[#set+1] = fd + s._flags = s._flags | SELECT_READ + selector:event_mod(s._fd, s._flags) end -local function FD_CLR(set, fd) - for i = 1, #set do - if fd == set[i] then - set[i] = set[#set] - set[#set] = nil - return - end +local function fd_clr_read(s) + if s._flags & SELECT_READ == 0 then + return end + s._flags = s._flags & (~SELECT_READ) + selector:event_mod(s._fd, s._flags) end -local function fd_set_read(fd) - FD_SET(readfds, fd) -end - -local function fd_clr_read(fd) - FD_CLR(readfds, fd) -end - -local function fd_set_write(fd) - FD_SET(writefds, fd) +local function fd_set_write(s) + if s._flags & SELECT_WRITE ~= 0 then + return + end + s._flags = s._flags | SELECT_WRITE + selector:event_mod(s._fd, s._flags) end -local function fd_clr_write(fd) - FD_CLR(writefds, fd) +local function fd_clr_write(s) + if s._flags & SELECT_WRITE == 0 then + return + end + s._flags = s._flags & (~SELECT_WRITE) + selector:event_mod(s._fd, s._flags) end local function on_event(self, name, ...) @@ -49,8 +46,8 @@ end local function close(self) local fd = self._fd on_event(self, "close") + selector:event_del(fd) fd:close() - map[fd] = nil end local stream_mt = {} @@ -69,7 +66,7 @@ function stream:write(data) return end if self._writebuf == "" then - fd_set_write(self._fd) + fd_set_write(self) end self._writebuf = self._writebuf .. data end @@ -79,35 +76,17 @@ end function stream:close() if not self.shutdown_r then self.shutdown_r = true - fd_clr_read(self._fd) + fd_clr_read(self) end if self.shutdown_w or self._writebuf == "" then self.shutdown_w = true - fd_clr_write(self._fd) + fd_clr_write(self) close(self) end end -function stream:update(timeout) - local fd = self._fd - local r = {fd} - local w = r - if self._writebuf == "" then - w = nil - end - local rd, wr = socket.select(r, w, timeout or 0) - if rd then - if #rd > 0 then - self:select_r() - end - if #wr > 0 then - self:select_w() - end - end -end local function close_write(self) - fd_clr_write(self._fd) + fd_clr_write(self) if self.shutdown_r then - fd_clr_read(self._fd) close(self) end end @@ -133,26 +112,43 @@ function stream:select_w() end end end +local function update_stream(s, event) + if event & SELECT_READ ~= 0 then + s:select_r() + end + if event & SELECT_WRITE ~= 0 then + s:select_w() + end +end local function accept_stream(fd) - local self = setmetatable({ + local s = setmetatable({ _fd = fd, + _flags = SELECT_READ, _event = {}, _writebuf = "", shutdown_r = false, shutdown_w = false, }, stream_mt) - map[fd] = self - fd_set_read(fd) - return self -end -local function connect_stream(self) - setmetatable(self, stream_mt) - fd_set_read(self._fd) - if self._writebuf ~= "" then - self:select_w() + selector:event_add(fd, SELECT_READ, function (event) + update_stream(s, event) + end) + return s +end +local function connect_stream(s) + setmetatable(s, stream_mt) + selector:event_del(s._fd) + if s._writebuf ~= "" then + s._flags = SELECT_READ | SELECT_WRITE + selector:event_add(s._fd, SELECT_READ | SELECT_WRITE, function (event) + update_stream(s, event) + end) + s:select_w() else - fd_clr_write(self._fd) + s._flags = SELECT_READ + selector:event_add(s._fd, SELECT_READ, function (event) + update_stream(s, event) + end) end end @@ -170,35 +166,32 @@ function listen:is_closed() end function listen:close() self.shutdown_r = true - fd_clr_read(self._fd) close(self) end -function listen:update(timeout) - local fd = self._fd - local r = {fd} - local rd = socket.select(r, nil, timeout or 0) - if rd then - if #rd > 0 then - self:select_r() - end - end -end -function listen:select_r() - local newfd = self._fd:accept() - if newfd:status() then - local news = accept_stream(newfd) - on_event(self, "accept", news) - end -end local function new_listen(fd) local s = { _fd = fd, + _flags = SELECT_READ, _event = {}, shutdown_r = false, shutdown_w = true, } - map[fd] = s - fd_set_read(fd) + selector:event_add(fd, SELECT_READ, function () + local newfd, err = fd:accept() + if not newfd then + on_event(s, "error", err) + return + end + local ok, err = newfd:status() + if not ok then + on_event(s, "error", err) + return + end + if newfd:status() then + local news = accept_stream(newfd) + on_event(s, "accept", news) + end + end) return setmetatable(s, listen_mt) end @@ -221,39 +214,27 @@ function connect:is_closed() end function connect:close() self.shutdown_w = true - fd_clr_write(self._fd) close(self) end -function connect:update(timeout) - local fd = self._fd - local w = {fd} - local rd, wr = socket.select(nil, w, timeout or 0) - if rd then - if #wr > 0 then - self:select_w() - end - end -end -function connect:select_w() - local ok, err = self._fd:status() - if ok then - connect_stream(self) - on_event(self, "connect") - else - on_event(self, "error", err) - self:close() - end -end local function new_connect(fd) local s = { _fd = fd, + _flags = SELECT_WRITE, _event = {}, _writebuf = "", shutdown_r = false, shutdown_w = false, } - map[fd] = s - fd_set_write(fd) + selector:event_add(fd, SELECT_WRITE, function () + local ok, err = fd:status() + if ok then + connect_stream(s) + on_event(s, "connect") + else + on_event(s, "error", err) + s:close() + end + end) return setmetatable(s, connect_mt) end @@ -293,18 +274,8 @@ function m.connect(protocol, ...) end function m.update(timeout) - local rd, wr = socket.select(readfds, writefds, timeout or 0) - if rd then - for i = 1, #rd do - local fd = rd[i] - local s = map[fd] - s:select_r() - end - for i = 1, #wr do - local fd = wr[i] - local s = map[fd] - s:select_w() - end + for func, event in selector:wait(timeout or 0) do + func(event) end end |