From e90a3d60c0a3e6a72413adb81041e582212822d4 Mon Sep 17 00:00:00 2001 From: Thulinma Date: Wed, 14 Aug 2019 11:00:48 +0200 Subject: [PATCH] Socket lib restyle and fixes backported from 3.0 --- lib/socket.cpp | 132 +++++++++++++++++++++++++------------------------ lib/socket.h | 120 ++++++++++++++++++++++---------------------- 2 files changed, 128 insertions(+), 124 deletions(-) diff --git a/lib/socket.cpp b/lib/socket.cpp index 87979b6f..72d8245d 100644 --- a/lib/socket.cpp +++ b/lib/socket.cpp @@ -2,8 +2,8 @@ /// A handy Socket wrapper library. /// Written by Jaron Vietor in 2010 for DDVTech -#include "socket.h" #include "defines.h" +#include "socket.h" #include "timing.h" #include #include @@ -37,8 +37,7 @@ static std::string getIPv6BinAddr(const struct sockaddr_in6 &remoteaddr){ char tmpBuffer[17] = "\000\000\000\000\000\000\000\000\000\000\377\377\000\000\000\000"; switch (remoteaddr.sin6_family){ case AF_INET: - memcpy(tmpBuffer + 12, &(reinterpret_cast(&remoteaddr)->sin_addr.s_addr), - 4); + memcpy(tmpBuffer + 12, &(reinterpret_cast(&remoteaddr)->sin_addr.s_addr), 4); break; case AF_INET6: memcpy(tmpBuffer, &(remoteaddr.sin6_addr.s6_addr), 16); break; default: return ""; break; @@ -58,7 +57,7 @@ bool Socket::isLocalhost(const std::string &remotehost){ return false; } -///Checks if the given file descriptor is actually socket or not. +/// Checks if the given file descriptor is actually socket or not. bool Socket::checkTrueSocket(int sock){ struct stat sBuf; if (sock != -1 && !fstat(sock, &sBuf)){return S_ISSOCK(sBuf.st_mode);} @@ -188,15 +187,25 @@ void Socket::hostBytesToStr(const char *bytes, size_t len, std::string &target){ target = tmpstr; break; case 16: + if (memcmp(bytes, "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000", 15) == 0){ + if (bytes[15] == 0){ + target = "::"; + return; + } + char tmpstr[6]; + snprintf(tmpstr, 6, "::%hhu", bytes[15]); + target = tmpstr; + return; + } if (memcmp(bytes, "\000\000\000\000\000\000\000\000\000\000\377\377", 12) == 0){ char tmpstr[16]; snprintf(tmpstr, 16, "%hhu.%hhu.%hhu.%hhu", bytes[12], bytes[13], bytes[14], bytes[15]); target = tmpstr; }else{ char tmpstr[40]; - snprintf(tmpstr, 40, "%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x", bytes[0], bytes[1], - bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], - bytes[14], bytes[15]); + snprintf(tmpstr, 40, "%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x:%.2x%.2x", + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]); target = tmpstr; } break; @@ -237,8 +246,7 @@ unsigned int Socket::Buffer::bytesToSplit(){ unsigned int i = 0; for (std::deque::reverse_iterator it = data.rbegin(); it != data.rend(); ++it){ i += (*it).size(); - if ((*it).size() >= splitter.size() && - (*it).substr((*it).size() - splitter.size()) == splitter){ + if ((*it).size() >= splitter.size() && (*it).substr((*it).size() - splitter.size()) == splitter){ return i; } } @@ -293,8 +301,8 @@ void Socket::Buffer::append(const char *newdata, const unsigned int newdatasize) } } if (data.size() > 5000){ - WARN_MSG("Warning: After %d new bytes, buffer has %d parts containing over %u bytes!", newdatasize, (int)data.size(), - bytes(9000)); + WARN_MSG("Warning: After %d new bytes, buffer has %d parts containing over %u bytes!", + newdatasize, (int)data.size(), bytes(9000)); } } @@ -386,13 +394,11 @@ void Socket::Connection::setBoundAddr(){ } struct sockaddr_in6 tmpaddr; socklen_t len = sizeof(tmpaddr); - if (!getsockname(sSend, (sockaddr*)&tmpaddr, &len)){ + if (!getsockname(sSend, (sockaddr *)&tmpaddr, &len)){ static char addrconv[INET6_ADDRSTRLEN]; if (tmpaddr.sin6_family == AF_INET6){ boundaddr = inet_ntop(AF_INET6, &(tmpaddr.sin6_addr), addrconv, INET6_ADDRSTRLEN); - if (boundaddr.substr(0, 7) == "::ffff:"){ - boundaddr = boundaddr.substr(7); - } + if (boundaddr.substr(0, 7) == "::ffff:"){boundaddr = boundaddr.substr(7);} HIGH_MSG("Local IPv6 addr [%s]", boundaddr.c_str()); } if (tmpaddr.sin6_family == AF_INET){ @@ -402,13 +408,12 @@ void Socket::Connection::setBoundAddr(){ } } -//Cleans up the socket by dropping the connection. -//Does not call close because it calls shutdown, which would destroy any copies of this socket too. +// Cleans up the socket by dropping the connection. +// Does not call close because it calls shutdown, which would destroy any copies of this socket too. Socket::Connection::~Connection(){ drop(); } - /// Create a new base socket. This is a basic constructor for converting any valid socket to a /// Socket::Connection. \param sockNo Integer representing the socket to convert. Socket::Connection::Connection(int sockNo){ @@ -452,7 +457,7 @@ void Socket::Connection::clear(){ isTrueSocket = false; up = 0; down = 0; - conntime = Util::epoch(); + conntime = Util::bootSecs(); Error = false; Blocking = false; skipCount = 0; @@ -548,9 +553,7 @@ void Socket::Connection::drop(){ #ifdef SSL if (sslConnected){ DONTEVEN_MSG("SSL close"); - if (ssl){ - mbedtls_ssl_close_notify(ssl); - } + if (ssl){mbedtls_ssl_close_notify(ssl);} if (server_fd){ mbedtls_net_free(server_fd); delete server_fd; @@ -653,7 +656,7 @@ void Socket::Connection::open(std::string address, bool nonblock){ } #ifdef SSL -///Local-only function for debugging SSL sockets +/// Local-only function for debugging SSL sockets static void my_debug(void *ctx, int level, const char *file, int line, const char *str){ ((void)level); fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str); @@ -689,15 +692,15 @@ void Socket::Connection::open(std::string host, int port, bool nonblock, bool wi mbedtls_ctr_drbg_init(ctr_drbg); mbedtls_entropy_init(entropy); DONTEVEN_MSG("SSL init"); - if (mbedtls_ctr_drbg_seed(ctr_drbg, mbedtls_entropy_func, entropy, (const unsigned char *)"meow", - 4) != 0){ + if (mbedtls_ctr_drbg_seed(ctr_drbg, mbedtls_entropy_func, entropy, (const unsigned char *)"meow", 4) != 0){ FAIL_MSG("SSL socket init failed"); close(); return; } DONTEVEN_MSG("SSL connect"); int ret = 0; - if ((ret = mbedtls_net_connect(server_fd, host.c_str(), JSON::Value(port).asString().c_str(), MBEDTLS_NET_PROTO_TCP)) != 0){ + if ((ret = mbedtls_net_connect(server_fd, host.c_str(), JSON::Value(port).asString().c_str(), + MBEDTLS_NET_PROTO_TCP)) != 0){ FAIL_MSG(" failed\n ! mbedtls_net_connect returned %d\n\n", ret); close(); return; @@ -762,7 +765,10 @@ void Socket::Connection::open(std::string host, int port, bool nonblock, bool wi for (rp = result; rp != NULL; rp = rp->ai_next){ sSend = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (sSend < 0){continue;} - if (connect(sSend, rp->ai_addr, rp->ai_addrlen) == 0){break;} + if (connect(sSend, rp->ai_addr, rp->ai_addrlen) == 0){ + remoteaddr = *((sockaddr_in6 *)rp->ai_addr); + break; + } remotehost += strerror(errno); ::close(sSend); } @@ -811,13 +817,6 @@ uint64_t Socket::Connection::dataDown(){ return down; } -/// Returns a std::string of stats, ended by a newline. -/// Requires the current connector name as an argument. -std::string Socket::Connection::getStats(std::string C){ - return "S " + getHost() + " " + C + " " + uint2string(Util::epoch() - conntime) + " " + - uint2string(up) + " " + uint2string(down) + "\n"; -} - /// Updates the downbuffer internal variable. /// Returns true if new data was received, false otherwise. bool Socket::Connection::spool(){ @@ -962,7 +961,10 @@ int Socket::Connection::iread(void *buffer, int len, int flags){ r = mbedtls_ssl_read(ssl, (unsigned char *)buffer, len); if (r < 0){ switch (errno){ - case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: close(); return 0; break; + case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: + close(); + return 0; + break; case MBEDTLS_ERR_SSL_WANT_WRITE: return 0; break; case MBEDTLS_ERR_SSL_WANT_READ: return 0; break; case EWOULDBLOCK: return 0; break; @@ -1092,17 +1094,18 @@ Socket::Connection::operator bool() const{ return connected(); } -//Copy constructor -Socket::Connection::Connection(const Connection& rhs){ +// Copy constructor +Socket::Connection::Connection(const Connection &rhs){ clear(); if (!rhs){return;} #if DEBUG >= DLVL_DEVEL - HIGH_MSG("Copying %s socket", rhs.sslConnected?"SSL":"regular"); + HIGH_MSG("Copying %s socket", rhs.sslConnected ? "SSL" : "regular"); #endif conntime = rhs.conntime; isTrueSocket = rhs.isTrueSocket; remotehost = rhs.remotehost; boundaddr = rhs.boundaddr; + remoteaddr = rhs.remoteaddr; up = rhs.up; down = rhs.down; downbuffer = rhs.downbuffer; @@ -1119,18 +1122,19 @@ Socket::Connection::Connection(const Connection& rhs){ #endif } -//Assignment constructor -Socket::Connection& Socket::Connection::operator=(const Socket::Connection& rhs){ +// Assignment constructor +Socket::Connection &Socket::Connection::operator=(const Socket::Connection &rhs){ drop(); clear(); if (!rhs){return *this;} #if DEBUG >= DLVL_DEVEL - HIGH_MSG("Assigning %s socket", rhs.sslConnected?"SSL":"regular"); + HIGH_MSG("Assigning %s socket", rhs.sslConnected ? "SSL" : "regular"); #endif conntime = rhs.conntime; isTrueSocket = rhs.isTrueSocket; remotehost = rhs.remotehost; boundaddr = rhs.boundaddr; + remoteaddr = rhs.remoteaddr; up = rhs.up; down = rhs.down; downbuffer = rhs.downbuffer; @@ -1160,7 +1164,6 @@ bool Socket::Connection::isLocal(){ return Socket::isLocal(remotehost); } - /// Create a new base Server. The socket is never connected, and a placeholder for later /// connections. Socket::Server::Server(){ @@ -1378,8 +1381,7 @@ Socket::Connection Socket::Server::accept(bool nonblock){ HIGH_MSG("IPv6 addr [%s]", tmp.remotehost.c_str()); } if (tmpaddr.sin6_family == AF_INET){ - tmp.remotehost = - inet_ntop(AF_INET, &(((sockaddr_in *)&tmpaddr)->sin_addr), addrconv, INET6_ADDRSTRLEN); + tmp.remotehost = inet_ntop(AF_INET, &(((sockaddr_in *)&tmpaddr)->sin_addr), addrconv, INET6_ADDRSTRLEN); HIGH_MSG("IPv4 addr [%s]", tmp.remotehost.c_str()); } if (tmpaddr.sin6_family == AF_UNIX){ @@ -1443,6 +1445,9 @@ int Socket::Server::getSocket(){ /// If both fail, prints an DLVL_FAIL debug message. /// \param nonblock Whether the socket should be nonblocking. Socket::UDPConnection::UDPConnection(bool nonblock){ + boundPort = 0; + boundAddr = ""; + boundMulti = ""; family = AF_INET6; sock = socket(AF_INET6, SOCK_DGRAM, 0); if (sock == -1){ @@ -1546,11 +1551,14 @@ void Socket::UDPConnection::SetDestination(std::string destIp, uint32_t port){ if (!destAddr){return;} memcpy(destAddr, rp->ai_addr, rp->ai_addrlen); if (family != rp->ai_family){ - INFO_MSG("Socket is wrong type (%s), re-opening as %s", addrFam(family), - addrFam(rp->ai_family)); + INFO_MSG("Socket is wrong type (%s), re-opening as %s", addrFam(family), addrFam(rp->ai_family)); close(); family = rp->ai_family; sock = socket(family, SOCK_DGRAM, 0); + if (boundPort){ + INFO_MSG("Rebinding to %s:%d %s", boundAddr.c_str(), boundPort, boundMulti.c_str()); + bind(boundPort, boundAddr, boundMulti); + } } HIGH_MSG("Set UDP destination: %s:%d (%s)", destIp.c_str(), port, addrFam(family)); freeaddrinfo(result); @@ -1574,16 +1582,14 @@ void Socket::UDPConnection::GetDestination(std::string &destIp, uint32_t &port){ char addr_str[INET6_ADDRSTRLEN + 1]; addr_str[INET6_ADDRSTRLEN] = 0; // set last byte to zero, to prevent walking out of the array if (((struct sockaddr_in *)destAddr)->sin_family == AF_INET6){ - if (inet_ntop(AF_INET6, &(((struct sockaddr_in6 *)destAddr)->sin6_addr), addr_str, - INET6_ADDRSTRLEN) != 0){ + if (inet_ntop(AF_INET6, &(((struct sockaddr_in6 *)destAddr)->sin6_addr), addr_str, INET6_ADDRSTRLEN) != 0){ destIp = addr_str; port = ntohs(((struct sockaddr_in6 *)destAddr)->sin6_port); return; } } if (((struct sockaddr_in *)destAddr)->sin_family == AF_INET){ - if (inet_ntop(AF_INET, &(((struct sockaddr_in *)destAddr)->sin_addr), addr_str, - INET6_ADDRSTRLEN) != 0){ + if (inet_ntop(AF_INET, &(((struct sockaddr_in *)destAddr)->sin_addr), addr_str, INET6_ADDRSTRLEN) != 0){ destIp = addr_str; port = ntohs(((struct sockaddr_in *)destAddr)->sin_port); return; @@ -1647,8 +1653,7 @@ void Socket::UDPConnection::SendNow(const char *sdata, size_t len){ /// \arg multicastInterfaces Comma-separated list of interfaces to listen on for multicast packets. /// Optional, left out means automatically chosen by kernel. \return Actually bound port number, or /// zero on error. -uint16_t Socket::UDPConnection::bind(int port, std::string iface, - const std::string &multicastInterfaces){ +uint16_t Socket::UDPConnection::bind(int port, std::string iface, const std::string &multicastInterfaces){ close(); // we open a new socket for each attempt int addr_ret; bool multicast = false; @@ -1722,6 +1727,9 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, portNo = ntohs(((struct sockaddr_in *)&fin_addr)->sin_port); } } + boundAddr = iface; + boundMulti = multicastInterfaces; + boundPort = portNo; INFO_MSG("UDP bind success on %s:%u (%s)", human_addr, portNo, addrFam(rp->ai_family)); break; } @@ -1769,8 +1777,7 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, }else{ size_t nxtPos = std::string::npos; bool atLeastOne = false; - for (size_t loc = 0; loc != std::string::npos; - loc = (nxtPos == std::string::npos ? nxtPos : nxtPos + 1)){ + for (size_t loc = 0; loc != std::string::npos; loc = (nxtPos == std::string::npos ? nxtPos : nxtPos + 1)){ nxtPos = multicastInterfaces.find(',', loc); std::string curIface = multicastInterfaces.substr(loc, (nxtPos == std::string::npos ? nxtPos : nxtPos - loc)); @@ -1785,25 +1792,22 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, if (family == AF_INET6){ INFO_MSG("Registering for IPv6 multicast on interface %s", curIface.c_str()); if ((addr_ret = getaddrinfo(curIface.c_str(), 0, &hints, &reslocal)) != 0){ - WARN_MSG("Unable to resolve IPv6 interface address %s: %s", curIface.c_str(), - gai_strerror(addr_ret)); + WARN_MSG("Unable to resolve IPv6 interface address %s: %s", curIface.c_str(), gai_strerror(addr_ret)); continue; } memcpy(&mreq6.ipv6mr_multiaddr, &((sockaddr_in6 *)resmulti->ai_addr)->sin6_addr, sizeof(mreq6.ipv6mr_multiaddr)); mreq6.ipv6mr_interface = ((sockaddr_in6 *)reslocal->ai_addr)->sin6_scope_id; if (setsockopt(sock, IPPROTO_IPV6, IPV6_JOIN_GROUP, (char *)&mreq6, sizeof(mreq6)) != 0){ - FAIL_MSG("Unable to register for IPv6 multicast on interface %s (%u): %s", - curIface.c_str(), ((sockaddr_in6 *)reslocal->ai_addr)->sin6_scope_id, - strerror(errno)); + FAIL_MSG("Unable to register for IPv6 multicast on interface %s (%u): %s", curIface.c_str(), + ((sockaddr_in6 *)reslocal->ai_addr)->sin6_scope_id, strerror(errno)); }else{ atLeastOne = true; } }else{ INFO_MSG("Registering for IPv4 multicast on interface %s", curIface.c_str()); if ((addr_ret = getaddrinfo(curIface.c_str(), 0, &hints, &reslocal)) != 0){ - WARN_MSG("Unable to resolve IPv4 interface address %s: %s", curIface.c_str(), - gai_strerror(addr_ret)); + WARN_MSG("Unable to resolve IPv4 interface address %s: %s", curIface.c_str(), gai_strerror(addr_ret)); continue; } mreq4.imr_multiaddr = ((sockaddr_in *)resmulti->ai_addr)->sin_addr; @@ -1815,9 +1819,7 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, atLeastOne = true; } } - if (!atLeastOne){ - close(); - } + if (!atLeastOne){close();} freeaddrinfo(reslocal); // free resolved interface addr }// loop over all interfaces } @@ -1831,6 +1833,7 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, /// If a packet is received, it will be placed in the "data" member, with it's length in "data_len". /// \return True if a packet was received, false otherwise. bool Socket::UDPConnection::Receive(){ + if (sock == -1){return false;} #ifdef __CYGWIN__ if (data_size != SOCKETSIZE){ data = (char *)realloc(data, SOCKETSIZE); @@ -1867,4 +1870,3 @@ bool Socket::UDPConnection::Receive(){ int Socket::UDPConnection::getSock(){ return sock; } - diff --git a/lib/socket.h b/lib/socket.h index d6ac4249..880f6113 100644 --- a/lib/socket.h +++ b/lib/socket.h @@ -17,12 +17,12 @@ #include #ifdef SSL -#include "mbedtls/net.h" -#include "mbedtls/ssl.h" -#include "mbedtls/entropy.h" #include "mbedtls/ctr_drbg.h" #include "mbedtls/debug.h" +#include "mbedtls/entropy.h" #include "mbedtls/error.h" +#include "mbedtls/net.h" +#include "mbedtls/ssl.h" #endif // for being friendly with Socket::Connection down below @@ -38,9 +38,9 @@ namespace Socket{ bool matchIPv6Addr(const std::string &A, const std::string &B, uint8_t prefix); std::string getBinForms(std::string addr); /// Returns true if given human-readable address (address, not hostname) is a local address. - bool isLocal(const std::string & host); + bool isLocal(const std::string &host); /// Returns true if given human-readable hostname is a local address. - bool isLocalhost(const std::string & host); + bool isLocalhost(const std::string &host); bool checkTrueSocket(int sock); /// A buffer made out of std::string objects that can be efficiently read from and written to. @@ -49,7 +49,7 @@ namespace Socket{ std::deque data; public: - std::string splitter;///