summaryrefslogtreecommitdiff
path: root/script/service/net.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/service/net.lua')
-rw-r--r--script/service/net.lua237
1 files changed, 115 insertions, 122 deletions
diff --git a/script/service/net.lua b/script/service/net.lua
index 2019406e..bf77e7df 100644
--- a/script/service/net.lua
+++ b/script/service/net.lua
@@ -1,5 +1,7 @@
local socket = require "bee.socket"
local select = require "bee.select"
+local fs = require "bee.filesystem"
+
local selector = select.create()
local SELECT_READ <const> = select.SELECT_READ
local SELECT_WRITE <const> = select.SELECT_WRITE
@@ -39,7 +41,7 @@ end
local function on_event(self, name, ...)
local f = self._event[name]
if f then
- f(self, ...)
+ return f(self, ...)
end
end
@@ -90,69 +92,31 @@ local function close_write(self)
close(self)
end
end
-function stream:select_r()
- local data = self._fd:recv()
- if data == nil then
- self:close()
- elseif data == false then
- else
- on_event(self, "data", data)
- end
-end
-function stream:select_w()
- local n = self._fd:send(self._writebuf)
- if n == nil then
- self.shutdown_w = true
- close_write(self)
- elseif n == false then
- else
- self._writebuf = self._writebuf:sub(n + 1)
- if self._writebuf == "" then
- close_write(self)
- end
- end
-end
local function update_stream(s, event)
if event & SELECT_READ ~= 0 then
- s:select_r()
+ local data = s._fd:recv()
+ if data == nil then
+ s:close()
+ elseif data == false then
+ else
+ on_event(s, "data", data)
+ end
end
if event & SELECT_WRITE ~= 0 then
- s:select_w()
- end
-end
-
-local function accept_stream(fd)
- local s = setmetatable({
- _fd = fd,
- _flags = SELECT_READ,
- _event = {},
- _writebuf = "",
- shutdown_r = false,
- shutdown_w = false,
- }, stream_mt)
- selector:event_add(fd, SELECT_READ, function (event)
- update_stream(s, event)
- end)
- return s
-end
-local function connect_stream(s)
- setmetatable(s, stream_mt)
- selector:event_del(s._fd)
- if s._writebuf ~= "" then
- s._flags = SELECT_READ | SELECT_WRITE
- selector:event_add(s._fd, SELECT_READ | SELECT_WRITE, function (event)
- update_stream(s, event)
- end)
- s:select_w()
- else
- s._flags = SELECT_READ
- selector:event_add(s._fd, SELECT_READ, function (event)
- update_stream(s, event)
- end)
+ local n = s._fd:send(s._writebuf)
+ if n == nil then
+ s.shutdown_w = true
+ close_write(s)
+ elseif n == false then
+ else
+ s._writebuf = s._writebuf:sub(n + 1)
+ if s._writebuf == "" then
+ close_write(s)
+ end
+ end
end
end
-
local listen_mt = {}
local listen = {}
listen_mt.__index = listen
@@ -168,32 +132,6 @@ function listen:close()
self.shutdown_r = true
close(self)
end
-local function new_listen(fd)
- local s = {
- _fd = fd,
- _flags = SELECT_READ,
- _event = {},
- shutdown_r = false,
- shutdown_w = true,
- }
- selector:event_add(fd, SELECT_READ, function ()
- local newfd, err = fd:accept()
- if not newfd then
- on_event(s, "error", err)
- return
- end
- local ok, err = newfd:status()
- if not ok then
- on_event(s, "error", err)
- return
- end
- if newfd:status() then
- local news = accept_stream(newfd)
- on_event(s, "accept", news)
- end
- end)
- return setmetatable(s, listen_mt)
-end
local connect_mt = {}
local connect = {}
@@ -216,7 +154,84 @@ function connect:close()
self.shutdown_w = true
close(self)
end
-local function new_connect(fd)
+
+local m = {}
+
+function m.listen(protocol, address, port)
+ local fd; do
+ local err
+ fd, err = socket.create(protocol)
+ if not fd then
+ return nil, err
+ end
+ if protocol == "unix" then
+ fs.remove(address)
+ end
+ end
+ do
+ local ok, err = fd:bind(address, port)
+ if not ok then
+ fd:close()
+ return nil, err
+ end
+ end
+ do
+ local ok, err = fd:listen()
+ if not ok then
+ fd:close()
+ return nil, err
+ end
+ end
+ local s = {
+ _fd = fd,
+ _flags = SELECT_READ,
+ _event = {},
+ shutdown_r = false,
+ shutdown_w = true,
+ }
+ selector:event_add(fd, SELECT_READ, function ()
+ local new_fd, err = fd:accept()
+ if new_fd == nil then
+ fd:close()
+ on_event(s, "error", err)
+ return
+ elseif new_fd == false then
+ else
+ local new_s = setmetatable({
+ _fd = new_fd,
+ _flags = SELECT_READ,
+ _event = {},
+ _writebuf = "",
+ shutdown_r = false,
+ shutdown_w = false,
+ }, stream_mt)
+ if on_event(s, "accepted", new_s) then
+ selector:event_add(new_fd, new_s._flags, function (event)
+ update_stream(new_s, event)
+ end)
+ else
+ new_fd:close()
+ end
+ end
+ end)
+ return setmetatable(s, listen_mt)
+end
+
+function m.connect(protocol, address, port)
+ local fd; do
+ local err
+ fd, err = socket.create(protocol)
+ if not fd then
+ return nil, err
+ end
+ end
+ do
+ local ok, err = fd:connect(address, port)
+ if ok == nil then
+ fd:close()
+ return nil, err
+ end
+ end
local s = {
_fd = fd,
_flags = SELECT_WRITE,
@@ -228,51 +243,29 @@ local function new_connect(fd)
selector:event_add(fd, SELECT_WRITE, function ()
local ok, err = fd:status()
if ok then
- connect_stream(s)
- on_event(s, "connect")
+ on_event(s, "connected")
+ setmetatable(s, stream_mt)
+ if s._writebuf ~= "" then
+ update_stream(s, SELECT_WRITE)
+ if s._writebuf ~= "" then
+ s._flags = SELECT_READ | SELECT_WRITE
+ else
+ s._flags = SELECT_READ
+ end
+ else
+ s._flags = SELECT_READ
+ end
+ selector:event_add(s._fd, s._flags, function (event)
+ update_stream(s, event)
+ end)
else
- on_event(s, "error", err)
s:close()
+ on_event(s, "error", err)
end
end)
return setmetatable(s, connect_mt)
end
-local m = {}
-
-function m.listen(protocol, ...)
- local fd, err = socket(protocol)
- if not fd then
- return nil, err
- end
- local ok
- ok, err = fd:bind(...)
- if not ok then
- fd:close()
- return nil, err
- end
- ok, err = fd:listen()
- if not ok then
- fd:close()
- return nil, err
- end
- return new_listen(fd)
-end
-
-function m.connect(protocol, ...)
- local fd, err = socket(protocol)
- if not fd then
- return nil, err
- end
- local ok
- ok, err = fd:connect(...)
- if ok == nil then
- fd:close()
- return nil, err
- end
- return new_connect(fd)
-end
-
function m.update(timeout)
for func, event in selector:wait(timeout or 0) do
func(event)