1 module yu.asyncsocket.udpsocket;
2 
3 import core.stdc.errno;
4 
5 import std.socket;
6 import std.functional;
7 import std.exception;
8 
9 import yu.eventloop;
10 import yu.asyncsocket.transport;
11 import yu.exception : yuCathException, showException;
12 
13 // you should be yDel the Address, and the data if you save shoulde be copy;
14 alias UDPReadCallBack = void delegate(in ubyte[] buffer, Address adr) nothrow;
15 
16 @trusted class UDPSocket : AsyncTransport, EventCallInterface {
17     this(EventLoop loop, bool isIpV6 = false) {
18         auto family = isIpV6 ? AddressFamily.INET6 : AddressFamily.INET;
19         this(loop, family);
20     }
21 
22     this(EventLoop loop, AddressFamily family)
23     in {
24         assert(family == AddressFamily.INET6 || family == AddressFamily.INET,
25             "the AddressFamily must be AddressFamily.INET or AddressFamily.INET6");
26     }
27     body {
28         super(loop, TransportType.UDP);
29 
30         _socket = yNew!UdpSocket(family);
31 
32         _socket.blocking = true;
33         _readBuffer = makeArray!ubyte(yuAlloctor, UDP_READ_BUFFER_SIZE);
34         _event = AsyncEvent(AsynType.UDP, this, _socket.handle, true, false, false);
35         static if (IO_MODE.iocp == IOMode) {
36             _iocpBuffer.len = cast(uint)TCP_READ_BUFFER_SIZE;
37             _iocpBuffer.buf = cast(char*) _readBuffer.ptr;
38             _iocpread.event = &_event;
39             _iocpread.operationType = IOCP_OP_TYPE.read;
40 
41             if (family == AddressFamily.INET)
42                 _bindddr = new InternetAddress(InternetAddress.PORT_ANY);
43             else if (family == AddressFamily.INET6)
44                 _bindddr = new Internet6Address(Internet6Address.PORT_ANY);
45             else
46                 _bindddr = new UnknownAddress();
47         }
48     }
49 
50     ~this() {
51         if (_event.isActive)
52             eventLoop.delEvent(&_event);
53         yDel(_socket);
54         yDel(_readBuffer);
55         _readBuffer = null;
56         static if (IO_MODE.iocp == IOMode) {
57             if (_bindddr)
58                 yDel(_bindddr);
59         }
60     }
61 
62     @property reusePort(bool use) {
63         _socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, use);
64         version (Posix)
65             _socket.setOption(SocketOptionLevel.SOCKET, cast(SocketOption) SO_REUSEPORT,
66                 use);
67     }
68 
69     pragma(inline) final void setReadCallBack(UDPReadCallBack cback) {
70         _readCallBack = cback;
71     }
72 
73     void bind(Address addr) @trusted {
74         static if (IO_MODE.iocp == IOMode) {
75             _isBind = true;
76             _bindddr = addr;
77             trace("udp bind : ", addr.toAddrString());
78         }
79         _socket.bind(forward!addr);
80     }
81 
82     bool connect(Address to) {
83         if (!_socket.isAlive())
84             return false;
85         _connecto = to;
86         return true;
87     }
88 
89     pragma(inline) @safe ptrdiff_t sendTo(const(void)[] buf, Address to) {
90         return _socket.sendTo(buf, to);
91     }
92 
93     pragma(inline) @safe ptrdiff_t sendTo(const(void)[] buf) {
94         ptrdiff_t len = -1;
95         if (_connecto)
96             len = _socket.sendTo(buf, _connecto);
97         return len;
98     }
99 
100     final override @property int fd() {
101         return cast(int) _socket.handle();
102     }
103 
104     pragma(inline, true) final @property localAddress() {
105         return _socket.localAddress();
106     }
107 
108     override bool start() {
109         if (_event.isActive || !_socket.isAlive() || !_readCallBack)
110             return false;
111         _event = AsyncEvent(AsynType.UDP, this, _socket.handle, true, false, false);
112         static if (IOMode == IO_MODE.iocp) {
113             if (!_isBind) {
114                 bind(_bindddr);
115             }
116             _loop.addEvent(&_event);
117             return doRead();
118         } else {
119             return _loop.addEvent(&_event);
120         }
121     }
122 
123     override void close() {
124         if (isAlive) {
125             onClose();
126         } else if (_socket.isAlive()) {
127             _socket.close();
128         }
129     }
130 
131     override @property bool isAlive() @trusted nothrow {
132         bool alive;
133         yuCathException((_event.isActive && _socket.handle() != socket_t.init), alive);
134         return alive;
135     }
136 
137     mixin TransportSocketOption;
138 
139 protected:
140     override void onRead() nothrow {
141         try {
142             static if (IO_MODE.iocp == IOMode) {
143                 if (_event.readLen > 0) {
144                     setReadAddr();
145                     _readCallBack(_readBuffer[0 .. _event.readLen], _readAddr);
146                 }
147                 scope (exit) {
148                     _event.readLen = 0;
149                     if (_socket.isAlive)
150                         doRead();
151                 }
152             } else {
153                 if (_readAddr is null)
154                     _readAddr = createAddress();
155                 auto len = _socket.receiveFrom(_readBuffer, _readAddr);
156                 if (len <= 0)
157                     return;
158                 Address tp = _readAddr;
159                 _readAddr = null;
160                 _readCallBack(_readBuffer[0 .. len], tp);
161 
162             }
163         }
164         catch (Exception e) {
165             showException(e);
166         }
167     }
168 
169     override void onWrite() nothrow {
170     }
171 
172     override void onClose() nothrow {
173         if (!isAlive)
174             return;
175         eventLoop.delEvent(&_event);
176         _socket.close();
177         static if (IO_MODE.iocp == IOMode)
178             _isBind = false;
179     }
180 
181     static if (IO_MODE.iocp == IOMode) {
182     package:
183         pragma(inline, true) void setReadAddr() {
184             if (remoteAddrLen == 32) {
185                 sockaddr_in* addr = cast(sockaddr_in*)(&remoteAddr);
186                 _readAddr = yNew!InternetAddress(*addr);
187             } else {
188                 sockaddr_in6* addr = cast(sockaddr_in6*)(&remoteAddr);
189                 _readAddr = yNew!Internet6Address(*addr);
190             }
191         }
192 
193         bool doRead() nothrow {
194             _iocpBuffer.len = cast(uint)TCP_READ_BUFFER_SIZE;
195             _iocpBuffer.buf = cast(char*) _readBuffer.ptr;
196             _iocpread.event = &_event;
197             _iocpread.operationType = IOCP_OP_TYPE.read;
198             remoteAddrLen = cast(int) _bindddr.nameLen();
199 
200             DWORD dwReceived = 0;
201             DWORD dwFlags = 0;
202 
203             int nRet = WSARecvFrom(cast(SOCKET) _socket.handle, &_iocpBuffer,
204                 cast(uint) 1, &dwReceived, &dwFlags,
205                 cast(SOCKADDR*)&remoteAddr, &remoteAddrLen, &_iocpread.ol,
206                 cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
207             if (nRet == SOCKET_ERROR) {
208                 DWORD dwLastError = GetLastError();
209                 if (ERROR_IO_PENDING != dwLastError) {
210                     yuCathException(error("WSARecvFrom failed with error: ", dwLastError));
211                     onClose();
212                     return false;
213                 }
214             }
215             return true;
216         }
217 
218     private:
219         IOCP_DATA _iocpread;
220         WSABUF _iocpBuffer;
221 
222         sockaddr remoteAddr; //存储数据来源IP地址
223         int remoteAddrLen; //存储数据来源IP地址长度
224 
225         Address _bindddr;
226         bool _isBind = false;
227     }
228 
229 private:
230     Address _connecto = null;
231     Address _readAddr = null;
232     UdpSocket _socket;
233     AsyncEvent _event;
234     ubyte[] _readBuffer;
235     UDPReadCallBack _readCallBack;
236 }
237 
238 unittest {
239     /*    import std.conv;
240     import std.stdio;
241 
242     EventLoop loop = new EventLoop();
243 
244     UDPSocket server = new UDPSocket(loop);
245     UDPSocket client = new UDPSocket(loop);
246 
247     server.bind(new InternetAddress("127.0.0.1", 9008));
248     Address adr = new InternetAddress("127.0.0.1", 9008);
249     client.connect(adr);
250 
251     int i = 0;
252 
253     void serverHandle(ubyte[] data, Address adr2)
254     {
255         string tstr = cast(string) data;
256         writeln("Server revec data : ", tstr);
257         string str = "hello " ~ to!string(i);
258         server.sendTo(data, adr2);
259         assert(str == tstr);
260         if (i > 10)
261             loop.stop();
262     }
263 
264     void clientHandle(ubyte[] data, Address adr23)
265     {
266         writeln("Client revec data : ", cast(string) data);
267         ++i;
268         string str = "hello " ~ to!string(i);
269         client.sendTo(str);
270     }
271 
272     client.setReadCallBack(&clientHandle);
273     server.setReadCallBack(&serverHandle);
274 
275     client.start();
276     server.start();
277 
278     string str = "hello " ~ to!string(i);
279     client.sendTo(cast(ubyte[]) str);
280     writeln("Edit source/app.d to start your project.");
281     loop.run();
282     server.close();
283     client.close();
284     */
285 }