diff --git a/CMakeLists.txt b/CMakeLists.txt index d1457feb..c8c90ed9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -136,6 +136,7 @@ set(libHeaders lib/riff.h lib/ebml.h lib/ebml_socketglue.h + lib/websocket.h ) ######################################## @@ -178,6 +179,7 @@ add_library (mist lib/riff.cpp lib/ebml.cpp lib/ebml_socketglue.cpp + lib/websocket.cpp ) if (NOT APPLE) set (LIBRT -lrt) diff --git a/lib/websocket.cpp b/lib/websocket.cpp new file mode 100644 index 00000000..6b8ef2a4 --- /dev/null +++ b/lib/websocket.cpp @@ -0,0 +1,166 @@ +#include "websocket.h" +#include "defines.h" +#include "encode.h" +#include "bitfields.h" +#include "timing.h" +#ifdef SSL +#include "mbedtls/sha1.h" +#endif + +namespace HTTP{ + + Websocket::Websocket(Socket::Connection &c, HTTP::Parser &h) : C(c), H(h){ + frameType = 0; + if (H.GetHeader("Connection").find("Upgrade") == std::string::npos){ + FAIL_MSG("Could not negotiate websocket, connection header incorrect (%s).", + H.GetHeader("Connection").c_str()); + C.close(); + return; + } + if (H.GetHeader("Upgrade") != "websocket"){ + FAIL_MSG("Could not negotiate websocket, upgrade header incorrect (%s).", + H.GetHeader("Upgrade").c_str()); + C.close(); + return; + } + if (H.GetHeader("Sec-WebSocket-Version") != "13"){ + FAIL_MSG("Could not negotiate websocket, version incorrect (%s).", + H.GetHeader("Sec-WebSocket-Version").c_str()); + C.close(); + return; + } + std::string client_key = H.GetHeader("Sec-WebSocket-Key"); + if (!client_key.size()){ + FAIL_MSG("Could not negotiate websocket, missing key!"); + C.close(); + return; + } + client_key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + H.Clean(); + H.setCORSHeaders(); + H.SetHeader("Upgrade", "websocket"); + H.SetHeader("Connection", "Upgrade"); +#ifdef SSL + mbedtls_sha1_context ctx; + unsigned char outdata[20]; + mbedtls_sha1_starts(&ctx); + mbedtls_sha1_update(&ctx, (const unsigned char*)client_key.data(), client_key.size()); + mbedtls_sha1_finish(&ctx, outdata); + H.SetHeader("Sec-WebSocket-Accept", Encodings::Base64::encode(std::string((const char*)outdata, 20))); +#endif + //H.SetHeader("Sec-WebSocket-Protocol", "json"); + H.SendResponse("101", "Websocket away!", C); + } + + /// Loops calling readFrame until the connection is closed, sleeping in between reads if needed. + bool Websocket::readLoop(){ + while (C){ + if (readFrame()){ + return true; + } + Util::sleep(500); + } + return false; + } + + /// Loops reading from the socket until either there is no more data ready or a whole frame was read. + bool Websocket::readFrame(){ + while(true){ + //Check if we can receive the minimum frame size (2 header bytes, 0 payload) + if (!C.Received().available(2)){ + if (C.spool()){continue;} + return false; + } + std::string head = C.Received().copy(2); + //Read masked bit and payload length + bool masked = head[1] & 0x80; + uint64_t payLen = head[1] & 0x7F; + uint32_t headSize = 2 + (masked?4:0) + (payLen==126?2:0) + (payLen==127?8:0); + if (headSize > 2){ + //Check if we can receive the whole header + if (!C.Received().available(headSize)){ + if (C.spool()){continue;} + return false; + } + //Read entire header, re-read real payload length + head = C.Received().copy(headSize); + if (payLen == 126){ + payLen = Bit::btohs(head.data()+2); + }else if (payLen == 127){ + payLen = Bit::btohll(head.data()+2); + } + } + //Check if we can receive the whole frame (header + payload) + if (!C.Received().available(headSize + payLen)){ + if (C.spool()){continue;} + return false; + } + C.Received().remove(headSize);//delete the header + std::string pl = C.Received().remove(payLen); + if (masked){ + //If masked, apply the mask to the payload + const char * mask = head.data() + headSize - 4;//mask is last 4 bytes of header + for (uint32_t i = 0; i < payLen; ++i){ + pl[i] ^= mask[i % 4]; + } + } + if ((head[0] & 0xF)){ + //Non-continuation + frameType = (head[0] & 0xF); + data.assign(pl.data(), pl.size()); + }else{ + //Continuation + data.append(pl.data(), pl.size()); + } + if (head[0] & 0x80){ + //FIN + switch (frameType){ + case 0x0://Continuation, should not happen + WARN_MSG("Received unknown websocket frame - ignoring"); + break; + case 0x8://Connection close + HIGH_MSG("Websocket close received"); + C.close(); + break; + case 0x9://Ping + HIGH_MSG("Websocket ping received"); + sendFrame(data, data.size(), 0xA);//send pong + break; + case 0xA://Pong + HIGH_MSG("Websocket pong received"); + break; + } + return true; + } + } + } + + void Websocket::sendFrame(const char * data, unsigned int len, unsigned int frameType){ + char header[10]; + header[0] = 0x80 + frameType;//FIN + frameType + if (len < 126){ + header[1] = len; + C.SendNow(header, 2); + }else{ + if (len <= 0xFFFF){ + header[1] = 126; + Bit::htobs(header+2, len); + C.SendNow(header, 4); + }else{ + header[1] = 127; + Bit::htobll(header+2, len); + C.SendNow(header, 10); + } + } + C.SendNow(data, len); + } + + void Websocket::sendFrame(const std::string & data){ + sendFrame(data.data(), data.size()); + } + + Websocket::operator bool() const{return C;} + +}// namespace HTTP + diff --git a/lib/websocket.h b/lib/websocket.h new file mode 100644 index 00000000..2c386b1e --- /dev/null +++ b/lib/websocket.h @@ -0,0 +1,22 @@ +#pragma once +#include "http_parser.h" +#include "socket.h" +#include "util.h" + +namespace HTTP{ + class Websocket{ + public: + Websocket(Socket::Connection &c, HTTP::Parser &h); + operator bool() const; + bool readFrame(); + bool readLoop(); + void sendFrame(const char * data, unsigned int len, unsigned int frameType = 1); + void sendFrame(const std::string & data); + Util::ResizeablePointer data; + uint8_t frameType; + private: + Socket::Connection &C; + HTTP::Parser &H; + }; +}// namespace HTTP +