Custom godot sockets for ENet now support DTLS.

Non-DTLS implementation uses plain NetSocket for performance as before.
This commit is contained in:
Fabio Alessandrelli 2020-01-13 13:10:12 +01:00
parent 35a9f0fe64
commit aea3c4d93e
2 changed files with 321 additions and 56 deletions

View file

@ -578,6 +578,8 @@ ENET_API void enet_host_channel_limit (ENetHost *, size_t);
ENET_API void enet_host_bandwidth_limit (ENetHost *, enet_uint32, enet_uint32); ENET_API void enet_host_bandwidth_limit (ENetHost *, enet_uint32, enet_uint32);
extern void enet_host_bandwidth_throttle (ENetHost *); extern void enet_host_bandwidth_throttle (ENetHost *);
extern enet_uint32 enet_host_random_seed (void); extern enet_uint32 enet_host_random_seed (void);
ENET_API void enet_host_dtls_server_setup (ENetHost *, void *, void *);
ENET_API void enet_host_dtls_client_setup (ENetHost *, void *, uint8_t, const char *);
ENET_API int enet_peer_send (ENetPeer *, enet_uint8, ENetPacket *); ENET_API int enet_peer_send (ENetPeer *, enet_uint8, ENetPacket *);
ENET_API ENetPacket * enet_peer_receive (ENetPeer *, enet_uint8 * channelID); ENET_API ENetPacket * enet_peer_receive (ENetPeer *, enet_uint8 * channelID);

View file

@ -32,13 +32,313 @@
@brief ENet Godot specific functions @brief ENet Godot specific functions
*/ */
#include "core/io/dtls_server.h"
#include "core/io/ip.h" #include "core/io/ip.h"
#include "core/io/net_socket.h" #include "core/io/net_socket.h"
#include "core/io/packet_peer_dtls.h"
#include "core/io/udp_server.h"
#include "core/os/os.h" #include "core/os/os.h"
// This must be last for windows to compile (tested with MinGW) // This must be last for windows to compile (tested with MinGW)
#include "enet/enet.h" #include "enet/enet.h"
/// Abstract ENet interface for UDP/DTLS.
class ENetGodotSocket {
public:
virtual Error bind(IP_Address p_ip, uint16_t p_port) = 0;
virtual Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) = 0;
virtual Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) = 0;
virtual int set_option(ENetSocketOption p_option, int p_value) = 0;
virtual void close() = 0;
virtual ~ENetGodotSocket(){};
};
class ENetDTLSClient;
class ENetDTLSServer;
/// NetSocket interface
class ENetUDP : public ENetGodotSocket {
friend class ENetDTLSClient;
friend class ENetDTLSServer;
private:
Ref<NetSocket> sock;
IP_Address address;
uint16_t port;
bool bound;
public:
ENetUDP() {
sock = Ref<NetSocket>(NetSocket::create());
IP::Type ip_type = IP::TYPE_ANY;
bound = false;
sock->open(NetSocket::TYPE_UDP, ip_type);
}
~ENetUDP() {
sock->close();
}
Error bind(IP_Address p_ip, uint16_t p_port) {
address = p_ip;
port = p_port;
bound = true;
return sock->bind(address, port);
}
Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) {
return sock->sendto(p_buffer, p_len, r_sent, p_ip, p_port);
}
Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) {
Error err = sock->poll(NetSocket::POLL_TYPE_IN, 0);
if (err != OK)
return err;
return sock->recvfrom(p_buffer, p_len, r_read, r_ip, r_port);
}
int set_option(ENetSocketOption p_option, int p_value) {
switch (p_option) {
case ENET_SOCKOPT_NONBLOCK: {
sock->set_blocking_enabled(p_value ? false : true);
return 0;
} break;
case ENET_SOCKOPT_BROADCAST: {
sock->set_broadcasting_enabled(p_value ? true : false);
return 0;
} break;
case ENET_SOCKOPT_REUSEADDR: {
sock->set_reuse_address_enabled(p_value ? true : false);
return 0;
} break;
case ENET_SOCKOPT_RCVBUF: {
return -1;
} break;
case ENET_SOCKOPT_SNDBUF: {
return -1;
} break;
case ENET_SOCKOPT_RCVTIMEO: {
return -1;
} break;
case ENET_SOCKOPT_SNDTIMEO: {
return -1;
} break;
case ENET_SOCKOPT_NODELAY: {
sock->set_tcp_no_delay_enabled(p_value ? true : false);
return 0;
} break;
}
return -1;
}
void close() {
sock->close();
}
};
/// DTLS Client ENet interface
class ENetDTLSClient : public ENetGodotSocket {
bool connected;
Ref<PacketPeerUDP> udp;
Ref<PacketPeerDTLS> dtls;
bool verify;
String for_hostname;
Ref<X509Certificate> cert;
public:
ENetDTLSClient(ENetUDP *p_base, Ref<X509Certificate> p_cert, bool p_verify, String p_for_hostname) {
verify = p_verify;
for_hostname = p_for_hostname;
cert = p_cert;
udp.instance();
dtls = Ref<PacketPeerDTLS>(PacketPeerDTLS::create());
p_base->close();
if (p_base->bound) {
bind(p_base->address, p_base->port);
}
connected = false;
}
~ENetDTLSClient() {
close();
}
Error bind(IP_Address p_ip, uint16_t p_port) {
return udp->listen(p_port, p_ip);
}
Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) {
if (!connected) {
udp->connect_to_host(p_ip, p_port);
dtls->connect_to_peer(udp, verify, for_hostname, cert);
connected = true;
}
dtls->poll();
if (dtls->get_status() == PacketPeerDTLS::STATUS_HANDSHAKING)
return ERR_BUSY;
else if (dtls->get_status() != PacketPeerDTLS::STATUS_CONNECTED)
return FAILED;
r_sent = p_len;
return dtls->put_packet(p_buffer, p_len);
}
Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) {
dtls->poll();
if (dtls->get_status() == PacketPeerDTLS::STATUS_HANDSHAKING)
return ERR_BUSY;
if (dtls->get_status() != PacketPeerDTLS::STATUS_CONNECTED)
return FAILED;
int pc = dtls->get_available_packet_count();
if (pc == 0)
return ERR_BUSY;
else if (pc < 0)
return FAILED;
const uint8_t *buffer;
Error err = dtls->get_packet(&buffer, r_read);
ERR_FAIL_COND_V(err != OK, err);
ERR_FAIL_COND_V(p_len < r_read, ERR_OUT_OF_MEMORY);
copymem(p_buffer, buffer, r_read);
r_ip = udp->get_packet_address();
r_port = udp->get_packet_port();
return err;
}
int set_option(ENetSocketOption p_option, int p_value) {
return -1;
}
void close() {
dtls->disconnect_from_peer();
udp->close();
}
};
/// DTLSServer - ENet interface
class ENetDTLSServer : public ENetGodotSocket {
Ref<DTLSServer> server;
Ref<UDPServer> udp_server;
Map<String, Ref<PacketPeerDTLS> > peers;
int last_service;
public:
ENetDTLSServer(ENetUDP *p_base, Ref<CryptoKey> p_key, Ref<X509Certificate> p_cert) {
last_service = 0;
udp_server.instance();
p_base->close();
if (p_base->bound) {
bind(p_base->address, p_base->port);
}
server = Ref<DTLSServer>(DTLSServer::create());
server->setup(p_key, p_cert);
}
~ENetDTLSServer() {
close();
}
Error bind(IP_Address p_ip, uint16_t p_port) {
return udp_server->listen(p_port, p_ip);
}
Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) {
String key = String(p_ip) + ":" + itos(p_port);
ERR_FAIL_COND_V(!peers.has(key), ERR_UNAVAILABLE);
Ref<PacketPeerDTLS> peer = peers[key];
Error err = peer->put_packet(p_buffer, p_len);
if (err == OK)
r_sent = p_len;
else if (err == ERR_BUSY)
r_sent = 0;
else
r_sent = -1;
return err;
}
Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) {
// TODO limits? Maybe we can better enforce allowed connections!
if (udp_server->is_connection_available()) {
Ref<PacketPeerUDP> udp = udp_server->take_connection();
IP_Address peer_ip = udp->get_packet_address();
int peer_port = udp->get_packet_port();
Ref<PacketPeerDTLS> peer = server->take_connection(udp);
PacketPeerDTLS::Status status = peer->get_status();
if (status == PacketPeerDTLS::STATUS_HANDSHAKING || status == PacketPeerDTLS::STATUS_CONNECTED) {
String key = String(peer_ip) + ":" + itos(peer_port);
peers[key] = peer;
}
}
List<String> remove;
Error err = ERR_BUSY;
// TODO this needs to be fair!
for (Map<String, Ref<PacketPeerDTLS> >::Element *E = peers.front(); E; E = E->next()) {
Ref<PacketPeerDTLS> peer = E->get();
peer->poll();
if (peer->get_status() == PacketPeerDTLS::STATUS_HANDSHAKING)
continue;
else if (peer->get_status() != PacketPeerDTLS::STATUS_CONNECTED) {
// Peer disconnected, removing it.
remove.push_back(E->key());
continue;
}
if (peer->get_available_packet_count() > 0) {
const uint8_t *buffer;
err = peer->get_packet(&buffer, r_read);
if (err != OK || p_len < r_read) {
// Something wrong with this peer, removing it.
remove.push_back(E->key());
err = FAILED;
continue;
}
Vector<String> s = E->key().rsplit(":", false, 1);
ERR_CONTINUE(s.size() != 2); // BUG!
copymem(p_buffer, buffer, r_read);
r_ip = s[0];
r_port = s[1].to_int();
break; // err = OK
}
}
// Remove disconnected peers from map.
for (List<String>::Element *E = remove.front(); E; E = E->next()) {
peers.erase(E->get());
}
return err; // OK, ERR_BUSY, or possibly an error.
}
int set_option(ENetSocketOption p_option, int p_value) {
return -1;
}
void close() {
for (Map<String, Ref<PacketPeerDTLS> >::Element *E = peers.front(); E; E = E->next()) {
E->get()->disconnect_from_peer();
}
peers.clear();
udp_server->stop();
server->stop();
}
};
static enet_uint32 timeBase = 0; static enet_uint32 timeBase = 0;
int enet_initialize(void) { int enet_initialize(void) {
@ -92,13 +392,23 @@ int enet_address_get_host(const ENetAddress *address, char *name, size_t nameLen
ENetSocket enet_socket_create(ENetSocketType type) { ENetSocket enet_socket_create(ENetSocketType type) {
NetSocket *socket = NetSocket::create(); ENetUDP *socket = memnew(ENetUDP);
IP::Type ip_type = IP::TYPE_ANY;
socket->open(NetSocket::TYPE_UDP, ip_type);
return socket; return socket;
} }
void enet_host_dtls_server_setup(ENetHost *host, void *p_key, void *p_cert) {
ENetUDP *sock = (ENetUDP *)host->socket;
host->socket = memnew(ENetDTLSServer(sock, Ref<CryptoKey>((CryptoKey *)p_key), Ref<X509Certificate>((X509Certificate *)p_cert)));
memdelete(sock);
}
void enet_host_dtls_client_setup(ENetHost *host, void *p_cert, uint8_t p_verify, const char *p_for_hostname) {
ENetUDP *sock = (ENetUDP *)host->socket;
host->socket = memnew(ENetDTLSClient(sock, Ref<X509Certificate>((X509Certificate *)p_cert), p_verify, String(p_for_hostname)));
memdelete(sock);
}
int enet_socket_bind(ENetSocket socket, const ENetAddress *address) { int enet_socket_bind(ENetSocket socket, const ENetAddress *address) {
IP_Address ip; IP_Address ip;
@ -108,7 +418,7 @@ int enet_socket_bind(ENetSocket socket, const ENetAddress *address) {
ip.set_ipv6(address->host); ip.set_ipv6(address->host);
} }
NetSocket *sock = (NetSocket *)socket; ENetGodotSocket *sock = (ENetGodotSocket *)socket;
if (sock->bind(ip, address->port) != OK) { if (sock->bind(ip, address->port) != OK) {
return -1; return -1;
} }
@ -116,7 +426,7 @@ int enet_socket_bind(ENetSocket socket, const ENetAddress *address) {
} }
void enet_socket_destroy(ENetSocket socket) { void enet_socket_destroy(ENetSocket socket) {
NetSocket *sock = (NetSocket *)socket; ENetGodotSocket *sock = (ENetGodotSocket *)socket;
sock->close(); sock->close();
memdelete(sock); memdelete(sock);
} }
@ -125,7 +435,7 @@ int enet_socket_send(ENetSocket socket, const ENetAddress *address, const ENetBu
ERR_FAIL_COND_V(address == NULL, -1); ERR_FAIL_COND_V(address == NULL, -1);
NetSocket *sock = (NetSocket *)socket; ENetGodotSocket *sock = (ENetGodotSocket *)socket;
IP_Address dest; IP_Address dest;
Error err; Error err;
size_t i = 0; size_t i = 0;
@ -167,15 +477,7 @@ int enet_socket_receive(ENetSocket socket, ENetAddress *address, ENetBuffer *buf
ERR_FAIL_COND_V(bufferCount != 1, -1); ERR_FAIL_COND_V(bufferCount != 1, -1);
NetSocket *sock = (NetSocket *)socket; ENetGodotSocket *sock = (ENetGodotSocket *)socket;
Error ret = sock->poll(NetSocket::POLL_TYPE_IN, 0);
if (ret == ERR_BUSY)
return 0;
if (ret != OK)
return -1;
int read; int read;
IP_Address ip; IP_Address ip;
@ -215,47 +517,8 @@ int enet_socket_listen(ENetSocket socket, int backlog) {
int enet_socket_set_option(ENetSocket socket, ENetSocketOption option, int value) { int enet_socket_set_option(ENetSocket socket, ENetSocketOption option, int value) {
NetSocket *sock = (NetSocket *)socket; ENetGodotSocket *sock = (ENetGodotSocket *)socket;
return sock->set_option(option, value);
switch (option) {
case ENET_SOCKOPT_NONBLOCK: {
sock->set_blocking_enabled(value ? false : true);
return 0;
} break;
case ENET_SOCKOPT_BROADCAST: {
sock->set_broadcasting_enabled(value ? true : false);
return 0;
} break;
case ENET_SOCKOPT_REUSEADDR: {
sock->set_reuse_address_enabled(value ? true : false);
return 0;
} break;
case ENET_SOCKOPT_RCVBUF: {
return -1;
} break;
case ENET_SOCKOPT_SNDBUF: {
return -1;
} break;
case ENET_SOCKOPT_RCVTIMEO: {
return -1;
} break;
case ENET_SOCKOPT_SNDTIMEO: {
return -1;
} break;
case ENET_SOCKOPT_NODELAY: {
sock->set_tcp_no_delay_enabled(value ? true : false);
return 0;
} break;
}
return -1;
} }
int enet_socket_get_option(ENetSocket socket, ENetSocketOption option, int *value) { int enet_socket_get_option(ENetSocket socket, ENetSocketOption option, int *value) {