diff --git a/CMakeLists.txt b/CMakeLists.txt index 263a062c..72a54e38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -824,4 +824,6 @@ target_link_libraries(bitwritertest mist) add_test(BitWriterTest COMMAND bitwritertest) add_executable(streamstatustest test/status.cpp ${BINARY_DIR}/mist/.headers) target_link_libraries(streamstatustest mist) +add_executable(websockettest test/websocket.cpp ${BINARY_DIR}/mist/.headers) +target_link_libraries(websockettest mist) diff --git a/lib/downloader.cpp b/lib/downloader.cpp index 346c3673..547e4fc7 100644 --- a/lib/downloader.cpp +++ b/lib/downloader.cpp @@ -12,6 +12,7 @@ namespace HTTP{ retryCount = 5; ssl = false; proxied = false; + sPtr = 0; char *p = getenv("http_proxy"); if (p){ proxyUrl = HTTP::URL(p); @@ -59,16 +60,28 @@ namespace HTTP{ /// Returns a reference to the internal HTTP class instance. Parser &Downloader::getHTTP(){return H;} - /// Returns a reference to the internal Socket::Connection class instance. - Socket::Connection &Downloader::getSocket(){return S;} - const Socket::Connection &Downloader::getSocket() const{return S;} + /// Returns a reference to the internal Socket::Connection class instance, or the override, if in use. + Socket::Connection &Downloader::getSocket(){ + if (sPtr){return *sPtr;} + return S; + } + + const Socket::Connection &Downloader::getSocket() const{ + if (sPtr){return *sPtr;} + return S; + } + + ///Sets an override to use the given socket + void Downloader::setSocket(Socket::Connection * socketPtr){ + sPtr = socketPtr; + } Downloader::~Downloader(){S.close();} /// Prepares a request for the given URL, does not send anything void Downloader::prepareRequest(const HTTP::URL &link, const std::string &method){ if (!canRequest(link)){return;} - bool needSSL = (link.protocol == "https"); + bool needSSL = (link.protocol == "https" || link.protocol == "wss"); H.Clean(); // Reconnect if needed if (!proxied || needSSL){ @@ -78,12 +91,12 @@ namespace HTTP{ connectedPort = link.getPort(); #ifdef SSL if (needSSL){ - S.open(connectedHost, connectedPort, true, true); + getSocket().open(connectedHost, connectedPort, true, true); }else{ - S.open(connectedHost, connectedPort, true); + getSocket().open(connectedHost, connectedPort, true); } #else - S.open(connectedHost, connectedPort, true); + getSocket().open(connectedHost, connectedPort, true); #endif } }else{ @@ -91,12 +104,12 @@ namespace HTTP{ getSocket().close(); connectedHost = proxyUrl.host; connectedPort = proxyUrl.getPort(); - S.open(connectedHost, connectedPort, true); + getSocket().open(connectedHost, connectedPort, true); } } ssl = needSSL; if (!getSocket()){ - H.method = S.getError(); + H.method = getSocket().getError(); return; // socket is closed } if (proxied && !ssl){ @@ -440,12 +453,12 @@ namespace HTTP{ bool Downloader::canRequest(const HTTP::URL &link){ if (!link.host.size()){return false;} - if (link.protocol != "http" && link.protocol != "https"){ + if (link.protocol != "http" && link.protocol != "https" && link.protocol != "ws" && link.protocol != "wss"){ FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); return false; } #ifndef SSL - if (link.protocol == "https"){ + if (link.protocol == "https" || link.protocol == "wss"){ FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); return false; } diff --git a/lib/downloader.h b/lib/downloader.h index 3956efad..494a1b7f 100644 --- a/lib/downloader.h +++ b/lib/downloader.h @@ -45,6 +45,7 @@ namespace HTTP{ Parser &getHTTP(); Socket::Connection &getSocket(); const Socket::Connection &getSocket() const; + void setSocket(Socket::Connection * socketPtr); uint32_t retryCount, dataTimeout; bool isProxied() const; const HTTP::URL &lastURL(); @@ -56,6 +57,7 @@ namespace HTTP{ uint32_t connectedPort; ///< Currently connected port number Parser H; ///< HTTP parser for downloader Socket::Connection S; ///< TCP socket for downloader + Socket::Connection * sPtr; ///< TCP socket override, when wanting to use an external socket bool ssl; ///< True if ssl is currently in use. std::string authStr; ///< Most recently seen WWW-Authenticate request std::string proxyAuthStr; ///< Most recently seen Proxy-Authenticate request diff --git a/lib/http_parser.cpp b/lib/http_parser.cpp index b4bdad23..f2aed4b0 100644 --- a/lib/http_parser.cpp +++ b/lib/http_parser.cpp @@ -629,6 +629,11 @@ bool HTTP::Parser::parse(std::string &HTTPbuffer, Util::DataCallback &cb){ } if (seenHeaders){ if (headerOnly){return true;} + //Check if we have a response code that may never have a body + if (url.size() && url[0] >= '0' && url[0] <= '9'){ + unsigned int code = atoi(url.data()); + if ((code >= 100 && code < 200) || code == 204 || code == 304){return true;} + } if (length > 0 && !getChunks){ unsigned int toappend = length - body.length(); diff --git a/lib/socket.cpp b/lib/socket.cpp index 3e8244da..ceb91174 100644 --- a/lib/socket.cpp +++ b/lib/socket.cpp @@ -795,15 +795,19 @@ void Socket::Connection::open(std::string host, int port, bool nonblock, bool wi int ret = 0; if ((ret = mbedtls_net_connect(server_fd, host.c_str(), JSON::Value(port).asString().c_str(), MBEDTLS_NET_PROTO_TCP)) != 0){ - lastErr = "mbedtls_net_connect failed"; - FAIL_MSG(" failed\n ! mbedtls_net_connect returned %d\n\n", ret); + char estr[200]; + mbedtls_strerror(ret, estr, 200); + lastErr = estr; + FAIL_MSG("SSL connect failed: %d: %s", ret, lastErr.c_str()); close(); return; } if ((ret = mbedtls_ssl_config_defaults(conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0){ - lastErr = "mbedtls_ssl_config_defaults failed"; - FAIL_MSG(" failed\n ! mbedtls_ssl_config_defaults returned %d\n\n", ret); + char estr[200]; + mbedtls_strerror(ret, estr, 200); + lastErr = estr; + FAIL_MSG("SSL config failed: %d: %s", ret, lastErr.c_str()); close(); return; } @@ -819,7 +823,10 @@ void Socket::Connection::open(std::string host, int port, bool nonblock, bool wi return; } if ((ret = mbedtls_ssl_set_hostname(ssl, host.c_str())) != 0){ - FAIL_MSG(" failed\n ! mbedtls_ssl_set_hostname returned %d\n\n", ret); + char estr[200]; + mbedtls_strerror(ret, estr, 200); + lastErr = estr; + FAIL_MSG("SSL setup error %d: %s", ret, lastErr.c_str()); close(); return; } diff --git a/lib/url.cpp b/lib/url.cpp index c30b80ce..09da6f10 100644 --- a/lib/url.cpp +++ b/lib/url.cpp @@ -147,10 +147,13 @@ uint32_t HTTP::URL::getPort() const{ uint32_t HTTP::URL::getDefaultPort() const{ if (protocol == "http"){return 80;} if (protocol == "https"){return 443;} + if (protocol == "ws"){return 80;} + if (protocol == "wss"){return 443;} if (protocol == "rtmp"){return 1935;} if (protocol == "rtmps"){return 443;} if (protocol == "dtsc"){return 4200;} if (protocol == "rtsp"){return 554;} + if (protocol == "srt"){return 8889;} return 0; } diff --git a/lib/websocket.cpp b/lib/websocket.cpp index b53fae96..a0980f0b 100644 --- a/lib/websocket.cpp +++ b/lib/websocket.cpp @@ -3,56 +3,117 @@ #include "encode.h" #include "timing.h" #include "websocket.h" +#include "downloader.h" #ifdef SSL #include "mbedtls/sha1.h" #endif +// Takes the data from a Sec-WebSocket-Key header, and returns the corresponding data for a Sec-WebSocket-Accept header +static std::string calculateKeyAccept(std::string client_key){ + client_key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + mbedtls_sha1_context ctx; + unsigned char outdata[20]; + mbedtls_sha1_starts(&ctx); + mbedtls_sha1_update(&ctx, (const unsigned char *)client_key.data(), client_key.size()); + mbedtls_sha1_finish(&ctx, outdata); + return Encodings::Base64::encode(std::string((const char *)outdata, 20)); +} + namespace HTTP{ - Websocket::Websocket(Socket::Connection &c, HTTP::Parser &h) : C(c), H(h){ + /// Uses the referenced Socket::Connection to make use of an already connected Websocket. + Websocket::Websocket(Socket::Connection &c, bool client) : C(c){ + maskOut = client; + } + + /// Uses the referenced Socket::Connection to make a new Websocket by connecting to the given URL. + Websocket::Websocket(Socket::Connection &c, const HTTP::URL & url, std::map * headers) : C(c){ + HTTP::Downloader d; + + //Ensure our passed socket gets used by the downloader class + d.setSocket(&C); + + //Generate a random nonce based on the current process ID + //Note: This is not cryptographically secure, nor intended to be. + //It does make it possible to trace which stream came from which PID, if needed. + char nonce[16]; + unsigned int state = getpid(); + for (size_t i = 0; i < 16; ++i){nonce[i] = rand_r(&state) % 255;} + std::string handshakeKey = Encodings::Base64::encode(std::string(nonce, 16)); + + //Prepare the headers + d.setHeader("Connection", "Upgrade"); + d.setHeader("Upgrade", "websocket"); + d.setHeader("Sec-WebSocket-Version", "13"); + d.setHeader("Sec-WebSocket-Key", handshakeKey); + if (headers && headers->size()){ + for (std::map::iterator it = headers->begin(); it != headers->end(); ++it){ + d.setHeader(it->first, it->second); + } + } + if (!d.get(url) || d.getStatusCode() != 101 || !d.getHeader("Sec-WebSocket-Accept").size()){ + FAIL_MSG("Could not connect websocket to %s", url.getUrl().c_str()); + d.getSocket().close(); + C = d.getSocket(); + return; + } + +#ifdef SSL + std::string handshakeAccept = calculateKeyAccept(handshakeKey); + if (d.getHeader("Sec-WebSocket-Accept") != handshakeAccept){ + FAIL_MSG("WebSocket handshake failure: expected accept parameter %s but received %s", handshakeAccept.c_str(), d.getHeader("Sec-WebSocket-Accept").c_str()); + d.getSocket().close(); + C = d.getSocket(); + return; + } +#endif + + MEDIUM_MSG("Connected to websocket %s", url.getUrl().c_str()); + maskOut = true; + } + + /// Takes an incoming HTTP::Parser request for a Websocket, and turns it into one. + Websocket::Websocket(Socket::Connection &c, HTTP::Parser &h) : C(c){ frameType = 0; - std::string connHeader = H.GetHeader("Connection"); + maskOut = false; + std::string connHeader = h.GetHeader("Connection"); Util::stringToLower(connHeader); if (connHeader.find("upgrade") == std::string::npos){ FAIL_MSG("Could not negotiate websocket, connection header incorrect (%s).", connHeader.c_str()); C.close(); return; } - std::string upgradeHeader = H.GetHeader("Upgrade"); + std::string upgradeHeader = h.GetHeader("Upgrade"); Util::stringToLower(upgradeHeader); if (upgradeHeader != "websocket"){ FAIL_MSG("Could not negotiate websocket, upgrade header incorrect (%s).", upgradeHeader.c_str()); C.close(); return; } - if (H.GetHeader("Sec-WebSocket-Version") != "13"){ + if (h.GetHeader("Sec-WebSocket-Version") != "13"){ FAIL_MSG("Could not negotiate websocket, version incorrect (%s).", - H.GetHeader("Sec-WebSocket-Version").c_str()); + h.GetHeader("Sec-WebSocket-Version").c_str()); C.close(); return; } - std::string client_key = H.GetHeader("Sec-WebSocket-Key"); +#ifdef SSL + std::string client_key = h.GetHeader("Sec-WebSocket-Key"); if (!client_key.size()){ FAIL_MSG("Could not negotiate websocket, missing key!"); C.close(); return; } - client_key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +#endif - H.Clean(); - H.setCORSHeaders(); - H.SetHeader("Upgrade", "websocket"); - H.SetHeader("Connection", "Upgrade"); + h.Clean(); + h.setCORSHeaders(); + h.SetHeader("Upgrade", "websocket"); + h.SetHeader("Connection", "Upgrade"); #ifdef SSL - mbedtls_sha1_context ctx; - unsigned char outdata[20]; - mbedtls_sha1_starts(&ctx); - mbedtls_sha1_update(&ctx, (const unsigned char *)client_key.data(), client_key.size()); - mbedtls_sha1_finish(&ctx, outdata); - H.SetHeader("Sec-WebSocket-Accept", Encodings::Base64::encode(std::string((const char *)outdata, 20))); + h.SetHeader("Sec-WebSocket-Accept", calculateKeyAccept(client_key)); #endif // H.SetHeader("Sec-WebSocket-Protocol", "json"); - H.SendResponse("101", "Websocket away!", C); + h.SendResponse("101", "Websocket away!", C); } /// Loops calling readFrame until the connection is closed, sleeping in between reads if needed. @@ -138,27 +199,47 @@ namespace HTTP{ } } - void Websocket::sendFrame(const char *data, unsigned int len, unsigned int frameType){ - char header[10]; + void Websocket::sendFrameHead(unsigned int len, unsigned int frameType){ header[0] = 0x80 + frameType; // FIN + frameType + headLen = 2; if (len < 126){ header[1] = len; - C.SendNow(header, 2); }else{ if (len <= 0xFFFF){ header[1] = 126; Bit::htobs(header + 2, len); - C.SendNow(header, 4); + headLen = 4; }else{ header[1] = 127; Bit::htobll(header + 2, len); - C.SendNow(header, 10); + headLen = 10; } } - C.SendNow(data, len); + if (maskOut){ + header[1] |= 128; + header[headLen++] = 0; + header[headLen++] = 0; + header[headLen++] = 0; + header[headLen++] = 0; + } + C.SendNow(header, headLen); + dataCtr = 0; } - void Websocket::sendFrame(const std::string &data){sendFrame(data.data(), data.size());} + void Websocket::sendFrameData(const char *data, unsigned int len){ + C.SendNow(data, len); + dataCtr += len; + } + + void Websocket::sendFrame(const char *data, unsigned int len, unsigned int frameType){ + sendFrameHead(len, frameType); + sendFrameData(data, len); + } + + void Websocket::sendFrame(const std::string &data){ + sendFrameHead(data.size()); + sendFrameData(data.data(), data.size()); + } Websocket::operator bool() const{return C;} diff --git a/lib/websocket.h b/lib/websocket.h index fb05f751..819f18c9 100644 --- a/lib/websocket.h +++ b/lib/websocket.h @@ -1,5 +1,6 @@ #pragma once #include "http_parser.h" +#include "url.h" #include "socket.h" #include "util.h" @@ -7,16 +8,23 @@ namespace HTTP{ class Websocket{ public: Websocket(Socket::Connection &c, HTTP::Parser &h); + Websocket(Socket::Connection &c, const HTTP::URL & url, std::map * headers = 0); + Websocket(Socket::Connection &c, bool client); operator bool() const; bool readFrame(); bool readLoop(); void sendFrame(const char *data, unsigned int len, unsigned int frameType = 1); + void sendFrameHead(unsigned int len, unsigned int frameType = 1); + void sendFrameData(const char *data, unsigned int len); void sendFrame(const std::string &data); Util::ResizeablePointer data; uint8_t frameType; private: + char header[14];///< Header used for currently sending frame, if any + size_t headLen; ///< Length of header used for currently sending frame + size_t dataCtr; ///< Tracks payload bytes sent since frame start + bool maskOut; ///< True if masking is used for output Socket::Connection &C; - HTTP::Parser &H; }; }// namespace HTTP diff --git a/test/websocket.cpp b/test/websocket.cpp new file mode 100644 index 00000000..43ad99b2 --- /dev/null +++ b/test/websocket.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include + +int main(int argc, char **argv){ + Util::Config c(argv[0]); + + JSON::Value option; + option["arg_num"] = 1; + option["arg"] = "string"; + option["help"] = "URL to retrieve"; + c.addOption("url", option); + if (!(c.parseArgs(argc, argv))){return 1;} + + Util::redirectLogsIfNeeded(); + Socket::Connection C; + HTTP::Websocket ws(C, HTTP::URL(c.getString("url"))); + if (!ws){return 1;} + while (ws){ + if (!ws.readFrame()){ + Util::sleep(100); + continue; + } + switch (ws.frameType){ + case 1: + std::cout << "Text frame (" << ws.data.size() << "b):" << std::endl + << std::string(ws.data, ws.data.size()) << std::endl; + break; + case 2:{ + std::cout << "Binary frame (" << ws.data.size() << "b):" << std::endl; + size_t counter = 0; + for (size_t i = 0; i < ws.data.size(); ++i){ + std::cout << std::hex << std::setw(2) << std::setfill('0') << (int)(ws.data[i] & 0xff) << " "; + if ((counter) % 32 == 31){std::cout << std::endl;} + counter++; + } + std::cout << std::endl; + }break; + case 8: + std::cout << "Connection close frame" << std::endl; + C.close(); + break; + case 9: std::cout << "Ping frame" << std::endl; break; + case 10: std::cout << "Pong frame" << std::endl; break; + default: std::cout << "Unknown frame (" << (int)ws.frameType << ")" << std::endl; break; + } + } + return 0; +}