From 3987cfec3faf9b038e5b0fc740cb3f2c608e8411 Mon Sep 17 00:00:00 2001 From: Thulinma Date: Fri, 29 Dec 2023 00:58:03 +0100 Subject: [PATCH] Support for WebRTC data tracks (output only, for now), rewrite of dTLS integration (now part of socket lib), support for multi-path WebRTC connections --- lib/dtls_srtp_handshake.cpp | 395 --------------- lib/dtls_srtp_handshake.h | 62 --- lib/meson.build | 3 +- lib/sdp_media.cpp | 57 ++- lib/sdp_media.h | 4 + lib/socket.cpp | 645 +++++++++++++++++++++--- lib/socket.h | 51 +- meson.build | 2 + src/output/meson.build | 1 + src/output/output.cpp | 6 +- src/output/output_webrtc.cpp | 800 +++++++++++++++++++++--------- src/output/output_webrtc.h | 60 ++- src/output/output_webrtc_srtp.cpp | 9 + src/output/output_webrtc_srtp.h | 2 + src/session.cpp | 12 +- subprojects/usrsctp.wrap | 5 + 16 files changed, 1303 insertions(+), 811 deletions(-) delete mode 100644 lib/dtls_srtp_handshake.cpp delete mode 100644 lib/dtls_srtp_handshake.h create mode 100644 subprojects/usrsctp.wrap diff --git a/lib/dtls_srtp_handshake.cpp b/lib/dtls_srtp_handshake.cpp deleted file mode 100644 index ea092d13..00000000 --- a/lib/dtls_srtp_handshake.cpp +++ /dev/null @@ -1,395 +0,0 @@ -#include "defines.h" -#include "dtls_srtp_handshake.h" -#include -#include - -/* Write mbedtls into a log file. */ -#define LOG_TO_FILE 0 -#if LOG_TO_FILE -#include -#endif - -/* ----------------------------------------- */ - -static void print_mbedtls_error(int r); -static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str); -static int on_mbedtls_wants_to_read(void *user, unsigned char *buf, - size_t len); /* Called when mbedtls wants to read data from e.g. a socket. */ -static int on_mbedtls_wants_to_write(void *user, const unsigned char *buf, - size_t len); /* Called when mbedtls wants to write data to e.g. a socket. */ - -/* ----------------------------------------- */ - -DTLSSRTPHandshake::DTLSSRTPHandshake() : cert(NULL), key(NULL), write_callback(NULL){ - memset((void *)&entropy_ctx, 0x00, sizeof(entropy_ctx)); - memset((void *)&rand_ctx, 0x00, sizeof(rand_ctx)); - memset((void *)&ssl_ctx, 0x00, sizeof(ssl_ctx)); - memset((void *)&ssl_conf, 0x00, sizeof(ssl_conf)); - memset((void *)&cookie_ctx, 0x00, sizeof(cookie_ctx)); - memset((void *)&timer_ctx, 0x00, sizeof(timer_ctx)); -} - -int DTLSSRTPHandshake::init(mbedtls_x509_crt *certificate, mbedtls_pk_context *privateKey, - int (*writeCallback)(const uint8_t *data, int *nbytes)){ - - int r = 0; - mbedtls_ssl_srtp_profile srtp_profiles[] ={MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80, - MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32}; - - if (!writeCallback){ - FAIL_MSG("No writeCallack function given."); - r = -3; - goto error; - } - - if (!certificate){ - FAIL_MSG("Given certificate is null."); - r = -5; - goto error; - } - - if (!privateKey){ - FAIL_MSG("Given key is null."); - r = -10; - goto error; - } - - cert = certificate; - key = privateKey; - - /* init the contexts */ - mbedtls_entropy_init(&entropy_ctx); - mbedtls_ctr_drbg_init(&rand_ctx); - mbedtls_ssl_init(&ssl_ctx); - mbedtls_ssl_config_init(&ssl_conf); - mbedtls_ssl_cookie_init(&cookie_ctx); - - /* seed and setup the random number generator */ - r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, - (const unsigned char *)"mist-srtp", 9); - if (0 != r){ - print_mbedtls_error(r); - r = -20; - goto error; - } - - /* load defaults into our ssl_conf */ - r = mbedtls_ssl_config_defaults(&ssl_conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM, - MBEDTLS_SSL_PRESET_DEFAULT); - if (0 != r){ - print_mbedtls_error(r); - r = -30; - goto error; - } - - mbedtls_ssl_conf_authmode(&ssl_conf, MBEDTLS_SSL_VERIFY_NONE); - mbedtls_ssl_conf_rng(&ssl_conf, mbedtls_ctr_drbg_random, &rand_ctx); - mbedtls_ssl_conf_dbg(&ssl_conf, print_mbedtls_debug_message, stdout); - mbedtls_ssl_conf_ca_chain(&ssl_conf, cert, NULL); - mbedtls_ssl_conf_cert_profile(&ssl_conf, &mbedtls_x509_crt_profile_default); - mbedtls_debug_set_threshold(10); - - /* enable SRTP */ - r = mbedtls_ssl_conf_dtls_srtp_protection_profiles(&ssl_conf, srtp_profiles, - sizeof(srtp_profiles) / sizeof(srtp_profiles[0])); - if (0 != r){ - print_mbedtls_error(r); - r = -40; - goto error; - } - - /* cert certificate chain + key, so we can verify the client-hello signed data */ - r = mbedtls_ssl_conf_own_cert(&ssl_conf, cert, key); - if (0 != r){ - print_mbedtls_error(r); - r = -50; - goto error; - } - - /* cookie setup (e.g. to prevent ddos). */ - r = mbedtls_ssl_cookie_setup(&cookie_ctx, mbedtls_ctr_drbg_random, &rand_ctx); - if (0 != r){ - print_mbedtls_error(r); - r = -60; - goto error; - } - - /* register callbacks for dtls cookies (server only). */ - mbedtls_ssl_conf_dtls_cookies(&ssl_conf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &cookie_ctx); - - /* setup the ssl context for use. note that ssl_conf will be referenced internall by the context and therefore should be kept around. */ - r = mbedtls_ssl_setup(&ssl_ctx, &ssl_conf); - if (0 != r){ - print_mbedtls_error(r); - r = -70; - goto error; - } - - /* set bio handlers */ - mbedtls_ssl_set_bio(&ssl_ctx, (void *)this, on_mbedtls_wants_to_write, on_mbedtls_wants_to_read, NULL); - - /* set temp id, just adds some exta randomness */ - { - std::string remote_id = "mist"; - r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char *)remote_id.c_str(), - remote_id.size()); - if (0 != r){ - print_mbedtls_error(r); - r = -80; - goto error; - } - } - - /* set timer callbacks */ - mbedtls_ssl_set_timer_cb(&ssl_ctx, &timer_ctx, mbedtls_timing_set_delay, mbedtls_timing_get_delay); - - write_callback = writeCallback; - -error: - - if (r < 0){shutdown();} - - return r; -} - -int DTLSSRTPHandshake::shutdown(){ - - /* cleanup the refs from the settings. */ - cert = NULL; - key = NULL; - buffer.clear(); - cipher.clear(); - remote_key.clear(); - remote_salt.clear(); - local_key.clear(); - local_salt.clear(); - - /* free our contexts; we do not free the `settings.cert` and `settings.key` as they are owned by the user of this class. */ - mbedtls_entropy_free(&entropy_ctx); - mbedtls_ctr_drbg_free(&rand_ctx); - mbedtls_ssl_free(&ssl_ctx); - mbedtls_ssl_config_free(&ssl_conf); - mbedtls_ssl_cookie_free(&cookie_ctx); - - return 0; -} - -/* ----------------------------------------- */ - -int DTLSSRTPHandshake::parse(const uint8_t *data, size_t nbytes){ - - if (NULL == data){ - ERROR_MSG("Given `data` is NULL."); - return -1; - } - - if (0 == nbytes){ - ERROR_MSG("Given nbytes is 0."); - return -2; - } - - if (MBEDTLS_SSL_HANDSHAKE_OVER == ssl_ctx.state){ - ERROR_MSG("Already finished the handshake."); - return -3; - } - - /* copy incoming data into a temporary buffer which is read via our `bio` read function. */ - int r = 0; - std::copy(data, data + nbytes, std::back_inserter(buffer)); - - do{ - - r = mbedtls_ssl_handshake(&ssl_ctx); - - switch (r){ - /* 0 = handshake done. */ - case 0:{ - if (0 != extractKeyingMaterial()){ - ERROR_MSG("Failed to extract keying material after handshake was done."); - return -2; - } - return 0; - } - /* see the dtls server example; this is used to prevent certain attacks (ddos) */ - case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:{ - if (0 != resetSession()){ - ERROR_MSG( - "Failed to reset the session which is necessary when we need to verify the HELLO."); - return -3; - } - break; - } - case MBEDTLS_ERR_SSL_WANT_READ:{ - DONTEVEN_MSG( - "mbedtls wants a bit more data before it can continue parsing the DTLS handshake."); - break; - } - default:{ - ERROR_MSG("A serious mbedtls error occured."); - print_mbedtls_error(r); - return -2; - } - } - }while (MBEDTLS_ERR_SSL_WANT_WRITE == r); - - return 0; -} - -/* ----------------------------------------- */ - -int DTLSSRTPHandshake::resetSession(){ - - std::string remote_id = "mist"; /* @todo for now we hardcoded this... */ - int r = 0; - - r = mbedtls_ssl_session_reset(&ssl_ctx); - if (0 != r){ - print_mbedtls_error(r); - return -1; - } - - r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char *)remote_id.c_str(), - remote_id.size()); - if (0 != r){ - print_mbedtls_error(r); - return -2; - } - - buffer.clear(); - - return 0; -} - -/* - master key is 128 bits => 16 bytes. - master salt is 112 bits => 14 bytes -*/ -int DTLSSRTPHandshake::extractKeyingMaterial(){ - - int r = 0; - uint8_t keying_material[MBEDTLS_DTLS_SRTP_MAX_KEY_MATERIAL_LENGTH] ={}; - size_t keying_material_len = sizeof(keying_material); - - r = mbedtls_ssl_get_dtls_srtp_key_material(&ssl_ctx, keying_material, &keying_material_len); - if (0 != r){ - print_mbedtls_error(r); - return -1; - } - - /* @todo following code is for server mode only */ - mbedtls_ssl_srtp_profile srtp_profile = mbedtls_ssl_get_dtls_srtp_protection_profile(&ssl_ctx); - switch (srtp_profile){ - case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80:{ - cipher = "SRTP_AES128_CM_SHA1_80"; - break; - } - case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32:{ - cipher = "SRTP_AES128_CM_SHA1_32"; - break; - } - default:{ - ERROR_MSG("Unhandled SRTP profile, cannot extract keying material."); - return -6; - } - } - - remote_key.assign((char *)(&keying_material[0]) + 0, 16); - local_key.assign((char *)(&keying_material[0]) + 16, 16); - remote_salt.assign((char *)(&keying_material[0]) + 32, 14); - local_salt.assign((char *)(&keying_material[0]) + 46, 14); - - DONTEVEN_MSG("Extracted the DTLS SRTP keying material with cipher %s.", cipher.c_str()); - DONTEVEN_MSG("Remote DTLS SRTP key size is %zu.", remote_key.size()); - DONTEVEN_MSG("Remote DTLS SRTP salt size is %zu.", remote_salt.size()); - DONTEVEN_MSG("Local DTLS SRTP key size is %zu.", local_key.size()); - DONTEVEN_MSG("Local DTLS SRTP salt size is %zu.", local_salt.size()); - - return 0; -} - -/* ----------------------------------------- */ - -/* - - This function is called by mbedtls whenever it wants to read - some data. The documentation states the following: "For DTLS, - you need to provide either a non-NULL f_recv_timeout - callback, or a f_recv that doesn't block." As this - implementation is completely decoupled from any I/O and uses - a "push" model instead of a "pull" model we have to copy new - input bytes into a temporary buffer (see parse), but we act - as if we were using a non-blocking socket, which means: - - - we return MBETLS_ERR_SSL_WANT_READ when there is no data left to read - - when there is data in our temporary buffer, we read from that - -*/ -static int on_mbedtls_wants_to_read(void *user, unsigned char *buf, size_t len){ - - DTLSSRTPHandshake *hs = static_cast(user); - if (NULL == hs){ - ERROR_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake."); - return -1; - } - - /* figure out how much we can read. */ - if (hs->buffer.size() == 0){return MBEDTLS_ERR_SSL_WANT_READ;} - - size_t nbytes = hs->buffer.size(); - if (nbytes > len){nbytes = len;} - - /* "read" into the given buffer. */ - memcpy(buf, &hs->buffer[0], nbytes); - hs->buffer.erase(hs->buffer.begin(), hs->buffer.begin() + nbytes); - - return (int)nbytes; -} - -static int on_mbedtls_wants_to_write(void *user, const unsigned char *buf, size_t len){ - - DTLSSRTPHandshake *hs = static_cast(user); - if (!hs){ - FAIL_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake."); - return -1; - } - - if (!hs->write_callback){ - FAIL_MSG("The `write_callback` member is NULL."); - return -2; - } - - int nwritten = (int)len; - if (0 != hs->write_callback(buf, &nwritten)){ - FAIL_MSG("Failed to write some DTLS handshake data."); - return -3; - } - - if (nwritten != (int)len){ - FAIL_MSG("The DTLS-SRTP handshake listener MUST write all the data."); - return -4; - } - - return nwritten; -} - -/* ----------------------------------------- */ - -static void print_mbedtls_error(int r){ - char buf[1024] ={}; - mbedtls_strerror(r, buf, sizeof(buf)); - ERROR_MSG("mbedtls error: %s", buf); -} - -static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str){ - DONTEVEN_MSG("%s:%04d: %.*s", file, line, (int)strlen(str) - 1, str); - -#if LOG_TO_FILE - static std::ofstream ofs; - if (!ofs.is_open()){ofs.open("mbedtls.log", std::ios::out);} - if (!ofs.is_open()){return;} - ofs << str; - ofs.flush(); -#endif -} - -/* ---------------------------------------- */ diff --git a/lib/dtls_srtp_handshake.h b/lib/dtls_srtp_handshake.h deleted file mode 100644 index 7250167e..00000000 --- a/lib/dtls_srtp_handshake.h +++ /dev/null @@ -1,62 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -/* ----------------------------------------- */ - -class DTLSSRTPHandshake{ -public: - DTLSSRTPHandshake(); - int init(mbedtls_x509_crt *certificate, mbedtls_pk_context *privateKey, - int (*writeCallback)(const uint8_t *data, - int *nbytes)); // writeCallback should return 0 on succes < 0 on error. - // nbytes holds the number of bytes to be sent and needs - // to be set to the number of bytes actually sent. - int shutdown(); - int parse(const uint8_t *data, size_t nbytes); - bool hasKeyingMaterial(); - -private: - int extractKeyingMaterial(); - int resetSession(); - -private: - mbedtls_x509_crt *cert; /* Certificate, we do not own the key. Make sure it's kept alive during the livetime of this class instance. */ - mbedtls_pk_context *key; /* Private key, we do not own the key. Make sure it's kept alive during the livetime of this class instance. */ - mbedtls_entropy_context entropy_ctx; - mbedtls_ctr_drbg_context rand_ctx; - mbedtls_ssl_context ssl_ctx; - mbedtls_ssl_config ssl_conf; - mbedtls_ssl_cookie_ctx cookie_ctx; - mbedtls_timing_delay_context timer_ctx; - -public: - int (*write_callback)(const uint8_t *data, int *nbytes); - std::deque buffer; /* Accessed from BIO callbback. We copy the bytes you pass into `parse()` into this - temporary buffer which is read by a trigger to `mbedlts_ssl_handshake()`. */ - std::string cipher; /* selected SRTP cipher. */ - std::string remote_key; - std::string remote_salt; - std::string local_key; - std::string local_salt; -}; - -/* ----------------------------------------- */ - -inline bool DTLSSRTPHandshake::hasKeyingMaterial(){ - return (0 != remote_key.size() && 0 != remote_salt.size() && 0 != local_key.size() && - 0 != local_salt.size()); -} - -/* ----------------------------------------- */ diff --git a/lib/meson.build b/lib/meson.build index a9708ab6..424cb090 100644 --- a/lib/meson.build +++ b/lib/meson.build @@ -12,7 +12,6 @@ headers = [ 'comms.h', 'config.h', 'defines.h', - 'dtls_srtp_handshake.h', 'dtsc.h', 'encryption.h', 'flv_tag.h', @@ -69,7 +68,7 @@ install_headers(headers, subdir: 'mist') extra_code = [] if usessl - extra_code += ['dtls_srtp_handshake.cpp', 'stun.cpp', 'certificate.cpp', 'encryption.cpp',] + extra_code += ['stun.cpp', 'certificate.cpp', 'encryption.cpp',] endif libmist = library('mist', diff --git a/lib/sdp_media.cpp b/lib/sdp_media.cpp index 2ab03a08..3dd24fdd 100644 --- a/lib/sdp_media.cpp +++ b/lib/sdp_media.cpp @@ -35,6 +35,8 @@ namespace SDP{ return "AAC"; }else if (codec == "OPUS"){ return "opus"; + }else if (codec == "WEBRTC-DATACHANNEL"){ + return "JSON"; }else if (codec == "ULPFEC"){ return ""; }else if (codec == "RED"){ @@ -67,6 +69,10 @@ namespace SDP{ return "MPA"; }else if (codec == "AAC"){ return "MPEG4-GENERIC"; + }else if (codec == "JSON"){ + return "WEBRTC-DATACHANNEL"; + }else if (codec == "subtitle"){ + return "WEBRTC-DATACHANNEL"; }else if (codec == "opus"){ return "OPUS"; }else if (codec == "ULPFEC"){ @@ -277,6 +283,8 @@ namespace SDP{ type = "audio"; }else if (words[0] == "m=video"){ type = "video"; + }else if (words[0] == "m=application"){ + type = "meta"; }else{ ERROR_MSG("Unhandled media type: `%s`.", words[0].c_str()); return false; @@ -289,6 +297,7 @@ namespace SDP{ for (size_t i = 3; i < words.size(); ++i){ SDP::MediaFormat format; format.payloadType = JSON::Value(words[i]).asInt(); + if (words[i] == "webrtc-datachannel"){format.encodingName = "WEBRTC-DATACHANNEL";} formats[format.payloadType] = format; if (!payloadTypes.empty()){payloadTypes += " ";} payloadTypes += words[i]; @@ -711,17 +720,11 @@ namespace SDP{ static bool sdp_get_name_value_from_varval(const std::string &str, std::string &var, std::string &value){ if (str.empty()){ - ERROR_MSG("Cannot get `name` and `value` from string because the given string is empty. " - "String is: `%s`", - str.c_str()); return false; } size_t pos = str.find("="); if (pos == std::string::npos){ - WARN_MSG("Cannot get `name` and `value` from string becuase it doesn't contain a `=` sign. " - "String is: `%s`. Returning the string as is.", - str.c_str()); value = str; return true; } @@ -776,7 +779,7 @@ namespace SDP{ } Answer::Answer() - : isAudioEnabled(false), isVideoEnabled(false), candidatePort(0), + : isAudioEnabled(false), isVideoEnabled(false), isMetaEnabled(false), candidatePort(0), videoLossPrevention(SDP_LOSS_PREVENTION_NONE){} bool Answer::parseOffer(const std::string &sdp){ @@ -817,6 +820,15 @@ namespace SDP{ return true; } + bool Answer::enableMeta(const std::string &codecName){ + if (!enableMedia("meta", codecName, answerMetaMedia, answerMetaFormat)){ + DONTEVEN_MSG("Not enabling meta."); + return false; + } + isMetaEnabled = true; + return true; + } + void Answer::setCandidate(const std::string &ip, uint16_t port){ if (ip.empty()){WARN_MSG("Given candidate IP is empty. It's fine if you want to unset it.");} candidateIP = ip; @@ -934,7 +946,7 @@ namespace SDP{ bool isEnabled = false; std::vector supportedPayloadTypes; - if (type != "audio" && type != "video"){continue;} + if (type != "audio" && type != "video" && type != "meta"){continue;} // port = 9 (default), port = 0 (disable this media) if (type == "audio"){ @@ -947,6 +959,10 @@ namespace SDP{ fmtMedia = &answerVideoFormat; fmtRED = media->getFormatForEncodingName("RED"); fmtULPFEC = media->getFormatForEncodingName("ULPFEC"); + }else if (type == "meta"){ + isEnabled = isMetaEnabled; + media = &answerMetaMedia; + fmtMedia = &answerMetaFormat; } if (!media){ @@ -975,10 +991,17 @@ namespace SDP{ } std::string payloadTypes = ss.str(); + std::string protocol = "UDP/TLS/RTP/SAVPF"; + if (type == "meta"){ + protocol = "UDP/DTLS/SCTP"; + payloadTypes = "webrtc-datachannel"; + type = "application"; + } + if (isEnabled){ - addLine("m=%s 9 UDP/TLS/RTP/SAVPF %s", type.c_str(), payloadTypes.c_str()); + addLine("m=%s 9 %s %s", type.c_str(), protocol.c_str(), payloadTypes.c_str()); }else{ - addLine("m=%s %u UDP/TLS/RTP/SAVPF %s", type.c_str(), 0, mediaOffer.payloadTypes.c_str()); + addLine("m=%s %u %s %s", type.c_str(), 0, protocol.c_str(), mediaOffer.payloadTypes.c_str()); } addLine("c=IN IP4 0.0.0.0"); @@ -996,9 +1019,14 @@ namespace SDP{ addLine("a=fingerprint:sha-256 %s", fingerprint.c_str()); addLine("a=ice-ufrag:%s", fmtMedia->iceUFrag.c_str()); addLine("a=ice-pwd:%s", fmtMedia->icePwd.c_str()); - addLine("a=rtcp-mux"); - addLine("a=rtcp-rsize"); - addLine("a=%s", fmtMedia->rtpmap.c_str()); + if (type == "application"){ + addLine("a=sctp-port:5000"); + addLine("a=max-message-size:262144"); + }else{ + addLine("a=rtcp-mux"); + addLine("a=rtcp-rsize"); + addLine("a=%s", fmtMedia->rtpmap.c_str()); + } // BEGIN FEC/RTX: testing with just FEC or RTX if ((videoLossPrevention & SDP_LOSS_PREVENTION_ULPFEC) && fmtRED && fmtULPFEC){ @@ -1136,14 +1164,11 @@ namespace SDP{ return false; } - INFO_MSG("Enabling media for codec: %s", format->encodingName.c_str()); - outMedia = *media; outFormat = *format; outFormat.rtcpFormats.clear(); outFormat.icePwd = generateIcePwd(); outFormat.iceUFrag = generateIceUFrag(); - return true; } diff --git a/lib/sdp_media.h b/lib/sdp_media.h index 88886ab0..e3b96711 100644 --- a/lib/sdp_media.h +++ b/lib/sdp_media.h @@ -167,6 +167,7 @@ namespace SDP{ bool hasAudio(); ///< Check if the offer has audio. bool enableVideo(const std::string &codecName); bool enableAudio(const std::string &codecName); + bool enableMeta(const std::string &codecName); void setCandidate(const std::string &ip, uint16_t port); void setFingerprint(const std::string &fingerprintSha); ///< Set the SHA265 that represents the ///< certificate that is used with DTLS. @@ -189,10 +190,13 @@ namespace SDP{ SDP::Session sdpOffer; SDP::Media answerVideoMedia; SDP::Media answerAudioMedia; + SDP::Media answerMetaMedia; SDP::MediaFormat answerVideoFormat; SDP::MediaFormat answerAudioFormat; + SDP::MediaFormat answerMetaFormat; bool isAudioEnabled; bool isVideoEnabled; + bool isMetaEnabled; std::string candidateIP; ///< We use rtcp-mux and BUNDLE; so only one candidate necessary. uint16_t candidatePort; ///< We use rtcp-mux and BUNDLE; so only one candidate necessary. std::string fingerprint; diff --git a/lib/socket.cpp b/lib/socket.cpp index 1b4eb4f3..f4bcc883 100644 --- a/lib/socket.cpp +++ b/lib/socket.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #define BUFFER_BLOCKSIZE 4096 // set buffer blocksize to 4KiB @@ -1603,35 +1604,197 @@ int Socket::Server::getSocket(){ return sock; } + +static int dTLS_recv(void *s, unsigned char *buf, size_t len){ + return ((Socket::UDPConnection*)s)->dTLSRead(buf, len); +} + +static int dTLS_send(void *s, const unsigned char *buf, size_t len){ + return ((Socket::UDPConnection*)s)->dTLSWrite(buf, len); +} + + /// Create a new UDP Socket. /// Will attempt to create an IPv6 UDP socket, on fail try a IPV4 UDP socket. /// If both fail, prints an DLVL_FAIL debug message. /// \param nonblock Whether the socket should be nonblocking. Socket::UDPConnection::UDPConnection(bool nonblock){ + init(nonblock); +}// Socket::UDPConnection UDP Contructor + +void Socket::UDPConnection::init(bool _nonblock, int _family){ lastPace = 0; boundPort = 0; - family = AF_INET6; - sock = socket(AF_INET6, SOCK_DGRAM, 0); - if (sock == -1){ + family = _family; + hasDTLS = false; + isConnected = false; + wasEncrypted = false; + pretendReceive = false; + sock = socket(family, SOCK_DGRAM, 0); + if (sock == -1 && family == AF_INET6){ sock = socket(AF_INET, SOCK_DGRAM, 0); family = AF_INET; } if (sock == -1){ FAIL_MSG("Could not create UDP socket: %s", strerror(errno)); }else{ - if (nonblock){setBlocking(!nonblock);} + isBlocking = !_nonblock; + if (_nonblock){setBlocking(!_nonblock);} checkRecvBuf(); } + + { + // Allow address re-use + int on = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + } + up = 0; down = 0; destAddr = 0; destAddr_size = 0; + recvAddr = 0; + recvAddr_size = 0; + hasReceiveData = false; #ifdef __CYGWIN__ data.allocate(SOCKETSIZE); #else data.allocate(2048); #endif -}// Socket::UDPConnection UDP Contructor +} + +void Socket::UDPConnection::initDTLS(mbedtls_x509_crt *cert, mbedtls_pk_context *key){ + hasDTLS = true; + nextDTLSRead = 0; + nextDTLSReadLen = 0; + + int r = 0; + char mbedtls_msg[1024]; + + // Null out the contexts before use + memset((void *)&entropy_ctx, 0x00, sizeof(entropy_ctx)); + memset((void *)&rand_ctx, 0x00, sizeof(rand_ctx)); + memset((void *)&ssl_ctx, 0x00, sizeof(ssl_ctx)); + memset((void *)&ssl_conf, 0x00, sizeof(ssl_conf)); + memset((void *)&cookie_ctx, 0x00, sizeof(cookie_ctx)); + memset((void *)&timer_ctx, 0x00, sizeof(timer_ctx)); + // Initialize contexts + mbedtls_entropy_init(&entropy_ctx); + mbedtls_ctr_drbg_init(&rand_ctx); + mbedtls_ssl_init(&ssl_ctx); + mbedtls_ssl_config_init(&ssl_conf); + mbedtls_ssl_cookie_init(&cookie_ctx); + + /* seed and setup the random number generator */ + r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, (const unsigned char *)"mist-srtp", 9); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + FAIL_MSG("dTLS could not init drbg seed: %s", mbedtls_msg); + return; + } + + /* load defaults into our ssl_conf */ + r = mbedtls_ssl_config_defaults(&ssl_conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + FAIL_MSG("dTLS could not set defaults: %s", mbedtls_msg); + return; + } + + mbedtls_ssl_conf_authmode(&ssl_conf, MBEDTLS_SSL_VERIFY_NONE); + mbedtls_ssl_conf_rng(&ssl_conf, mbedtls_ctr_drbg_random, &rand_ctx); + mbedtls_ssl_conf_ca_chain(&ssl_conf, cert, NULL); + mbedtls_ssl_conf_cert_profile(&ssl_conf, &mbedtls_x509_crt_profile_default); + //mbedtls_ssl_conf_dbg(&ssl_conf, print_mbedtls_debug_message, stdout); + //mbedtls_debug_set_threshold(10); + + // enable SRTP support (non-fatal on error) + mbedtls_ssl_srtp_profile srtpPro[] ={MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80, MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32}; + r = mbedtls_ssl_conf_dtls_srtp_protection_profiles(&ssl_conf, srtpPro, sizeof(srtpPro) / sizeof(srtpPro[0])); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + WARN_MSG("dTLS could not set SRTP profiles: %s", mbedtls_msg); + } + + /* cert certificate chain + key, so we can verify the client-hello signed data */ + r = mbedtls_ssl_conf_own_cert(&ssl_conf, cert, key); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + FAIL_MSG("dTLS could not set certificate: %s", mbedtls_msg); + return; + } + + // cookie setup (to prevent ddos, server-only) + r = mbedtls_ssl_cookie_setup(&cookie_ctx, mbedtls_ctr_drbg_random, &rand_ctx); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + FAIL_MSG("dTLS could not set SSL cookie: %s", mbedtls_msg); + return; + } + mbedtls_ssl_conf_dtls_cookies(&ssl_conf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &cookie_ctx); + + // setup the ssl context + r = mbedtls_ssl_setup(&ssl_ctx, &ssl_conf); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + FAIL_MSG("dTLS could not setup: %s", mbedtls_msg); + return; + } + + // set input/output callbacks + mbedtls_ssl_set_bio(&ssl_ctx, (void *)this, dTLS_send, dTLS_recv, NULL); + mbedtls_ssl_set_timer_cb(&ssl_ctx, &timer_ctx, mbedtls_timing_set_delay, mbedtls_timing_get_delay); + + // set transport ID (non-fatal on error) + r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char *)"mist", 4); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + WARN_MSG("dTLS could not set transport ID: %s", mbedtls_msg); + } +} + +void Socket::UDPConnection::deinitDTLS(){ + if (hasDTLS){ + mbedtls_entropy_free(&entropy_ctx); + mbedtls_ctr_drbg_free(&rand_ctx); + mbedtls_ssl_free(&ssl_ctx); + mbedtls_ssl_config_free(&ssl_conf); + mbedtls_ssl_cookie_free(&cookie_ctx); + hasDTLS = true; + } +} + +int Socket::UDPConnection::dTLSRead(unsigned char *buf, size_t _len){ + if (!nextDTLSReadLen){return MBEDTLS_ERR_SSL_WANT_READ;} + size_t len = _len; + if (len > nextDTLSReadLen){len = nextDTLSReadLen;} + memcpy(buf, nextDTLSRead, len); + nextDTLSReadLen = 0; + return len; +} + +int Socket::UDPConnection::dTLSWrite(const unsigned char *buf, size_t len){ + sendPaced((const char *)buf, len, false); + return len; +} + +void Socket::UDPConnection::dTLSReset(){ + char mbedtls_msg[1024]; + int r = mbedtls_ssl_session_reset(&ssl_ctx); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + FAIL_MSG("dTLS could not reset session: %s", mbedtls_msg); + return; + } + + // set transport ID (non-fatal on error) + r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char *)"mist", 4); + if (r){ + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + WARN_MSG("dTLS could not set transport ID: %s", mbedtls_msg); + } +} ///Checks if the UDP receive buffer is at least 1 mbyte, attempts to increase and warns user through log message on failure. void Socket::UDPConnection::checkRecvBuf(){ @@ -1681,27 +1844,23 @@ void Socket::UDPConnection::checkRecvBuf(){ /// Copies a UDP socket, re-allocating local copies of any needed structures. /// The data/data_size/data_len variables are *not* copied over. Socket::UDPConnection::UDPConnection(const UDPConnection &o){ - lastPace = 0; - boundPort = 0; - family = AF_INET6; - sock = socket(AF_INET6, SOCK_DGRAM, 0); - if (sock == -1){ - sock = socket(AF_INET, SOCK_DGRAM, 0); - family = AF_INET; - } - if (sock == -1){FAIL_MSG("Could not create UDP socket: %s", strerror(errno));} - checkRecvBuf(); - up = 0; - down = 0; + init(!o.isBlocking, o.family); + INFO_MSG("Copied socket of type %s", addrFam(o.family)); if (o.destAddr && o.destAddr_size){ destAddr = malloc(o.destAddr_size); destAddr_size = o.destAddr_size; if (destAddr){memcpy(destAddr, o.destAddr, o.destAddr_size);} - }else{ - destAddr = 0; - destAddr_size = 0; } - data.allocate(2048); + if (o.recvAddr && o.recvAddr_size){ + recvAddr = malloc(o.recvAddr_size); + recvAddr_size = o.recvAddr_size; + if (recvAddr){memcpy(recvAddr, o.recvAddr, o.recvAddr_size);} + } + if (o.data.size()){ + data.assign(o.data, o.data.size()); + pretendReceive = true; + } + hasReceiveData = o.hasReceiveData; } /// Close the UDP socket @@ -1720,8 +1879,35 @@ Socket::UDPConnection::~UDPConnection(){ free(destAddr); destAddr = 0; } + if (recvAddr){ + free(recvAddr); + recvAddr = 0; + } + deinitDTLS(); } + +bool Socket::UDPConnection::operator==(const Socket::UDPConnection& b) const{ + // UDP sockets are equal if they refer to the same underlying socket or are both closed + if (sock == b.sock){return true;} + // If either is closed (and the other is not), not equal. + if (sock == -1 || b.sock == -1){return false;} + size_t recvSize = recvAddr_size; + if (b.recvAddr_size < recvSize){recvSize = b.recvAddr_size;} + size_t destSize = destAddr_size; + if (b.destAddr_size < destSize){destSize = b.destAddr_size;} + // They are equal if they hold the same local and remote address. + if (recvSize && destSize && destAddr && b.destAddr && recvAddr && b.recvAddr){ + if (!memcmp(recvAddr, b.recvAddr, recvSize) && !memcmp(destAddr, b.destAddr, destSize)){ + return true; + } + } + // All other cases, not equal + return false; +} + +Socket::UDPConnection::operator bool() const{return sock != -1;} + // Sets socket family type (to IPV4 or IPV6) (AF_INET=2, AF_INET6=10) void Socket::UDPConnection::setSocketFamily(int AF_TYPE){\ INFO_MSG("Switching UDP socket from %s to %s", addrFam(family), addrFam(AF_TYPE)); @@ -1742,6 +1928,22 @@ void Socket::UDPConnection::allocateDestination(){ ((struct sockaddr_in *)destAddr)->sin_family = AF_UNSPEC; } } + if (recvAddr && recvAddr_size < sizeof(sockaddr_in6)){ + free(recvAddr); + recvAddr = 0; + } + if (!recvAddr){ + recvAddr = malloc(sizeof(sockaddr_in6)); + if (recvAddr){ + recvAddr_size = sizeof(sockaddr_in6); + memset(recvAddr, 0, sizeof(sockaddr_in6)); + ((struct sockaddr_in *)recvAddr)->sin_family = AF_UNSPEC; + } + const int opt = 1; + if (setsockopt(sock, IPPROTO_IP, IP_PKTINFO, &opt, sizeof(opt))){ + WARN_MSG("Could not set PKTINFO to 1!"); + } + } } /// Stores the properties of the receiving end of this UDP socket. @@ -1788,6 +1990,11 @@ void Socket::UDPConnection::SetDestination(std::string destIp, uint32_t port){ close(); family = rp->ai_family; sock = socket(family, SOCK_DGRAM, 0); + { + // Allow address re-use + int on = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + } checkRecvBuf(); if (boundPort){ INFO_MSG("Rebinding to %s:%d %s", boundAddr.c_str(), boundPort, boundMulti.c_str()); @@ -1839,6 +2046,35 @@ void Socket::UDPConnection::GetDestination(std::string &destIp, uint32_t &port){ FAIL_MSG("Could not get destination for UDP socket"); }// Socket::UDPConnection GetDestination +/// Gets the properties of the receiving end of the local UDP socket. +/// This will be the sending end for all SendNow calls. +void Socket::UDPConnection::GetLocalDestination(std::string &destIp, uint32_t &port){ + if (!recvAddr || !recvAddr_size){ + destIp = ""; + port = 0; + return; + } + char addr_str[INET6_ADDRSTRLEN + 1]; + addr_str[INET6_ADDRSTRLEN] = 0; // set last byte to zero, to prevent walking out of the array + if (((struct sockaddr_in *)recvAddr)->sin_family == AF_INET6){ + if (inet_ntop(AF_INET6, &(((struct sockaddr_in6 *)recvAddr)->sin6_addr), addr_str, INET6_ADDRSTRLEN) != 0){ + destIp = addr_str; + port = ntohs(((struct sockaddr_in6 *)recvAddr)->sin6_port); + return; + } + } + if (((struct sockaddr_in *)recvAddr)->sin_family == AF_INET){ + if (inet_ntop(AF_INET, &(((struct sockaddr_in *)recvAddr)->sin_addr), addr_str, INET6_ADDRSTRLEN) != 0){ + destIp = addr_str; + port = ntohs(((struct sockaddr_in *)recvAddr)->sin_port); + return; + } + } + destIp = ""; + port = 0; + FAIL_MSG("Could not get destination for UDP socket"); +}// Socket::UDPConnection GetDestination + /// Gets the properties of the receiving end of this UDP socket. /// This will be the receiving end for all SendNow calls. std::string Socket::UDPConnection::getBinDestination(){ @@ -1864,7 +2100,10 @@ uint32_t Socket::UDPConnection::getDestPort() const{ /// Sets the socket to be blocking if the parameters is true. /// Sets the socket to be non-blocking otherwise. void Socket::UDPConnection::setBlocking(bool blocking){ - if (sock >= 0){setFDBlocking(sock, blocking);} + if (sock >= 0){ + setFDBlocking(sock, blocking); + isBlocking = blocking; + } } /// Sends a UDP datagram using the buffer sdata. @@ -1885,64 +2124,146 @@ void Socket::UDPConnection::SendNow(const char *sdata){ /// Does not do anything if len < 1. /// Prints an DLVL_FAIL level debug message if sending failed. void Socket::UDPConnection::SendNow(const char *sdata, size_t len){ - if (len < 1){return;} - int r = sendto(sock, sdata, len, 0, (sockaddr *)destAddr, destAddr_size); - if (r > 0){ - up += r; + SendNow(sdata, len, (sockaddr*)destAddr, destAddr_size); +} + +/// Sends a UDP datagram using the buffer sdata of length len. +/// Does not do anything if len < 1. +/// Prints an DLVL_FAIL level debug message if sending failed. +void Socket::UDPConnection::SendNow(const char *sdata, size_t len, sockaddr * dAddr, size_t dAddrLen){ + if (len < 1 || sock == -1){return;} + if (isConnected){ + int r = send(sock, sdata, len, 0); + if (r > 0){ + up += r; + }else{ + if (errno == EDESTADDRREQ){ + close(); + return; + } + FAIL_MSG("Could not send UDP data through %d: %s", sock, strerror(errno)); + } + return; + } + if (hasReceiveData && recvAddr){ + msghdr mHdr; + char msg_control[0x100]; + iovec iovec; + iovec.iov_base = (void*)sdata; + iovec.iov_len = len; + mHdr.msg_name = dAddr; + mHdr.msg_namelen = dAddrLen; + mHdr.msg_iov = &iovec; + mHdr.msg_iovlen = 1; + mHdr.msg_control = msg_control; + mHdr.msg_controllen = sizeof(msg_control); + mHdr.msg_flags = 0; + int cmsg_space = 0; + cmsghdr * cmsg = CMSG_FIRSTHDR(&mHdr); + cmsg->cmsg_level = IPPROTO_IP; + cmsg->cmsg_type = IP_PKTINFO; + + struct in_pktinfo in_pktinfo; + memcpy(&(in_pktinfo.ipi_spec_dst), &(((sockaddr_in*)recvAddr)->sin_family), sizeof(in_pktinfo.ipi_spec_dst)); + in_pktinfo.ipi_ifindex = recvInterface; + cmsg->cmsg_len = CMSG_LEN(sizeof(in_pktinfo)); + *(struct in_pktinfo*)CMSG_DATA(cmsg) = in_pktinfo; + cmsg_space += CMSG_SPACE(sizeof(in_pktinfo)); + mHdr.msg_controllen = cmsg_space; + + int r = sendmsg(sock, &mHdr, 0); + if (r > 0){ + up += r; + }else{ + FAIL_MSG("Could not send UDP data through %d: %s", sock, strerror(errno)); + } + return; }else{ - FAIL_MSG("Could not send UDP data through %d: %s", sock, strerror(errno)); + int r = sendto(sock, sdata, len, 0, dAddr, dAddrLen); + if (r > 0){ + up += r; + }else{ + FAIL_MSG("Could not send UDP data through %d: %s", sock, strerror(errno)); + } } } /// Queues sdata, len for sending over this socket. /// If there has been enough time since the last packet, sends immediately. /// Warning: never call sendPaced for the same socket from a different thread! -void Socket::UDPConnection::sendPaced(const char *sdata, size_t len){ - if (!paceQueue.size() && (!lastPace || Util::getMicros(lastPace) > 10000)){ - SendNow(sdata, len); - lastPace = Util::getMicros(); +/// Note: Only actually encrypts if initDTLS was called in the past. +void Socket::UDPConnection::sendPaced(const char *sdata, size_t len, bool encrypt){ + if (hasDTLS && encrypt){ + if (ssl_ctx.state != MBEDTLS_SSL_HANDSHAKE_OVER){ + WARN_MSG("Attempting to write encrypted data before handshake completed! Data was thrown away."); + return; + } + int write = mbedtls_ssl_write(&ssl_ctx, (unsigned char*)sdata, len); + if (write <= 0){WARN_MSG("Could not write DTLS packet!");} }else{ - paceQueue.push_back(Util::ResizeablePointer()); - paceQueue.back().assign(sdata, len); - // Try to send a packet, if time allows - //sendPaced(0); + if (!paceQueue.size() && (!lastPace || Util::getMicros(lastPace) > 10000)){ + SendNow(sdata, len); + lastPace = Util::getMicros(); + }else{ + paceQueue.push_back(Util::ResizeablePointer()); + paceQueue.back().assign(sdata, len); + // Try to send a packet, if time allows + //sendPaced(0); + } } } +// Gets time in microseconds until next sendPaced call would send something +size_t Socket::UDPConnection::timeToNextPace(uint64_t uTime){ + size_t qSize = paceQueue.size(); + if (!qSize){return std::string::npos;} // No queue? No time. Return highest possible value. + if (!uTime){uTime = Util::getMicros();} + uint64_t paceWait = uTime - lastPace; // Time we've waited so far already + + // Target clearing the queue in 25ms at most. + uint64_t targetTime = 25000 / qSize; + // If this slows us to below 1 packet per 5ms, go that speed instead. + if (targetTime > 5000){targetTime = 5000;} + // If the wait is over, send now. + if (paceWait >= targetTime){return 0;} + // Return remaining wait time + return targetTime - paceWait; +} + /// Spends uSendWindow microseconds either sending paced packets or sleeping, whichever is more appropriate /// Warning: never call sendPaced for the same socket from a different thread! void Socket::UDPConnection::sendPaced(uint64_t uSendWindow){ uint64_t currPace = Util::getMicros(); + uint64_t uTime = currPace; do{ - uint64_t uTime = Util::getMicros(); - uint64_t sleepTime = uTime - currPace; - if (sleepTime > uSendWindow){ - sleepTime = 0; - }else{ - sleepTime = uSendWindow - sleepTime; - } - uint64_t paceWait = uTime - lastPace; - size_t qSize = paceQueue.size(); - // If the queue is complete, wait out the remainder of the time - if (!qSize){ - Util::usleep(sleepTime); - return; - } - // Otherwise, target clearing the queue in 25ms at most. - uint64_t targetTime = 25000 / qSize; - // If this slows us to below 1 packet per 5ms, go that speed instead. - if (targetTime > 5000){targetTime = 5000;} - // If the wait is over, send now. - if (paceWait >= targetTime){ + uint64_t sleepTime = uSendWindow - (uTime - currPace); + uint64_t nextPace = timeToNextPace(uTime); + if (sleepTime > nextPace){sleepTime = nextPace;} + + // Not sleeping? Send now! + if (!sleepTime){ SendNow(*paceQueue.begin(), paceQueue.begin()->size()); paceQueue.pop_front(); lastPace = uTime; continue; } - // Otherwise, wait for the smaller of remaining wait time or remaining send window time. - if (targetTime - paceWait < sleepTime){sleepTime = targetTime - paceWait;} - Util::usleep(sleepTime); - }while(Util::getMicros(currPace) < uSendWindow); + + { + // Use select to wait until a packet arrives or until the next packet should be sent + fd_set rfds; + struct timeval T; + T.tv_sec = sleepTime / 1000000; + T.tv_usec = sleepTime % 1000000; + // Watch configured FD's for input + FD_ZERO(&rfds); + int maxFD = getSock(); + FD_SET(maxFD, &rfds); + int r = select(maxFD + 1, &rfds, NULL, NULL, &T); + // If we can read the socket, immediately return and stop waiting + if (r > 0){return;} + } + uTime = Util::getMicros(); + }while(uTime - currPace < uSendWindow); } std::string Socket::UDPConnection::getBoundAddress(){ @@ -1995,6 +2316,11 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, const std::str for (rp = addr_result; rp != NULL; rp = rp->ai_next){ sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (sock == -1){continue;} + { + // Allow address re-use + int on = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); + } if (rp->ai_family == AF_INET6){ const int optval = 0; if (setsockopt(sock, SOL_SOCKET, IPV6_V6ONLY, &optval, sizeof(optval)) < 0){ @@ -2046,7 +2372,7 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, const std::str boundAddr = iface; boundMulti = multicastInterfaces; boundPort = portNo; - INFO_MSG("UDP bind success on %s:%u (%s)", human_addr, portNo, addrFam(rp->ai_family)); + INFO_MSG("UDP bind success %d on %s:%u (%s)", sock, human_addr, portNo, addrFam(rp->ai_family)); break; } if (err_str.size()){err_str += ", ";} @@ -2144,21 +2470,135 @@ uint16_t Socket::UDPConnection::bind(int port, std::string iface, const std::str return portNo; } +bool Socket::UDPConnection::connect(){ + if (!recvAddr || !recvAddr_size || !destAddr || !destAddr_size){ + WARN_MSG("Attempting to connect a UDP socket without local and/or remote address!"); + return false; + } + + { + std::string destIp; + uint32_t port = 0; + char addr_str[INET6_ADDRSTRLEN + 1]; + if (((struct sockaddr_in *)recvAddr)->sin_family == AF_INET6){ + if (inet_ntop(AF_INET6, &(((struct sockaddr_in6 *)recvAddr)->sin6_addr), addr_str, INET6_ADDRSTRLEN) != 0){ + destIp = addr_str; + port = ntohs(((struct sockaddr_in6 *)recvAddr)->sin6_port); + } + } + if (((struct sockaddr_in *)recvAddr)->sin_family == AF_INET){ + if (inet_ntop(AF_INET, &(((struct sockaddr_in *)recvAddr)->sin_addr), addr_str, INET6_ADDRSTRLEN) != 0){ + destIp = addr_str; + port = ntohs(((struct sockaddr_in *)recvAddr)->sin_port); + } + } + int ret = ::bind(sock, (const struct sockaddr*)recvAddr, recvAddr_size); + if (!ret){ + INFO_MSG("Bound socket %d to %s:%" PRIu32, sock, destIp.c_str(), port); + }else{ + FAIL_MSG("Failed to bind socket %d (%s) %s:%" PRIu32 ": %s", sock, addrFam(((struct sockaddr_in *)recvAddr)->sin_family), destIp.c_str(), port, strerror(errno)); + std::ofstream bleh("/tmp/socket_recv"); + bleh.write((const char*)recvAddr, recvAddr_size); + bleh.write((const char*)destAddr, destAddr_size); + bleh.close(); + return false; + } + } + + { + std::string destIp; + uint32_t port; + char addr_str[INET6_ADDRSTRLEN + 1]; + if (((struct sockaddr_in *)destAddr)->sin_family == AF_INET6){ + if (inet_ntop(AF_INET6, &(((struct sockaddr_in6 *)destAddr)->sin6_addr), addr_str, INET6_ADDRSTRLEN) != 0){ + destIp = addr_str; + port = ntohs(((struct sockaddr_in6 *)destAddr)->sin6_port); + } + } + if (((struct sockaddr_in *)destAddr)->sin_family == AF_INET){ + if (inet_ntop(AF_INET, &(((struct sockaddr_in *)destAddr)->sin_addr), addr_str, INET6_ADDRSTRLEN) != 0){ + destIp = addr_str; + port = ntohs(((struct sockaddr_in *)destAddr)->sin_port); + } + } + int ret = ::connect(sock, (const struct sockaddr*)destAddr, destAddr_size); + if (!ret){ + INFO_MSG("Connected socket to %s:%" PRIu32, destIp.c_str(), port); + }else{ + FAIL_MSG("Failed to connect socket to %s:%" PRIu32 ": %s", destIp.c_str(), port, strerror(errno)); + return false; + } + } + isConnected = true; + return true; +} + + /// Attempt to receive a UDP packet. /// This will automatically allocate or resize the internal data buffer if needed. /// If a packet is received, it will be placed in the "data" member, with it's length in "data_len". /// \return True if a packet was received, false otherwise. bool Socket::UDPConnection::Receive(){ + if (pretendReceive){ + pretendReceive = false; + return onData(); + } if (sock == -1){return false;} data.truncate(0); + if (isConnected){ + int r = recv(sock, data, data.rsize(), MSG_TRUNC | MSG_DONTWAIT); + if (r == -1){ + if (errno != EAGAIN){ + INFO_MSG("UDP receive: %d (%s)", errno, strerror(errno)); + if (errno == ECONNREFUSED){close();} + } + return false; + } + if (r > 0){ + data.append(0, r); + down += r; + if (data.rsize() < (unsigned int)r){ + INFO_MSG("Doubling UDP socket buffer from %" PRIu32 " to %" PRIu32, data.rsize(), data.rsize()*2); + data.allocate(data.rsize()*2); + } + return onData(); + } + return false; + } sockaddr_in6 addr; socklen_t destsize = sizeof(addr); - int r = recvfrom(sock, data, data.rsize(), MSG_TRUNC | MSG_DONTWAIT, (sockaddr *)&addr, &destsize); + //int r = recvfrom(sock, data, data.rsize(), MSG_TRUNC | MSG_DONTWAIT, (sockaddr *)&addr, &destsize); + msghdr mHdr; + memset(&mHdr, 0, sizeof(mHdr)); + char ctrl[0x100]; + iovec dBufs; + dBufs.iov_base = data; + dBufs.iov_len = data.rsize(); + mHdr.msg_name = &addr; + mHdr.msg_namelen = destsize; + mHdr.msg_control = ctrl; + mHdr.msg_controllen = 0x100; + mHdr.msg_iov = &dBufs; + mHdr.msg_iovlen = 1; + int r = recvmsg(sock, &mHdr, MSG_TRUNC | MSG_DONTWAIT); + destsize = mHdr.msg_namelen; if (r == -1){ if (errno != EAGAIN){INFO_MSG("UDP receive: %d (%s)", errno, strerror(errno));} return false; } if (destAddr && destsize && destAddr_size >= destsize){memcpy(destAddr, &addr, destsize);} + if (recvAddr){ + for ( struct cmsghdr *cmsg = CMSG_FIRSTHDR(&mHdr); cmsg != NULL; cmsg = CMSG_NXTHDR(&mHdr, cmsg)){ + if (cmsg->cmsg_level != IPPROTO_IP || cmsg->cmsg_type != IP_PKTINFO){continue;} + struct in_pktinfo* pi = (in_pktinfo*)CMSG_DATA(cmsg); + struct sockaddr_in * recvCast = (sockaddr_in*)recvAddr; + recvCast->sin_family = family; + recvCast->sin_port = htons(boundPort); + memcpy(&(recvCast->sin_addr), &(pi->ipi_spec_dst), sizeof(pi->ipi_spec_dst)); + recvInterface = pi->ipi_ifindex; + hasReceiveData = true; + } + } data.append(0, r); down += r; //Handle UDP packets that are too large @@ -2166,7 +2606,88 @@ bool Socket::UDPConnection::Receive(){ INFO_MSG("Doubling UDP socket buffer from %" PRIu32 " to %" PRIu32, data.rsize(), data.rsize()*2); data.allocate(data.rsize()*2); } - return (r > 0); + return onData(); +} + +bool Socket::UDPConnection::onData(){ + wasEncrypted = false; + if (!data.size()){return false;} + uint8_t fb = 0; + int r = data.size(); + if (r){fb = (uint8_t)data[0];} + if (r && hasDTLS && fb > 19 && fb < 64){ + if (nextDTLSReadLen){ + INFO_MSG("Overwriting %zu bytes of unread dTLS data!", nextDTLSReadLen); + } + nextDTLSRead = data; + nextDTLSReadLen = data.size(); + // Complete dTLS handshake if needed + if (ssl_ctx.state != MBEDTLS_SSL_HANDSHAKE_OVER){ + do{ + r = mbedtls_ssl_handshake(&ssl_ctx); + switch (r){ + case 0:{ // Handshake complete + INFO_MSG("dTLS handshake complete!"); + int extrRes = 0; + uint8_t keying_material[MBEDTLS_DTLS_SRTP_MAX_KEY_MATERIAL_LENGTH]; + size_t keying_material_len = sizeof(keying_material); + extrRes = mbedtls_ssl_get_dtls_srtp_key_material(&ssl_ctx, keying_material, &keying_material_len); + if (extrRes){ + char mbedtls_msg[1024]; + mbedtls_strerror(extrRes, mbedtls_msg, sizeof(mbedtls_msg)); + WARN_MSG("dTLS could not extract keying material: %s", mbedtls_msg); + return Receive(); + } + + mbedtls_ssl_srtp_profile srtp_profile = mbedtls_ssl_get_dtls_srtp_protection_profile(&ssl_ctx); + switch (srtp_profile){ + case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80:{ + cipher = "SRTP_AES128_CM_SHA1_80"; + break; + } + case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32:{ + cipher = "SRTP_AES128_CM_SHA1_32"; + break; + } + default:{ + WARN_MSG("Unhandled SRTP profile, cannot extract keying material."); + return Receive(); + } + } + remote_key.assign((char *)(&keying_material[0]) + 0, 16); + local_key.assign((char *)(&keying_material[0]) + 16, 16); + remote_salt.assign((char *)(&keying_material[0]) + 32, 14); + local_salt.assign((char *)(&keying_material[0]) + 46, 14); + return Receive(); // No application-level data to read + } + case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:{ + dTLSReset(); + return Receive(); // No application-level data to read + } + case MBEDTLS_ERR_SSL_WANT_READ:{ + return Receive(); // No application-level data to read + } + default:{ + char mbedtls_msg[1024]; + mbedtls_strerror(r, mbedtls_msg, sizeof(mbedtls_msg)); + WARN_MSG("dTLS could not handshake: %s", mbedtls_msg); + return Receive(); // No application-level data to read + } + } + }while (r == MBEDTLS_ERR_SSL_WANT_WRITE); + }else{ + int read = mbedtls_ssl_read(&ssl_ctx, (unsigned char *)(char*)data, data.size()); + if (read <= 0){ + // Non-encrypted read (encrypted read fail) + return true; + } + // Encrypted read success + wasEncrypted = true; + data.truncate(read); + return true; + } + } + return r > 0; } int Socket::UDPConnection::getSock(){ diff --git a/lib/socket.h b/lib/socket.h index 8edb852e..f73f93cf 100644 --- a/lib/socket.h +++ b/lib/socket.h @@ -18,12 +18,14 @@ #include "util.h" #ifdef SSL -#include "mbedtls/ctr_drbg.h" -#include "mbedtls/debug.h" -#include "mbedtls/entropy.h" -#include "mbedtls/error.h" -#include "mbedtls/net.h" -#include "mbedtls/ssl.h" +#include +#include +#include +#include +#include +#include +#include +#include #endif #include "util.h" @@ -196,10 +198,13 @@ namespace Socket{ class UDPConnection{ private: + void init(bool nonblock, int family = AF_INET6); int sock; ///< Internally saved socket number. std::string remotehost; ///< Stores remote host address void *destAddr; ///< Destination address pointer. unsigned int destAddr_size; ///< Size of the destination address pointer. + void *recvAddr; ///< Destination address pointer. + unsigned int recvAddr_size; ///< Size of the destination address pointer. unsigned int up; ///< Amount of bytes transferred up. unsigned int down; ///< Amount of bytes transferred down. int family; ///< Current socket address family @@ -208,19 +213,46 @@ namespace Socket{ void checkRecvBuf(); std::deque paceQueue; uint64_t lastPace; + int recvInterface; + bool hasReceiveData; + bool isBlocking; + bool isConnected; + bool pretendReceive; ///< If true, will pretend to have just received the current data buffer on new Receive() call + bool onData(); + + // dTLS-related members + bool hasDTLS; ///< True if dTLS is enabled + void * nextDTLSRead; + size_t nextDTLSReadLen; + mbedtls_entropy_context entropy_ctx; + mbedtls_ctr_drbg_context rand_ctx; + mbedtls_ssl_context ssl_ctx; + mbedtls_ssl_config ssl_conf; + mbedtls_ssl_cookie_ctx cookie_ctx; + mbedtls_timing_delay_context timer_ctx; public: Util::ResizeablePointer data; UDPConnection(const UDPConnection &o); UDPConnection(bool nonblock = false); ~UDPConnection(); + bool operator==(const UDPConnection& b) const; + operator bool() const; + void initDTLS(mbedtls_x509_crt *cert, mbedtls_pk_context *key); + void deinitDTLS(); + int dTLSRead(unsigned char *buf, size_t len); + int dTLSWrite(const unsigned char *buf, size_t len); + void dTLSReset(); + bool wasEncrypted; void close(); int getSock(); uint16_t bind(int port, std::string iface = "", const std::string &multicastAddress = ""); + bool connect(); void setBlocking(bool blocking); void allocateDestination(); void SetDestination(std::string hostname, uint32_t port); void GetDestination(std::string &hostname, uint32_t &port); + void GetLocalDestination(std::string &hostname, uint32_t &port); std::string getBinDestination(); const void * getDestAddr(){return destAddr;} size_t getDestAddrLen(){return destAddr_size;} @@ -230,8 +262,13 @@ namespace Socket{ void SendNow(const std::string &data); void SendNow(const char *data); void SendNow(const char *data, size_t len); - void sendPaced(const char * data, size_t len); + void SendNow(const char *sdata, size_t len, sockaddr * dAddr, size_t dAddrLen); + void sendPaced(const char * data, size_t len, bool encrypt = true); void sendPaced(uint64_t uSendWindow); + size_t timeToNextPace(uint64_t uTime = 0); void setSocketFamily(int AF_TYPE); + + // dTLS-related public members + std::string cipher, remote_key, local_key, remote_salt, local_salt; }; }// namespace Socket diff --git a/meson.build b/meson.build index b120ec67..03ff3690 100644 --- a/meson.build +++ b/meson.build @@ -129,6 +129,8 @@ if usessl mist_deps += [mbedtls, mbedx509, mbedcrypto] mist_deps += dependency('libsrtp2', default_options: ['tests=disabled'], fallback: ['libsrtp2', 'libsrtp2_dep']) + + usrsctp_dep = dependency('usrsctp', fallback: ['usrsctp', 'usrsctp_dep']) endif libsrt = false diff --git a/src/output/meson.build b/src/output/meson.build index fe781353..36d55599 100644 --- a/src/output/meson.build +++ b/src/output/meson.build @@ -91,6 +91,7 @@ foreach output : outputs endif if extra.contains('srtp') sources += files('output_webrtc_srtp.cpp', 'output_webrtc_srtp.h') + deps += usrsctp_dep endif if extra.contains('embed') sources += embed_tgts diff --git a/src/output/output.cpp b/src/output/output.cpp index 22daf69d..76772af3 100644 --- a/src/output/output.cpp +++ b/src/output/output.cpp @@ -735,18 +735,22 @@ namespace Mist{ std::set validTracks = M.getValidTracks(); if (!validTracks.size()){return 0;} uint64_t start = 0xFFFFFFFFFFFFFFFFull; + uint64_t nonMetaStart = 0xFFFFFFFFFFFFFFFFull; if (userSelect.size()){ for (std::map::iterator it = userSelect.begin(); it != userSelect.end(); it++){ if (M.trackValid(it->first) && start > M.getFirstms(it->first)){ start = M.getFirstms(it->first); } + if (M.trackValid(it->first) && M.getType(it->first) != "meta" && nonMetaStart > M.getFirstms(it->first)){ + nonMetaStart = M.getFirstms(it->first); + } } }else{ for (std::set::iterator it = validTracks.begin(); it != validTracks.end(); it++){ if (start > M.getFirstms(*it)){start = M.getFirstms(*it);} } } - return start; + return nonMetaStart != 0xFFFFFFFFFFFFFFFFull ? nonMetaStart: start; } /// Return the end time of the selected tracks, or 0 if unknown or live. diff --git a/src/output/output_webrtc.cpp b/src/output/output_webrtc.cpp index 70a0004c..4f38f27b 100644 --- a/src/output/output_webrtc.cpp +++ b/src/output/output_webrtc.cpp @@ -8,6 +8,7 @@ #include #include // ifaddr, listing ip addresses. #include +#include /* This file handles both input and output, and can operate in WHIP/WHEP as well as WebSocket signaling mode. In case of WHIP/WHEP: the Socket is closed after signaling has happened and the keepGoing function @@ -18,13 +19,15 @@ When handling WebRTC Input a second thread is started dedicated to UDP traffic. namespace Mist{ + bool doDTLS = true; + bool volkswagenMode = false; + OutWebRTC *classPointer = 0; /* ------------------------------------------------ */ static uint32_t generateSSRC(); static void webRTCInputOutputThreadFunc(void *arg); - static int onDTLSHandshakeWantsToWriteCallback(const uint8_t *data, int *nbytes); static void onDTSCConverterHasPacketCallback(const DTSC::Packet &pkt); static void onDTSCConverterHasInitDataCallback(const uint64_t track, const std::string &initData); static void onRTPSorterHasPacketCallback(const uint64_t track, @@ -33,6 +36,77 @@ namespace Mist{ static void onRTPPacketizerHasDataCallback(void *socket, const char *data, size_t len, uint8_t channel); static void onRTPPacketizerHasRTCPDataCallback(void *socket, const char *data, size_t nbytes, uint8_t channel); + static int sctp_recv_cb(struct socket *s, union sctp_sockstore addr, void *data, size_t datalen, struct sctp_rcvinfo rcv, int flags, void *ulp_info){ + if (data) { + if (!(flags & MSG_NOTIFICATION)){ + ((OutWebRTC*)(addr.sconn.sconn_addr))->onSCTP((const char*)data, datalen, ntohs(rcv.rcv_sid), ntohl(rcv.rcv_ppid)); + } + free(data); + }else{ + usrsctp_close(s); + } + return 1; + } + + int sctp_send_cb(void *addr, void *buf, size_t length, uint8_t tos, uint8_t set_df){ + ((OutWebRTC*)addr)->sendSCTPPacket((const char*)buf, length); + return 0; + } + + void sctp_debug_cb(const char * format, ...){ + char msg[1024]; + va_list args; + va_start(args, format); + vsnprintf(msg, 1024, format, args); + va_end(args); + INFO_MSG("sctp: %s", msg); + } + + WebRTCSocket::WebRTCSocket(){ + udpSock = 0; + if (volkswagenMode){ + srtpWriter.init("SRTP_AES128_CM_SHA1_80", "volkswagen modus", "volkswagenmode"); + } + } + + size_t WebRTCSocket::sendRTCP(const char * data, size_t len){ + if (doDTLS){ + dataBuffer.allocate(len + 256); + dataBuffer.assign(data, len); + int rtcpPacketSize = len; + if (srtpWriter.protectRtcp((uint8_t *)(void *)dataBuffer, &rtcpPacketSize) != 0){ + ERROR_MSG("Failed to protect the RTCP message."); + return 0; + } + udpSock->sendPaced(dataBuffer, rtcpPacketSize, false); + return rtcpPacketSize; + } + + udpSock->sendPaced(data, len, false); + + if (volkswagenMode){ + dataBuffer.allocate(len + 256); + dataBuffer.assign(data, len); + int rtcpPacketSize = len; + srtpWriter.protectRtcp((uint8_t *)(void *)dataBuffer, &rtcpPacketSize); + } + return len; + } + + size_t WebRTCSocket:: ackNACK(uint32_t pSSRC, uint16_t seq){ + if (!outBuffers.count(pSSRC)){ + WARN_MSG("Could not answer NACK for %" PRIu32 ": we don't know this track", pSSRC); + return 0; + } + nackBuffer &nb = outBuffers[pSSRC]; + if (!nb.isBuffered(seq)){ + HIGH_MSG("Could not answer NACK for %" PRIu32 " #%" PRIu16 ": packet not buffered", pSSRC, seq); + return 0; + } + udpSock->sendPaced(nb.getData(seq), nb.getSize(seq), false); + return nb.getSize(seq); + } + /* ------------------------------------------------ */ WebRTCTrack::WebRTCTrack(){ @@ -58,6 +132,8 @@ namespace Mist{ /* ------------------------------------------------ */ OutWebRTC::OutWebRTC(Socket::Connection &myConn) : HTTPOutput(myConn){ + sctpInited = false; + sctpConnected = false; noSignalling = false; totalPkts = 0; totalLoss = 0; @@ -73,6 +149,7 @@ namespace Mist{ vidTrack = INVALID_TRACK_ID; prevVidTrack = INVALID_TRACK_ID; audTrack = INVALID_TRACK_ID; + metaTrack = INVALID_TRACK_ID; firstKey = true; repeatInit = true; wsCmds = true; @@ -90,9 +167,9 @@ namespace Mist{ videoConstraint = videoBitrate; RTP::MAX_SEND = 1350 - 28; didReceiveKeyFrame = false; - doDTLS = true; - volkswagenMode = false; syncedNTPClock = false; + lastMediaSocket = 0; + lastMetaSocket = 0; JSON::Value & certOpt = config->getOption("cert", true); @@ -137,10 +214,6 @@ namespace Mist{ } } - if (dtlsHandshake.init(&cert.cert, &cert.key, onDTLSHandshakeWantsToWriteCallback) != 0){ - onFail("Failed to initialize the dtls-srtp handshake helper.", true); - return; - } sdpAnswer.setFingerprint(cert.getFingerprintSha256()); classPointer = this; @@ -155,12 +228,6 @@ namespace Mist{ delete webRTCInputOutputThread; webRTCInputOutputThread = NULL; } - - if (srtpReader.shutdown() != 0){FAIL_MSG("Failed to cleanly shutdown the srtp reader.");} - if (srtpWriter.shutdown() != 0){FAIL_MSG("Failed to cleanly shutdown the srtp writer.");} - if (dtlsHandshake.shutdown() != 0){ - FAIL_MSG("Failed to cleanly shutdown the dtls handshake."); - } } // Initialize the WebRTC output. This is where we define what @@ -183,6 +250,8 @@ namespace Mist{ capa["codecs"][0u][1u].append("opus"); capa["codecs"][0u][1u].append("ALAW"); capa["codecs"][0u][1u].append("ULAW"); + capa["codecs"][0u][2u].append("JSON"); + capa["codecs"][0u][2u].append("subtitle"); capa["methods"][0u]["handler"] = "ws"; capa["methods"][0u]["type"] = "webrtc"; capa["methods"][0u]["hrn"] = "WebRTC with WebSocket signalling"; @@ -329,7 +398,7 @@ namespace Mist{ void OutWebRTC::requestHandler(){ if (noSignalling){ // For WHEP, make sure we keep listening for packets while waiting for new data to come in for sending - if (parseData && !handleWebRTCInputOutput()){udp.sendPaced(10);} + if (parseData && !handleWebRTCInputOutput()){sendPaced(10);} //After 10s of no packets, abort if (Util::bootMS() > lastRecv + 10000){ Util::logExitReason(ER_CLEAN_INACTIVE, "received no data for 10+ seconds"); @@ -341,16 +410,25 @@ namespace Mist{ } void OutWebRTC::respondHTTP(const HTTP::Parser & req, bool headersOnly){ + // Generic header/parameter handling + HTTPOutput::respondHTTP(req, headersOnly); + INFO_MSG("HTTP: %s", req.method.c_str()); // Check for WHIP payload - if (req.method == "OPTIONS"){ + if (headersOnly){ H.setCORSHeaders(); H.StartResponse("200", "All good", req, myConn); H.Chunkify(0, 0, myConn); + return; } if (req.method == "POST"){ + INFO_MSG("POST"); if (req.GetHeader("Content-Type") == "application/sdp"){ SDP::Session sdpParser; const std::string &offerStr = req.body; + if (config && config->hasOption("packetlog") && config->getBool("packetlog")){ + std::string fileName = "/tmp/wrtcpackets_"+JSON::Value(getpid()).asString(); + packetLog.open(fileName.c_str()); + } if (packetLog.is_open()){ packetLog << "[" << Util::bootMS() << "]" << offerStr << std::endl << std::endl; } @@ -488,7 +566,6 @@ namespace Mist{ INFO_MSG("Disabling encryption"); }else if(command["encrypt"].asString() == "placebo" || command["encrypt"].asString() == "volkswagen"){ INFO_MSG("Entering volkswagen mode: encrypt data, but send plaintext for easier analysis"); - srtpWriter.init("SRTP_AES128_CM_SHA1_80", "volkswagen modus", "volkswagenmode"); volkswagenMode = true; }else{ doDTLS = true; @@ -650,6 +727,7 @@ namespace Mist{ std::string videoCodec; std::string audioCodec; + std::string metaCodec; capa["codecs"][0u][0u].null(); capa["codecs"][0u][1u].null(); @@ -664,6 +742,11 @@ namespace Mist{ audioCodec = M.getCodec(it->first); capa["codecs"][0u][1u].append(audioCodec); } + if (M.getType(it->first) == "meta"){ + metaTrack = it->first; + metaCodec = M.getCodec(it->first); + capa["codecs"][0u][2u].append(std::string("+") + metaCodec); + } } sdpAnswer.setDirection("sendonly"); @@ -699,8 +782,17 @@ namespace Mist{ } } - // this is necessary so that we can get the remote IP when creating STUN replies. - udp.allocateDestination(); + // setup meta WebRTC Track + if (metaTrack != INVALID_TRACK_ID){ + if (sdpAnswer.enableMeta(M.getCodec(metaTrack))){ + WebRTCTrack &mTrack = webrtcTracks[metaTrack]; + if (!createWebRTCTrackFromAnswer(sdpAnswer.answerMetaMedia, sdpAnswer.answerMetaFormat, mTrack)){ + FAIL_MSG("Failed to create the WebRTCTrack for the selected metadata."); + webrtcTracks.erase(metaTrack); + return false; + } + } + } // we set parseData to `true` to start the data flow. Is also // used to break out of our loop in `onHTTP()`. @@ -728,19 +820,6 @@ namespace Mist{ HTTPOutput::handleWebsocketIdle(); } - bool OutWebRTC::onFinish(){ - if (parseData){ - JSON::Value commandResult; - commandResult["type"] = "on_stop"; - commandResult["current"] = currentTime(); - commandResult["begin"] = startTime(); - commandResult["end"] = endTime(); - webSock->sendFrame(commandResult.toString()); - parseData = false; - } - return true; - } - // Creates a WebRTCTrack for the given `SDP::Media` and // `SDP::MediaFormat`. The `SDP::MediaFormat` must contain the // `icePwd` and `iceUFrag` which are used when we're handling @@ -753,8 +832,8 @@ namespace Mist{ // peer expect us to send data. bool OutWebRTC::createWebRTCTrackFromAnswer(const SDP::Media &mediaAnswer, const SDP::MediaFormat &formatAnswer, WebRTCTrack &result){ - if (formatAnswer.payloadType == SDP_PAYLOAD_TYPE_NONE){ - FAIL_MSG("Cannot create a WebRTCTrack, the given SDP::MediaFormat has no `payloadType` set."); + if (formatAnswer.payloadType == SDP_PAYLOAD_TYPE_NONE && formatAnswer.encodingName != "WEBRTC-DATACHANNEL"){ + FAIL_MSG("Cannot create a WebRTCTrack, the given %s SDP::MediaFormat has no `payloadType` set.", formatAnswer.encodingName.c_str()); return false; } @@ -818,6 +897,19 @@ namespace Mist{ audioCodec++; } } + + const char *metaCodecPreference[] ={"JSON", "subtitle", NULL}; + const char **metaCodec = metaCodecPreference; + SDP::Media *metaMediaOffer = sdpSession.getMediaForType("meta"); + if (metaMediaOffer){ + INFO_MSG("Has meta offer!"); + while (*metaCodec){ + if (sdpSession.getMediaFormatByEncodingName("meta", *metaCodec)){ + capa["codecs"][0u][2u].append(std::string("+") + *metaCodec); + } + metaCodec++; + } + } } // This function is called to handle an offer from a peer that wants to push data towards us. @@ -960,7 +1052,7 @@ namespace Mist{ //If a bind host has been put in as override, use it if (config && config->hasOption("bindhost") && config->getString("bindhost").size()){ bindAddr = config->getString("bindhost"); - udpPort = udp.bind(port, bindAddr); + udpPort = mainSocket.bind(port, bindAddr); if (!udpPort){ WARN_MSG("UDP bind address not valid - ignoring setting and using best guess instead"); bindAddr.clear(); @@ -972,12 +1064,12 @@ namespace Mist{ if (!bindAddr.size()){ bindAddr = Socket::resolveHostToBestExternalAddrGuess(externalAddr, AF_INET, myConn.getBoundAddress()); if (!bindAddr.size()){ - WARN_MSG("UDP bind to best guess failed - using same address as incoming connection as a last resort"); + INFO_MSG("UDP bind to best guess failed - using same address as incoming connection as a last resort"); bindAddr.clear(); }else{ - udpPort = udp.bind(port, bindAddr); + udpPort = mainSocket.bind(port, bindAddr); if (!udpPort){ - WARN_MSG("UDP bind to best guess failed - using same address as incoming connection as a last resort"); + INFO_MSG("UDP bind to best guess failed - using same address as incoming connection as a last resort"); bindAddr.clear(); }else{ INFO_MSG("Bound to public UDP bind address derived from hostname"); @@ -986,7 +1078,7 @@ namespace Mist{ } if (!bindAddr.size()){ bindAddr = myConn.getBoundAddress(); - udpPort = udp.bind(port, bindAddr); + udpPort = mainSocket.bind(port, bindAddr); if (!udpPort){ FAIL_MSG("UDP bind to connected address failed - we're out of options here, I'm afraid..."); bindAddr.clear(); @@ -995,7 +1087,11 @@ namespace Mist{ } } - Util::Procs::socketList.insert(udp.getSock()); + Util::Procs::socketList.insert(mainSocket.getSock()); + + // this is necessary so that we can get the remote IP when creating STUN replies. + mainSocket.allocateDestination(); + if (config && config->hasOption("pubhost") && config->getString("pubhost").size()){ bindAddr = config->getString("pubhost"); } @@ -1009,9 +1105,8 @@ namespace Mist{ // function. The `webRTCInputOutputThreadFunc()` is basically empty // and all work for the thread is done here. void OutWebRTC::handleWebRTCInputOutputFromThread(){ - udp.allocateDestination(); while (keepGoing()){ - if (!handleWebRTCInputOutput()){udp.sendPaced(10);} + if (!handleWebRTCInputOutput()){sendPaced(10);} } } @@ -1024,6 +1119,111 @@ namespace Mist{ statComm.setTime(now - myConn.connTime()); } + bool OutWebRTC::handleUDPSocket(Socket::UDPConnection & sock){ + bool hadPack = false; + while(sock.Receive()){ + std::string remoteIP, localIP; + uint32_t remotePort, localPort; + sock.GetDestination(remoteIP, remotePort); + sock.GetLocalDestination(localIP, localPort); + + // Check if we already have a socket handling this exact connection, if so, don't create a new one + bool existsAlready = false; + for (std::map::iterator it = sockets.begin(); it != sockets.end(); it++){ + if (!*(it->second.udpSock)){ + int sockNo = it->first; + sockets.erase(sockNo); + rtpSockets.erase(sockNo); + sctpSockets.erase(sockNo); + it = sockets.begin(); + if (!sockets.size()){break;} + } + if (*(it->second.udpSock) == sock){ + existsAlready = true; + INFO_MSG("Duplicate socket, not spawning another, inserting packet instead"); + break; + } + } + if (existsAlready){continue;} + + // No existing socket? Create a new one specifically for this exact connection + Socket::UDPConnection * s = new Socket::UDPConnection(sock); + if (s->connect()){ + s->initDTLS(&(cert.cert), &(cert.key)); + sockets[s->getSock()].udpSock = s; + Util::Procs::socketList.insert(s->getSock()); + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Connected new socket " << s->getSock() << " for: " << localIP << ":" << localPort << " <-> " << remoteIP << ":" << remotePort << std::endl;} + }else{ + delete s; + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Failed to connect new socket for: " << remoteIP << ":" << remotePort << std::endl;} + } + hadPack = true; + } + return hadPack; + } + + bool OutWebRTC::handleUDPSocket(WebRTCSocket & wSock){ + bool hadPack = false; + while(wSock.udpSock->Receive()){ + hadPack = true; + myConn.addDown(wSock.udpSock->data.size()); + + if (wSock.udpSock->data.size() && wSock.udpSock->wasEncrypted){ + lastRecv = Util::bootMS(); + if (packetLog.is_open()){ + packetLog << "[" << Util::bootMS() << "]" << "SCTP packet (" << wSock.udpSock->data.size() << "b): " << std::endl; + char * buffer = usrsctp_dumppacket(wSock.udpSock->data, wSock.udpSock->data.size(), SCTP_DUMP_INBOUND); + packetLog << buffer; + usrsctp_freedumpbuffer(buffer); + } + if (!sctpSockets.count(wSock.udpSock->getSock())){ + int s = wSock.udpSock->getSock(); + rtpSockets.erase(s); + sctpSockets.insert(s); + } + if (!sctpInited){ + INFO_MSG("Initializing SCTP library"); + usrsctp_init(0, sctp_send_cb, sctp_debug_cb); + usrsctp_register_address((void *)this); + sctp_sock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, sctp_recv_cb, NULL, 0, NULL); + struct sockaddr_conn sconn; + memset(&sconn, 0, sizeof(struct sockaddr_conn)); + sconn.sconn_family = AF_CONN; +#ifdef HAVE_SCONN_LEN + sconn.sconn_len = sizeof(struct sockaddr_conn); +#endif + sconn.sconn_port = htons(5000); + sconn.sconn_addr = (void *)this; + usrsctp_bind(sctp_sock, (struct sockaddr *)&sconn, sizeof(struct sockaddr_conn)); + usrsctp_listen(sctp_sock, 1); + sctpInited = true; + } + usrsctp_conninput(this, wSock.udpSock->data, wSock.udpSock->data.size(), 0); + //usrsctp_accept(sctp_sock, 0, 0); + continue; + } + + uint8_t fb = (uint8_t)wSock.udpSock->data[0]; + if (fb > 127 && fb < 192){ + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Packet " << (int)fb << ": RTP/RTCP" << std::endl;} + handleReceivedRTPOrRTCPPacket(wSock); + }else if (fb > 19 && fb < 64){ + if (packetLog.is_open()){ + std::string remoteIP; + uint32_t remotePort; + wSock.udpSock->GetDestination(remoteIP, remotePort); + packetLog << "[" << Util::bootMS() << "]" << "DTLS (" << remoteIP << ":" << remotePort << ") - Non-application-level data" << std::endl; + } + }else if (fb < 2){ + handleReceivedSTUNPacket(wSock); + }else{ + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Packet " << (int)fb << ": Unknown" << std::endl;} + FAIL_MSG("Unhandled WebRTC data. Type: %02X", fb); + } + } + return hadPack; + } + // Checks if there is data on our UDP socket. The data can be // STUN, DTLS, SRTP or SRTCP. When we're receiving media from // the browser (e.g. from webcam) this function is called from @@ -1032,24 +1232,30 @@ namespace Mist{ bool OutWebRTC::handleWebRTCInputOutput(){ bool hadPack = false; - while (udp.Receive()){ - hadPack = true; - myConn.addDown(udp.data.size()); + hadPack |= handleUDPSocket(mainSocket); - uint8_t fb = (uint8_t)udp.data[0]; - if (fb > 127 && fb < 192){ - if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Packet " << (int)fb << ": RTP/RTCP" << std::endl;} - handleReceivedRTPOrRTCPPacket(); - }else if (fb > 19 && fb < 64){ - if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Packet " << (int)fb << ": DTLS" << std::endl;} - handleReceivedDTLSPacket(); - }else if (fb < 2){ - if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Packet " << (int)fb << ": STUN" << std::endl;} - handleReceivedSTUNPacket(); - }else{ - if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Packet " << (int)fb << ": Unknown" << std::endl;} - FAIL_MSG("Unhandled WebRTC data. Type: %02X", fb); + for (std::map::iterator it = sockets.begin(); it != sockets.end(); it++){ + bool wasInited = it->second.udpSock->cipher.size(); + hadPack |= handleUDPSocket(it->second); + if (!wasInited && it->second.udpSock->cipher.size()){ + if (it->second.srtpReader.init(it->second.udpSock->cipher, it->second.udpSock->remote_key, it->second.udpSock->remote_salt) != 0){ + FAIL_MSG("Failed to initialize the SRTP reader."); + } + if (it->second.srtpWriter.init(it->second.udpSock->cipher, it->second.udpSock->local_key, it->second.udpSock->local_salt) != 0){ + FAIL_MSG("Failed to initialize the SRTP writer."); + } + rtpSockets.insert(it->first); + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "SRTP reader/writer " << it->first << " initialized" << std::endl;} + } + + if (!*(it->second.udpSock)){ + int sockNo = it->first; + sockets.erase(sockNo); + rtpSockets.erase(sockNo); + sctpSockets.erase(sockNo); + it = sockets.begin(); + if (!sockets.size()){break;} } } @@ -1079,20 +1285,22 @@ namespace Mist{ } } - if (udp.getSock() == -1){onFail("UDP socket closed", true);} + if (mainSocket.getSock() == -1){onFail("UDP socket closed", true);} return hadPack; } - void OutWebRTC::handleReceivedSTUNPacket(){ + void OutWebRTC::handleReceivedSTUNPacket(WebRTCSocket &wSock){ size_t nparsed = 0; StunMessage stun_msg; - if (stunReader.parse((uint8_t *)(char*)udp.data, udp.data.size(), nparsed, stun_msg) != 0){ + if (stunReader.parse((uint8_t *)(char*)wSock.udpSock->data, wSock.udpSock->data.size(), nparsed, stun_msg) != 0){ + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "STUN: (unparsable)" << std::endl;} FAIL_MSG("Failed to parse a stun message."); return; } if (stun_msg.type != STUN_MSG_TYPE_BINDING_REQUEST){ + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "STUN: (non-binding request, ignored)" << std::endl;} INFO_MSG("We only handle STUN binding requests as we're an ice-lite implementation."); return; } @@ -1131,10 +1339,10 @@ namespace Mist{ } lastRecv = Util::bootMS(); - std::string remoteIP = ""; - uint32_t remotePort = 0; - udp.GetDestination(remoteIP, remotePort); - if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "STUN: Bound to " << remoteIP << ":" << remotePort << std::endl;} + std::string remoteIP; + uint32_t remotePort; + wSock.udpSock->GetDestination(remoteIP, remotePort); + if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "STUN: " << remoteIP << ":" << remotePort << std::endl;} // create the binding success response stun_msg.removeAttributes(); @@ -1147,63 +1355,33 @@ namespace Mist{ stun_writer.writeFingerprint(); stun_writer.end(); - udp.sendPaced((const char *)stun_writer.getBufferPtr(), stun_writer.getBufferSize()); + wSock.udpSock->SendNow((const char *)stun_writer.getBufferPtr(), stun_writer.getBufferSize()); myConn.addUp(stun_writer.getBufferSize()); } - void OutWebRTC::handleReceivedDTLSPacket(){ - - if (dtlsHandshake.hasKeyingMaterial()){ - DONTEVEN_MSG("Not feeding data into the handshake .. already done."); - return; - } - - if (dtlsHandshake.parse((const uint8_t *)(const char*)udp.data, udp.data.size()) != 0){ - FAIL_MSG("Failed to parse a DTLS packet."); - return; - } - lastRecv = Util::bootMS(); - - if (!dtlsHandshake.hasKeyingMaterial()){ - if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "DTLS: No keying material (yet)" << std::endl;} - return; - } - - if (srtpReader.init(dtlsHandshake.cipher, dtlsHandshake.remote_key, dtlsHandshake.remote_salt) != 0){ - FAIL_MSG("Failed to initialize the SRTP reader."); - return; - } - - if (srtpWriter.init(dtlsHandshake.cipher, dtlsHandshake.local_key, dtlsHandshake.local_salt) != 0){ - FAIL_MSG("Failed to initialize the SRTP writer."); - return; - } - if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "DTLS: Keying material success" << std::endl;} - } - void OutWebRTC::ackNACK(uint32_t pSSRC, uint16_t seq){ - totalRetrans++; - if (!outBuffers.count(pSSRC)){ - WARN_MSG("Could not answer NACK for %" PRIu32 ": we don't know this track", pSSRC); - return; + + for (std::set::iterator it = rtpSockets.begin(); it != rtpSockets.end(); ++it){ + if (!*(sockets[*it].udpSock)){continue;} + size_t sent = sockets[*it].ackNACK(pSSRC, seq); + if (sent){ + totalRetrans++; + myConn.addUp(sent); + } } - nackBuffer &nb = outBuffers[pSSRC]; - if (!nb.isBuffered(seq)){ - HIGH_MSG("Could not answer NACK for %" PRIu32 " #%" PRIu16 ": packet not buffered", pSSRC, seq); - return; - } - udp.sendPaced(nb.getData(seq), nb.getSize(seq)); - myConn.addUp(nb.getSize(seq)); HIGH_MSG("Answered NACK for %" PRIu32 " #%" PRIu16, pSSRC, seq); } - void OutWebRTC::handleReceivedRTPOrRTCPPacket(){ + void OutWebRTC::handleReceivedRTPOrRTCPPacket(WebRTCSocket &wSock){ - uint8_t pt = udp.data[1] & 0x7F; + // Mark this socket as an (S)RTP socket, if not already marked + if (!rtpSockets.count(wSock.udpSock->getSock())){rtpSockets.insert(wSock.udpSock->getSock());} + + uint8_t pt = wSock.udpSock->data[1] & 0x7F; if ((pt < 64) || (pt >= 96)){ - RTP::Packet rtp_pkt((const char *)udp.data, (unsigned int)udp.data.size()); + RTP::Packet rtp_pkt((const char *)wSock.udpSock->data, (unsigned int)wSock.udpSock->data.size()); uint16_t currSeqNum = rtp_pkt.getSequence(); size_t idx = M.trackIDToIndex(rtp_pkt.getPayloadType(), getpid()); @@ -1225,21 +1403,21 @@ namespace Mist{ WebRTCTrack &rtcTrack = webrtcTracks[idx]; // Decrypt the SRTP to RTP - int len = udp.data.size(); - if (srtpReader.unprotectRtp((uint8_t *)(char*)udp.data, &len) != 0){ + int len = wSock.udpSock->data.size(); + if (wSock.srtpReader.unprotectRtp((uint8_t *)(char*)wSock.udpSock->data, &len) != 0){ if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "RTP decrypt failure" << std::endl;} return; } if (!len){return;} lastRecv = Util::bootMS(); - RTP::Packet unprotPack(udp.data, len); + RTP::Packet unprotPack(wSock.udpSock->data, len); DONTEVEN_MSG("%s", unprotPack.toString().c_str()); rtcTrack.gotPacket(unprotPack.getTimeStamp()); if (rtp_pkt.getPayloadType() == rtcTrack.REDPayloadType || rtp_pkt.getPayloadType() == rtcTrack.ULPFECPayloadType){ if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "RED packet " << rtp_pkt.getPayloadType() << " #" << currSeqNum << std::endl;} - rtcTrack.sorter.addREDPacket(udp.data, len, rtcTrack.payloadType, rtcTrack.REDPayloadType, + rtcTrack.sorter.addREDPacket(wSock.udpSock->data, len, rtcTrack.payloadType, rtcTrack.REDPayloadType, rtcTrack.ULPFECPayloadType); }else{ if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Basic packet " << rtp_pkt.getPayloadType() << " #" << currSeqNum << std::endl;} @@ -1258,9 +1436,9 @@ namespace Mist{ }else{ //Decrypt feedback packet - int len = udp.data.size(); + int len = wSock.udpSock->data.size(); if (doDTLS){ - if (srtpReader.unprotectRtcp((uint8_t *)(char*)udp.data, &len) != 0){ + if (wSock.srtpReader.unprotectRtcp((uint8_t *)(char*)wSock.udpSock->data, &len) != 0){ if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "RTCP decrypt failure" << std::endl;} return; } @@ -1268,13 +1446,13 @@ namespace Mist{ } lastRecv = Util::bootMS(); - uint8_t fmt = udp.data[0] & 0x1F; + uint8_t fmt = wSock.udpSock->data[0] & 0x1F; if (pt == 77 || pt == 65){ //77/65 = nack if (fmt == 1){ - uint32_t pSSRC = Bit::btohl(udp.data + 8); - uint16_t seq = Bit::btohs(udp.data + 12); - uint16_t bitmask = Bit::btohs(udp.data + 14); + uint32_t pSSRC = Bit::btohl(wSock.udpSock->data + 8); + uint16_t seq = Bit::btohs(wSock.udpSock->data + 12); + uint16_t bitmask = Bit::btohs(wSock.udpSock->data + 14); ackNACK(pSSRC, seq); size_t missed = 1; if (bitmask & 1){ackNACK(pSSRC, seq + 1); missed++;} @@ -1309,15 +1487,15 @@ namespace Mist{ } }else if (pt == 72){ //72 = sender report - uint32_t SSRC = Bit::btohl(udp.data + 4); + uint32_t SSRC = Bit::btohl(wSock.udpSock->data + 4); std::map::iterator it; for (it = webrtcTracks.begin(); it != webrtcTracks.end(); ++it){ if (it->second.SSRC == SSRC){ it->second.sorter.lastBootMS = Util::bootMS(); - it->second.sorter.lastNTP = Bit::btohl(udp.data+10); - uint64_t ntpTime = Bit::btohll(udp.data + 8); - uint32_t rtpTime = Bit::btohl(udp.data + 16); - uint32_t packets = Bit::btohl(udp.data + 20); + it->second.sorter.lastNTP = Bit::btohl(wSock.udpSock->data+10); + uint64_t ntpTime = Bit::btohll(wSock.udpSock->data + 8); + uint32_t rtpTime = Bit::btohl(wSock.udpSock->data + 16); + uint32_t packets = Bit::btohl(wSock.udpSock->data + 20); if (packets > it->second.lastPktCount){ //counter went up; check if it was less than half the range if ((packets - it->second.lastPktCount) <= 0x7FFFFFFF){ @@ -1338,7 +1516,7 @@ namespace Mist{ //The else case is a no-op: //If it went down outside those ranges, this is an older packet we should just ignore } - uint32_t bytes = Bit::btohl(udp.data + 24); + uint32_t bytes = Bit::btohl(wSock.udpSock->data + 24); HIGH_MSG("Received sender report for track %s (%" PRIu32 " pkts, %" PRIu32 "b) time: %" PRIu32 " RTP = %" PRIu64 " NTP", it->second.rtpToDTSC.codec.c_str(), packets, bytes, rtpTime, ntpTime); if (rtpTime && ntpTime){ //msDiff is the amount of millis our current NTP time is ahead of the sync moment NTP time @@ -1355,10 +1533,10 @@ namespace Mist{ } } }else if (pt == 73){ - //73 = receiver report: https://datatracker.ietf.org/doc/html/rfc3550#section-6.4.2 + //73 = 201 = receiver report: https://datatracker.ietf.org/doc/html/rfc3550#section-6.4.2 //Packet may contain more than one report - char * ptr = udp.data + 8; - while (ptr + 24 <= udp.data + udp.data.size()){ + char * ptr = wSock.udpSock->data + 8; + while (ptr + 24 <= wSock.udpSock->data + wSock.udpSock->data.size()){ //Update the counter for this ssrc uint32_t ssrc = Bit::btoh24(ptr); lostPackets[ssrc] = Bit::btoh24(ptr + 5); @@ -1370,6 +1548,39 @@ namespace Mist{ for (std::map::iterator it = lostPackets.begin(); it != lostPackets.end(); ++it){ totalLoss += it->second; } + }else if (pt == 74){ + // 74 = 202 = SDES: https://www.ietf.org/rfc/rfc1889.html#section-6.4 + Util::ResizeablePointer & p = wSock.udpSock->data; + // Check padding bit + if (p[0] & 0x20){ + // Padding count is stored in the last octet + size_t padding = p[p.size() - 1]; + if (padding > p.size()){padding = p.size();} + p.truncate(p.size() - padding); + } + size_t offset = 4; + while (offset + 5 <= p.size()){ + uint32_t ssrc = Bit::btohl((char*)p+offset); + offset += 4; + while (offset + 2 <= p.size()){ + uint8_t type = p[offset]; + if (!type){ + ++offset; + break; + } + uint8_t len = p[offset+1]; + if (offset+2+len <= p.size()){ + std::string val(offset+2, len); + // Ignore blank SDES messages + if (len){ + INFO_MSG("SDES for %" PRIu32 ": type %" PRIu8 " = %s", ssrc, type, val.c_str()); + } + } + offset += len +2; + } + // Ensure alignment + if (offset % 4){offset += 4 - (offset % 4);} + } }else{ if (packetLog.is_open()){packetLog << "[" << Util::bootMS() << "]" << "Unknown payload type: " << pt << std::endl;} WARN_MSG("Unknown RTP feedback payload type: %u", pt); @@ -1379,10 +1590,139 @@ namespace Mist{ /* ------------------------------------------------ */ - int OutWebRTC::onDTLSHandshakeWantsToWrite(const uint8_t *data, int *nbytes){ - udp.sendPaced((const char *)data, (size_t)*nbytes); - myConn.addUp(*nbytes); - return 0; + void OutWebRTC::sendSCTPPacket(const char * data, size_t len){ + for (std::set::iterator it = sctpSockets.begin(); it != sctpSockets.end(); ++it){ + if (!*(sockets[*it].udpSock)){continue;} + sockets[*it].udpSock->sendPaced(data, len); + } + } + + /// Use select to wait until a packet arrives or until the next packet should be sent + void OutWebRTC::sendPaced(uint64_t uSendWindow){ + uint64_t currPace = Util::getMicros(); + uint64_t uTime = currPace; + do{ + uint64_t sleepTime = uSendWindow - (uTime - currPace); + + fd_set rfds; + FD_ZERO(&rfds); + int maxFD = mainSocket.getSock(); + FD_SET(maxFD, &rfds); + + for (std::map::iterator it = sockets.begin(); it != sockets.end(); it++){ + if (!*(it->second.udpSock)){continue;} + uint64_t nextPace = it->second.udpSock->timeToNextPace(uTime); + // Not sleeping? Send now! + if (!nextPace){ + it->second.udpSock->sendPaced(0); + nextPace = it->second.udpSock->timeToNextPace(uTime); + } + if (sleepTime > nextPace){sleepTime = nextPace;} + + int s = it->second.udpSock->getSock(); + FD_SET(s, &rfds); + if (maxFD < s){maxFD = s;} + } + + struct timeval T; + T.tv_sec = sleepTime / 1000000; + T.tv_usec = sleepTime % 1000000; + int r = select(maxFD + 1, &rfds, NULL, NULL, &T); + // If we can read the socket, immediately return and stop waiting + if (r > 0){return;} + + uTime = Util::getMicros(); + }while(uTime - currPace < uSendWindow); + } + + void OutWebRTC::onSCTP(const char * data, size_t len, uint16_t stream, uint32_t ppid){ + if (!sctpConnected){ + // We have to call accept (at least) once, otherwise the SCTP library considers our socket not connected + // Accept blocks if there is no peer, so we do this as soon as the first message is received, which means we have a peer. + sctp_sock = usrsctp_accept(sctp_sock, 0, 0); + sctpConnected = true; + } + if (ppid == 50){ + // DCEP message. Spec: https://www.rfc-editor.org/rfc/rfc8832.html + if (data[0] == 3){ + uint8_t chanType = data[1]; + uint32_t reliParam = Bit::btohl(data+4); + uint16_t lblLen = Bit::btohs(data+8); + uint16_t proLen = Bit::btohs(data+10); + std::string chanTypeStr; + switch (chanType){ + case 0x00: chanTypeStr = "reliable"; break; + case 0x80: chanTypeStr = "reliable, unordered"; break; + case 0x01: chanTypeStr = "max " + JSON::Value(reliParam).asString() + " retrans"; break; + case 0x81: chanTypeStr = "max " + JSON::Value(reliParam).asString() + " retrans, unordered"; break; + case 0x02: chanTypeStr = "max " + JSON::Value(reliParam).asString() + " millis"; break; + case 0x82: chanTypeStr = "max " + JSON::Value(reliParam).asString() + " millis, unordered"; break; + } + std::string label(data+12, lblLen); + std::string protocol(data+12+lblLen, proLen); + INFO_MSG("New data channel %" PRIu16 ": %s/%s (%s)", stream, label.c_str(), protocol.c_str(), chanTypeStr.c_str()); + + sctp_sndinfo sndinfo; + sndinfo.snd_sid = stream; + sndinfo.snd_flags = SCTP_EOR; + sndinfo.snd_ppid = htonl(50); + sndinfo.snd_context = 0; + sndinfo.snd_assoc_id = 0; + int ret = usrsctp_sendv(sctp_sock, "2", 1, NULL, 0, (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO, 0); + if (ret < 0){ + WARN_MSG("Could not send data channel ACK, error: %s", strerror(errno)); + }else{ + if ((protocol == "JSON" || label == "JSON" || protocol == "*" || label == "*") && !dataChannels.count("JSON")){ + dataChannels["JSON"] = stream; + while (queuedJSON.size()){ + sctp_sndinfo sndinfo; + sndinfo.snd_sid = stream; + sndinfo.snd_flags = SCTP_EOR; + sndinfo.snd_ppid = htonl(51); + sndinfo.snd_context = 0; + sndinfo.snd_assoc_id = 0; + int ret = usrsctp_sendv(sctp_sock, queuedJSON.begin()->data(), queuedJSON.begin()->size(), NULL, 0, (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO, 0); + if (ret < 0){ + WARN_MSG("Could not send data channel message: %s", strerror(errno)); + } + queuedJSON.pop_front(); + } + } + if ((protocol == "subtitle" || label == "subtitle" || protocol == "*" || label == "*") && !dataChannels.count("subtitle")){ + dataChannels["subtitle"] = stream; + } + } + if (packetLog.is_open()){ + packetLog << "Data channel " << stream << " opened: " << label << "/" << protocol << " (" << chanTypeStr << ")" << std::endl; + } + }else if (data[0] == 2){ + INFO_MSG("Data channel acknowledged by remote"); + if (packetLog.is_open()){ + packetLog << "Data channel " << stream << " acknowledged by remote" << std::endl; + } + }else{ + WARN_MSG("Received invalid DCEP message!"); + return; + } + }else if (ppid == 51){ + std::string txt(data, len); + INFO_MSG("Received text: %s", txt.c_str()); + if (packetLog.is_open()){ + packetLog << "Received SCTP text (data stream " << stream << "): " << txt << std::endl; + } + }else{ + INFO_MSG("Received unknown PPID datachannel message: %" PRIu32, ppid); + if (packetLog.is_open()){ + packetLog << "Received SCTP data (" << len << "b):" << std::endl; + for (unsigned int i = 0; i < len; ++i){ + if (!(i % 32)){packetLog << std::endl;} + packetLog << std::hex << std::setw(2) << std::setfill('0') + << (unsigned int)(data[i]) << " "; + if ((i % 4) == 3){packetLog << " ";} + } + packetLog << std::dec << std::endl; + } + } } void OutWebRTC::onDTSCConverterHasPacket(const DTSC::Packet &pkt){ @@ -1465,35 +1805,34 @@ namespace Mist{ void OutWebRTC::onRTPPacketizerHasRTPPacket(const char *data, size_t nbytes){ rtpOutBuffer.allocate(nbytes + 256); - rtpOutBuffer.assign(data, nbytes); - int protectedSize = nbytes; - - if (doDTLS){ - if (srtpWriter.protectRtp((uint8_t *)(void *)rtpOutBuffer, &protectedSize) != 0){ - ERROR_MSG("Failed to protect the RTP message."); - return; + for (std::set::iterator it = rtpSockets.begin(); it != rtpSockets.end(); ++it){ + if (!*(sockets[*it].udpSock)){continue;} + rtpOutBuffer.assign(data, nbytes); + int protectedSize = nbytes; + if (doDTLS){ + if (sockets[*it].srtpWriter.protectRtp((uint8_t *)(void *)rtpOutBuffer, &protectedSize) != 0){ + ERROR_MSG("Failed to protect the RTP message."); + return; + } } - } - udp.sendPaced(rtpOutBuffer, (size_t)protectedSize); - - RTP::Packet tmpPkt(rtpOutBuffer, protectedSize); - uint32_t pSSRC = tmpPkt.getSSRC(); - uint16_t seq = tmpPkt.getSequence(); - outBuffers[pSSRC].assign(seq, rtpOutBuffer, protectedSize); - myConn.addUp(protectedSize); - totalPkts++; - - if (volkswagenMode){ - if (srtpWriter.protectRtp((uint8_t *)(void *)rtpOutBuffer, &protectedSize) != 0){ - ERROR_MSG("Failed to protect the RTP message."); - return; + + sockets[*it].udpSock->sendPaced(rtpOutBuffer, (size_t)protectedSize, false); + myConn.addUp(protectedSize); + RTP::Packet tmpPkt(rtpOutBuffer, protectedSize); + uint32_t pSSRC = tmpPkt.getSSRC(); + uint16_t seq = tmpPkt.getSequence(); + if (packetLog.is_open()){ + packetLog << "[" << Util::bootMS() << "]" << "Sending RTP packet #" << seq << " to socket " << sockets[*it].udpSock->getSock() << std::endl; } + sockets[*it].outBuffers[pSSRC].assign(seq, rtpOutBuffer, protectedSize); + totalPkts++; + + if (volkswagenMode){sockets[*it].srtpWriter.protectRtp((uint8_t *)(void *)rtpOutBuffer, &protectedSize);} } } void OutWebRTC::onRTPPacketizerHasRTCPPacket(const char *data, uint32_t nbytes){ - if (nbytes > 2048){ FAIL_MSG("The received RTCP packet is too big to handle."); return; @@ -1502,45 +1841,27 @@ namespace Mist{ FAIL_MSG("Invalid RTCP packet given."); return; } - - rtpOutBuffer.allocate(nbytes + 256); - rtpOutBuffer.assign(data, nbytes); - int rtcpPacketSize = nbytes; - - if (doDTLS){ - if (srtpWriter.protectRtcp((uint8_t *)(void *)rtpOutBuffer, &rtcpPacketSize) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } + for (std::set::iterator it = rtpSockets.begin(); it != rtpSockets.end(); ++it){ + if (!*(sockets[*it].udpSock)){continue;} + myConn.addUp(sockets[*it].sendRTCP(data, nbytes)); } - - udp.sendPaced(rtpOutBuffer, rtcpPacketSize); - myConn.addUp(rtcpPacketSize); - - if (volkswagenMode){ - if (srtpWriter.protectRtcp((uint8_t *)(void *)rtpOutBuffer, &rtcpPacketSize) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } - } - } void OutWebRTC::sendNext(){ HTTPOutput::sendNext(); // first make sure that we complete the DTLS handshake. if(doDTLS){ - while (keepGoing() && !dtlsHandshake.hasKeyingMaterial()){ - if (!handleWebRTCInputOutput()){udp.sendPaced(10);}else{udp.sendPaced(0);} + while (keepGoing() && !rtpSockets.size()){ + if (!handleWebRTCInputOutput()){sendPaced(10);}else{sendPaced(0);} if (lastRecv < Util::bootMS() - 10000){ - WARN_MSG("Killing idle connection in handshake phase"); + INFO_MSG("Killing idle connection in handshake phase"); onFail("idle connection in handshake phase", false); return; } } } if (lastRecv < Util::bootMS() - 10000){ - WARN_MSG("Killing idle connection"); + INFO_MSG("Killing idle connection"); onFail("idle connection", false); return; } @@ -1571,6 +1892,55 @@ namespace Mist{ size_t dataLen = 0; thisPacket.getString("data", dataPointer, dataLen); + + if (M.getType(thisIdx) == "meta"){ + JSON::Value jPack; + if (M.getCodec(thisIdx) == "JSON"){ + if (dataLen == 0 || (dataLen == 1 && dataPointer[0] == ' ')){return;} + jPack["data"] = JSON::fromString(dataPointer, dataLen); + jPack["time"] = thisTime; + jPack["track"] = (uint64_t)thisIdx; + }else if (M.getCodec(thisIdx) == "subtitle"){ + //Ignore blank subtitles + if (dataLen == 0 || (dataLen == 1 && dataPointer[0] == ' ')){return;} + + //Get duration, or calculate if missing + uint64_t duration = thisPacket.getInt("duration"); + if (!duration){duration = dataLen * 75 + 800;} + + //Build JSON data to transmit + jPack["duration"] = duration; + jPack["time"] = thisTime; + jPack["track"] = (uint64_t)thisIdx; + jPack["data"] = std::string(dataPointer, dataLen); + }else{ + jPack = thisPacket.toJSON(); + jPack.removeMember("bpos"); + jPack["generic_converter_used"] = true; + } + std::string packed = jPack.toString(); + + if (dataChannels.count(M.getCodec(thisIdx))){ + sctp_sndinfo sndinfo; + sndinfo.snd_sid = dataChannels[M.getCodec(thisIdx)]; + sndinfo.snd_flags = SCTP_EOR; + sndinfo.snd_ppid = htonl(51); + sndinfo.snd_context = 0; + sndinfo.snd_assoc_id = 0; + int ret = usrsctp_sendv(sctp_sock, packed.data(), packed.size(), NULL, 0, (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO, 0); + if (ret < 0){ + WARN_MSG("Could not send data channel message: %s", strerror(errno)); + } + }else{ + if (M.getCodec(thisIdx) == "JSON"){ + queuedJSON.push_back(packed); + }else{ + WARN_MSG("I don't have a data channel for %s data!", M.getCodec(thisIdx).c_str()); + } + } + return; + } + // make sure the webrtcTracks were setup correctly for output. uint32_t tid = thisIdx; @@ -1631,7 +2001,7 @@ namespace Mist{ if (repeatInit && isKeyFrame){sendSPSPPS(thisIdx, rtcTrack);} } - rtcTrack.rtpPacketizer.sendData(&udp, onRTPPacketizerHasDataCallback, dataPointer, dataLen, + rtcTrack.rtpPacketizer.sendData(0, onRTPPacketizerHasDataCallback, dataPointer, dataLen, rtcTrack.payloadType, M.getCodec(thisIdx)); //Trigger a re-send of the Sender Report for every track every ~250ms @@ -1644,7 +2014,7 @@ namespace Mist{ //If this track hasn't sent yet, actually sent if (mustSendSR.count(thisIdx)){ mustSendSR.erase(thisIdx); - rtcTrack.rtpPacketizer.sendRTCP_SR((void *)&udp, 0, onRTPPacketizerHasRTCPDataCallback); + rtcTrack.rtpPacketizer.sendRTCP_SR(0, 0, onRTPPacketizerHasRTCPDataCallback); } } @@ -1740,21 +2110,9 @@ namespace Mist{ size_t trailer_space = SRTP_MAX_TRAILER_LEN + 4; for (size_t i = 0; i < trailer_space; ++i){buffer.push_back(0x00);} - if (doDTLS){ - if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } - } - - udp.sendPaced((const char *)&buffer[0], buffer_size_in_bytes); - myConn.addUp(buffer_size_in_bytes); - - if (volkswagenMode){ - if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } + for (std::set::iterator it = rtpSockets.begin(); it != rtpSockets.end(); ++it){ + if (!*(sockets[*it].udpSock)){continue;} + myConn.addUp(sockets[*it].sendRTCP((const char *)&buffer[0], buffer_size_in_bytes)); } } @@ -1781,21 +2139,9 @@ namespace Mist{ size_t trailer_space = SRTP_MAX_TRAILER_LEN + 4; for (size_t i = 0; i < trailer_space; ++i){buffer.push_back(0x00);} - if (doDTLS){ - if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } - } - - udp.sendPaced((const char *)&buffer[0], buffer_size_in_bytes); - myConn.addUp(buffer_size_in_bytes); - - if (volkswagenMode){ - if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } + for (std::set::iterator it = rtpSockets.begin(); it != rtpSockets.end(); ++it){ + if (!*(sockets[*it].udpSock)){continue;} + myConn.addUp(sockets[*it].sendRTCP((const char *)&buffer[0], buffer_size_in_bytes)); } } @@ -1832,21 +2178,9 @@ namespace Mist{ size_t trailer_space = SRTP_MAX_TRAILER_LEN + 4; for (size_t i = 0; i < trailer_space; ++i){buffer.push_back(0x00);} - if (doDTLS){ - if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } - } - - udp.sendPaced((const char *)&buffer[0], buffer_size_in_bytes); - myConn.addUp(buffer_size_in_bytes); - - if (volkswagenMode){ - if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0){ - ERROR_MSG("Failed to protect the RTCP message."); - return; - } + for (std::set::iterator it = rtpSockets.begin(); it != rtpSockets.end(); ++it){ + if (!*(sockets[*it].udpSock)){continue;} + myConn.addUp(sockets[*it].sendRTCP((const char *)&buffer[0], buffer_size_in_bytes)); } } @@ -1908,7 +2242,7 @@ namespace Mist{ if (packetLog.is_open()){ packetLog << "[" << Util::bootMS() << "] Receiver Report (" << rtcTrack.rtpToDTSC.codec << "): " << stats_lossperc << " percent loss, " << rtcTrack.sorter.lostTotal << " total lost, " << stats_jitter << " ms jitter" << std::endl; } - ((RTP::FECPacket *)&(rtcTrack.rtpPacketizer))->sendRTCP_RR(rtcTrack.sorter, SSRC, rtcTrack.SSRC, (void *)&udp, onRTPPacketizerHasRTCPDataCallback, (uint32_t)rtcTrack.jitter); + ((RTP::FECPacket *)&(rtcTrack.rtpPacketizer))->sendRTCP_RR(rtcTrack.sorter, SSRC, rtcTrack.SSRC, 0, onRTPPacketizerHasRTCPDataCallback, (uint32_t)rtcTrack.jitter); } void OutWebRTC::sendSPSPPS(size_t dtscIdx, WebRTCTrack &rtcTrack){ @@ -1936,7 +2270,7 @@ namespace Mist{ *(uint32_t *)&buf[0] = htonl(len); std::copy(avcc.getSPS(i), avcc.getSPS(i) + avcc.getSPSLen(i), std::back_inserter(buf)); - rtcTrack.rtpPacketizer.sendData(&udp, onRTPPacketizerHasDataCallback, &buf[0], buf.size(), + rtcTrack.rtpPacketizer.sendData(0, onRTPPacketizerHasDataCallback, &buf[0], buf.size(), rtcTrack.payloadType, M.getCodec(dtscIdx)); } @@ -1954,7 +2288,7 @@ namespace Mist{ *(uint32_t *)&buf[0] = htonl(len); std::copy(avcc.getPPS(i), avcc.getPPS(i) + avcc.getPPSLen(i), std::back_inserter(buf)); - rtcTrack.rtpPacketizer.sendData(&udp, onRTPPacketizerHasDataCallback, &buf[0], buf.size(), + rtcTrack.rtpPacketizer.sendData(0, onRTPPacketizerHasDataCallback, &buf[0], buf.size(), rtcTrack.payloadType, M.getCodec(dtscIdx)); } } @@ -1972,14 +2306,6 @@ namespace Mist{ classPointer->handleWebRTCInputOutputFromThread(); } - static int onDTLSHandshakeWantsToWriteCallback(const uint8_t *data, int *nbytes){ - if (!classPointer){ - FAIL_MSG("Requested to send DTLS handshake data but the `classPointer` hasn't been set."); - return -1; - } - return classPointer->onDTLSHandshakeWantsToWrite(data, nbytes); - } - static void onRTPSorterHasPacketCallback(const uint64_t track, const RTP::Packet &p){ if (!classPointer){ FAIL_MSG("We received a sorted RTP packet but our `classPointer` is invalid."); diff --git a/src/output/output_webrtc.h b/src/output/output_webrtc.h index 67f95ae4..68d0fa7b 100644 --- a/src/output/output_webrtc.h +++ b/src/output/output_webrtc.h @@ -3,7 +3,6 @@ #include "output.h" #include "output_http.h" #include -#include #include #include #include @@ -14,6 +13,7 @@ #include #include #include "output_webrtc_srtp.h" +#include #define NACK_BUFFER_SIZE 1024 @@ -67,7 +67,19 @@ namespace Mist{ double jitter; }; - /* ------------------------------------------------ */ + class WebRTCSocket{ + public: + WebRTCSocket(); + Socket::UDPConnection* udpSock; + SRTPReader srtpReader; ///< Used to unprotect incoming RTP and RTCP data. Uses the keys that + ///< were exchanged with DTLS. + SRTPWriter srtpWriter; ///< Used to protect our RTP and RTCP data when sending data to another + ///< peer. Uses the keys that were exchanged with DTLS. + std::map outBuffers; + size_t sendRTCP(const char * data, size_t len); + size_t ackNACK(uint32_t pSSRC, uint16_t seq); + Util::ResizeablePointer dataBuffer; + }; class OutWebRTC : public HTTPOutput{ public: @@ -82,10 +94,13 @@ namespace Mist{ virtual bool dropPushTrack(uint32_t trackId, const std::string & dropReason); void handleWebsocketIdle(); virtual void onFail(const std::string &msg, bool critical = false); - bool onFinish(); bool doesWebsockets(){return true;} void handleWebRTCInputOutputFromThread(); - int onDTLSHandshakeWantsToWrite(const uint8_t *data, int *nbytes); + bool handleUDPSocket(Socket::UDPConnection & sock); + bool handleUDPSocket(WebRTCSocket & wSock); + void sendSCTPPacket(const char * data, size_t len); + void sendPaced(uint64_t uSendWindow); + void onSCTP(const char * data, size_t len, uint16_t stream, uint32_t ppid); void onRTPSorterHasPacket(size_t tid, const RTP::Packet &pkt); void onDTSCConverterHasPacket(const DTSC::Packet &pkt); void onDTSCConverterHasInitData(const size_t trackID, const std::string &initData); @@ -95,7 +110,7 @@ namespace Mist{ inline virtual bool keepGoing(){return config->is_active && (noSignalling || myConn);} virtual void requestHandler(); protected: - virtual void idleTime(uint64_t ms){udp.sendPaced(ms*1000);} + virtual void idleTime(uint64_t ms){sendPaced(ms*1000);} private: bool noSignalling; uint64_t lastRecv; @@ -109,9 +124,8 @@ namespace Mist{ void ackNACK(uint32_t SSRC, uint16_t seq); bool handleWebRTCInputOutput(); ///< Reads data from the UDP socket. Returns true when we read ///< some data, othewise false. - void handleReceivedSTUNPacket(); - void handleReceivedDTLSPacket(); - void handleReceivedRTPOrRTCPPacket(); + void handleReceivedSTUNPacket(WebRTCSocket &wSock); + void handleReceivedRTPOrRTCPPacket(WebRTCSocket &wSock); bool handleSignalingCommandRemoteOfferForInput(SDP::Session &sdpSession); bool handleSignalingCommandRemoteOfferForOutput(SDP::Session &sdpSession); void sendSignalingError(const std::string &commandType, const std::string &errorMessage); @@ -136,20 +150,19 @@ namespace Mist{ SDP::Session sdp; ///< SDP parser. SDP::Answer sdpAnswer; ///< WIP: Replacing our `sdp` member .. Certificate cert; ///< The TLS certificate. Used to generate a fingerprint in SDP answers. - DTLSSRTPHandshake dtlsHandshake; ///< Implements the DTLS handshake using the mbedtls library (fork). - SRTPReader srtpReader; ///< Used to unprotect incoming RTP and RTCP data. Uses the keys that - ///< were exchanged with DTLS. - SRTPWriter srtpWriter; ///< Used to protect our RTP and RTCP data when sending data to another - ///< peer. Uses the keys that were exchanged with DTLS. - Socket::UDPConnection udp; ///< Our UDP socket over which WebRTC data is received and sent. + Socket::UDPConnection mainSocket; //< Main socket created during the initial handshake + std::map sockets; ///< UDP sockets over which WebRTC data is received and sent. + std::set rtpSockets; ///< UDP sockets over which (S)RTP data is transmitted/received + std::set sctpSockets; ///< UDP sockets over which (S)RTP data is transmitted/received + uint16_t lastMediaSocket; //< Last socket number we received video/audio on + uint16_t lastMetaSocket; //< Last socket number we received non-media data on + uint16_t udpPort; ///< Port where we receive RTP, STUN, DTLS, etc. StunReader stunReader; ///< Decodes STUN messages; during a session we keep receiving STUN ///< messages to which we need to reply. std::map webrtcTracks; ///< WebRTCTracks indexed by payload type for incoming data and indexed by ///< myMeta.tracks[].trackID for outgoing data. tthread::thread *webRTCInputOutputThread; ///< The thread in which we read WebRTC data when ///< we're receive media from another peer. - uint16_t udpPort; ///< The port on which our webrtc socket is bound. This is where we receive - ///< RTP, STUN, DTLS, etc. */ uint32_t SSRC; ///< The SSRC for this local instance. Is used when generating RTCP reports. */ uint64_t rtcpTimeoutInMillis; ///< When current time in millis exceeds this timeout we have to ///< send a new RTCP packet. @@ -161,18 +174,15 @@ namespace Mist{ ///< the signaling channel. Defaults to 6mbit. uint32_t videoConstraint; - size_t audTrack, vidTrack, prevVidTrack; + size_t audTrack, vidTrack, prevVidTrack, metaTrack; double target_rate; ///< Target playback speed rate (1.0 = normal, 0 = auto) - bool didReceiveKeyFrame; /* TODO burst delay */ + bool didReceiveKeyFrame; bool setPacketOffset; int64_t packetOffset; ///< For timestamp rewrite with BMO uint64_t lastTimeSync; bool firstKey; bool repeatInit; - bool stayLive; - bool doDTLS; - bool volkswagenMode; double stats_jitter; uint64_t stats_nacknum; @@ -191,13 +201,17 @@ namespace Mist{ std::map payloadTypeToWebRTCTrack; ///< Maps e.g. RED to the corresponding track. Used when input ///< supports RED/ULPFEC; can also be used to map RTX in the ///< future. - std::map outBuffers; - uint64_t lastSR; std::set mustSendSR; int64_t ntpClockDifference; bool syncedNTPClock; + + bool sctpInited; + bool sctpConnected; + struct socket * sctp_sock; + std::map dataChannels; + std::deque queuedJSON; }; }// namespace Mist diff --git a/src/output/output_webrtc_srtp.cpp b/src/output/output_webrtc_srtp.cpp index 53c9ff45..aa0c5bc7 100644 --- a/src/output/output_webrtc_srtp.cpp +++ b/src/output/output_webrtc_srtp.cpp @@ -14,6 +14,10 @@ SRTPReader::SRTPReader(){ memset((void *)&policy, 0x00, sizeof(policy)); } +SRTPReader::~SRTPReader(){ + if (shutdown() != 0){FAIL_MSG("Failed to cleanly shutdown the srtp reader.");} +} + /* Before initializing the srtp library we shut it down first because initializing the library twice results in an error. @@ -203,6 +207,11 @@ SRTPWriter::SRTPWriter(){ memset((void *)&policy, 0x00, sizeof(policy)); } +SRTPWriter::~SRTPWriter(){ + if (shutdown() != 0){FAIL_MSG("Failed to cleanly shutdown the srtp writer.");} +} + + /* Before initializing the srtp library we shut it down first because initializing the library twice results in an error. diff --git a/src/output/output_webrtc_srtp.h b/src/output/output_webrtc_srtp.h index e92a44d9..90de3b51 100644 --- a/src/output/output_webrtc_srtp.h +++ b/src/output/output_webrtc_srtp.h @@ -14,6 +14,7 @@ class SRTPReader{ public: SRTPReader(); + ~SRTPReader(); int init(const std::string &cipher, const std::string &key, const std::string &salt); int shutdown(); int unprotectRtp(uint8_t *data, int *nbytes); /* `nbytes` should contain the number of bytes in `data`. On success `nbytes` @@ -32,6 +33,7 @@ private: class SRTPWriter{ public: SRTPWriter(); + ~SRTPWriter(); int init(const std::string &cipher, const std::string &key, const std::string &salt); int shutdown(); int protectRtp(uint8_t *data, int *nbytes); diff --git a/src/session.cpp b/src/session.cpp index 016e12e5..1d479d27 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -67,23 +67,23 @@ void userOnActive(Comms::Connections &connections, size_t idx){ } // Sanity checks if (connections.getDown(idx) < connDown[idx]){ - WARN_MSG("Connection downloaded bytes should be a counter, but has decreased in value"); + MEDIUM_MSG("Connection downloaded bytes should be a counter, but has decreased in value"); connDown[idx] = connections.getDown(idx); } if (connections.getUp(idx) < connUp[idx]){ - WARN_MSG("Connection uploaded bytes should be a counter, but has decreased in value"); + MEDIUM_MSG("Connection uploaded bytes should be a counter, but has decreased in value"); connUp[idx] = connections.getUp(idx); } if (connections.getPacketCount(idx) < connPktcount[idx]){ - WARN_MSG("Connection packet count should be a counter, but has decreased in value"); + MEDIUM_MSG("Connection packet count should be a counter, but has decreased in value"); connPktcount[idx] = connections.getPacketCount(idx); } if (connections.getPacketLostCount(idx) < connPktloss[idx]){ - WARN_MSG("Connection packet loss count should be a counter, but has decreased in value"); + MEDIUM_MSG("Connection packet loss count should be a counter, but has decreased in value"); connPktloss[idx] = connections.getPacketLostCount(idx); } if (connections.getPacketRetransmitCount(idx) < connPktretrans[idx]){ - WARN_MSG("Connection packets retransmitted should be a counter, but has decreased in value"); + MEDIUM_MSG("Connection packets retransmitted should be a counter, but has decreased in value"); connPktretrans[idx] = connections.getPacketRetransmitCount(idx); } // Add increase in stats to global stats @@ -218,7 +218,7 @@ int main(int argc, char **argv){ return 0; } } - + // Claim a spot in shared memory for this session on the global statistics page sessions.reload(); if (!sessions){ diff --git a/subprojects/usrsctp.wrap b/subprojects/usrsctp.wrap new file mode 100644 index 00000000..327cb901 --- /dev/null +++ b/subprojects/usrsctp.wrap @@ -0,0 +1,5 @@ +[wrap-git] +url = https://github.com/sctplab/usrsctp.git +revision = 0.9.5.0 +depth = 1 +