From 7e8eb634e652d18061098bd28ee367ea32f77ac3 Mon Sep 17 00:00:00 2001 From: roxlu Date: Fri, 1 Jun 2018 09:19:32 +0200 Subject: [PATCH] Implemented WebRTC --- CMakeLists.txt | 15 +- lib/certificate.cpp | 240 +++ lib/certificate.h | 34 + lib/dtls_srtp_handshake.cpp | 420 +++++ lib/dtls_srtp_handshake.h | 59 + lib/rtp_fec.cpp | 569 ++++++ lib/rtp_fec.h | 100 ++ lib/sdp_media.cpp | 1184 +++++++++++++ lib/sdp_media.h | 220 +++ lib/srtp.cpp | 422 +++++ lib/srtp.h | 43 + lib/stun.cpp | 1051 ++++++++++++ lib/stun.h | 250 +++ scripts/webrtc_compile.sh | 52 + .../webrtc_mbedtls_keying_material_fix.diff | 34 + scripts/webrtc_run.sh | 30 + scripts/webrtc_srtp_cmakelists.txt | 112 ++ scripts/webrtc_srtp_config.cmake | 181 ++ src/output/output_webrtc.cpp | 1524 +++++++++++++++++ src/output/output_webrtc.h | 173 ++ 20 files changed, 6712 insertions(+), 1 deletion(-) create mode 100644 lib/certificate.cpp create mode 100644 lib/certificate.h create mode 100644 lib/dtls_srtp_handshake.cpp create mode 100644 lib/dtls_srtp_handshake.h create mode 100644 lib/rtp_fec.cpp create mode 100644 lib/rtp_fec.h create mode 100644 lib/sdp_media.cpp create mode 100644 lib/sdp_media.h create mode 100644 lib/srtp.cpp create mode 100644 lib/srtp.h create mode 100644 lib/stun.cpp create mode 100644 lib/stun.h create mode 100755 scripts/webrtc_compile.sh create mode 100644 scripts/webrtc_mbedtls_keying_material_fix.diff create mode 100755 scripts/webrtc_run.sh create mode 100644 scripts/webrtc_srtp_cmakelists.txt create mode 100644 scripts/webrtc_srtp_config.cmake create mode 100644 src/output/output_webrtc.cpp create mode 100644 src/output/output_webrtc.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f454022d..68964ff1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,9 +143,11 @@ set(libHeaders lib/encode.h lib/bitfields.h lib/bitstream.h + lib/certificate.h lib/checksum.h lib/config.h lib/defines.h + lib/dtls_srtp_handshake.h lib/dtsc.h lib/encryption.h lib/flv_tag.h @@ -168,11 +170,15 @@ set(libHeaders lib/procs.h lib/rijndael.h lib/rtmpchunks.h + lib/rtp_fec.h lib/rtp.h lib/sdp.h + lib/sdp_media.h lib/shared_memory.h lib/socket.h + lib/srtp.h lib/stream.h + lib/stun.h lib/theora.h lib/timing.h lib/tinythread.h @@ -201,7 +207,9 @@ add_library (mist lib/encode.cpp lib/bitfields.cpp lib/bitstream.cpp + lib/certificate.cpp lib/config.cpp + lib/dtls_srtp_handshake.cpp lib/dtsc.cpp lib/dtscmeta.cpp lib/encryption.cpp @@ -224,11 +232,15 @@ add_library (mist lib/procs.cpp lib/rijndael.cpp lib/rtmpchunks.cpp + lib/rtp_fec.cpp lib/rtp.cpp lib/sdp.cpp + lib/sdp_media.cpp lib/shared_memory.cpp lib/socket.cpp + lib/srtp.cpp lib/stream.cpp + lib/stun.cpp lib/theora.cpp lib/timing.cpp lib/tinythread.cpp @@ -253,7 +265,7 @@ target_link_libraries(mist ${LIBRT} ) if (NOT DEFINED NOSSL ) - target_link_libraries(mist mbedtls mbedx509 mbedcrypto) + target_link_libraries(mist mbedtls mbedx509 mbedcrypto srtp2) endif() install( FILES ${libHeaders} @@ -460,6 +472,7 @@ makeOutput(EBML ebml) makeOutput(Push push)#LTS makeOutput(RTSP rtsp)#LTS makeOutput(WAV wav)#LTS +makeOutput(WebRTC webrtc http)#LTS if (NOT DEFINED NOSSL ) makeOutput(HTTPS https)#LTS endif() diff --git a/lib/certificate.cpp b/lib/certificate.cpp new file mode 100644 index 00000000..01b7e801 --- /dev/null +++ b/lib/certificate.cpp @@ -0,0 +1,240 @@ +#include "certificate.h" +#include "defines.h" + +Certificate::Certificate() + :rsa_ctx(NULL) +{ + memset((void*)&cert, 0x00, sizeof(cert)); + memset((void*)&key, 0x00, sizeof(key)); +} + +int Certificate::init(const std::string &countryName, + const std::string &organization, + const std::string& commonName) +{ + + mbedtls_ctr_drbg_context rand_ctx = {}; + mbedtls_entropy_context entropy_ctx = {}; + mbedtls_x509write_cert write_cert = {}; + + const char* personalisation = "mbedtls-self-signed-key"; + std::string subject_name = "C=" +countryName +",O=" +organization +",CN=" +commonName; + time_t time_from = { 0 }; + time_t time_to = { 0 }; + char time_from_str[20] = { 0 }; + char time_to_str[20] = { 0 }; + mbedtls_mpi serial_mpi = { 0 }; + char serial_hex[17] = { 0 }; + uint64_t serial_num = 0; + uint8_t* serial_ptr = (uint8_t*)&serial_num; + int r = 0; + int i = 0; + uint8_t buf[4096] = { 0 }; + + // validate + if (countryName.empty()) { + FAIL_MSG("Given `countryName`, C=, is empty."); + r = -1; + goto error; + } + if (organization.empty()) { + FAIL_MSG("Given `organization`, O=, is empty."); + r = -2; + goto error; + } + if (commonName.empty()) { + FAIL_MSG("Given `commonName`, CN=, is empty."); + r = -3; + goto error; + } + + // initialize random number generator + mbedtls_ctr_drbg_init(&rand_ctx); + mbedtls_entropy_init(&entropy_ctx); + r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, (const unsigned char*)personalisation, strlen(personalisation)); + if (0 != r) { + FAIL_MSG("Failed to initialize and seed the entropy context."); + r = -10; + goto error; + } + + // initialize the public key context + mbedtls_pk_init(&key); + r = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); + if (0 != r) { + FAIL_MSG("Faild to initialize the PK context."); + r = -20; + goto error; + } + + rsa_ctx = mbedtls_pk_rsa(key); + if (NULL == rsa_ctx) { + FAIL_MSG("Failed to get the RSA context from from the public key context (key)."); + r = -30; + goto error; + } + + r = mbedtls_rsa_gen_key(rsa_ctx, mbedtls_ctr_drbg_random, &rand_ctx, 2048, 65537); + if (0 != r) { + FAIL_MSG("Failed to generate a private key."); + r = -40; + goto error; + } + + // calc the valid from and until time. + time_from = time(NULL); + time_from = (time_from < 1000000000) ? 1000000000 : time_from; + time_to = time_from + (60 * 60 * 24 * 365); // valid for a year + + if (time_to < time_from) { + time_to = INT_MAX; + } + + r = strftime(time_from_str, sizeof(time_from_str), "%Y%m%d%H%M%S", gmtime(&time_from)); + if (0 == r) { + FAIL_MSG("Failed to generate the valid-from time string."); + r = -50; + goto error; + } + + r = strftime(time_to_str, sizeof(time_to_str), "%Y%m%d%H%M%S", gmtime(&time_to)); + if (0 == r) { + FAIL_MSG("Failed to generate the valid-to time string."); + r = -60; + goto error; + } + + r = mbedtls_ctr_drbg_random((void*)&rand_ctx, (uint8_t*)&serial_num, sizeof(serial_num)); + if (0 != r) { + FAIL_MSG("Failed to generate a random u64."); + r = -70; + goto error; + } + + for (i = 0; i < 8; ++i) { + sprintf(serial_hex + (i * 2), "%02x", serial_ptr[i]); + } + + // start creating the certificate + mbedtls_x509write_crt_init(&write_cert); + mbedtls_x509write_crt_set_md_alg(&write_cert, MBEDTLS_MD_SHA256); + mbedtls_x509write_crt_set_issuer_key(&write_cert, &key); + mbedtls_x509write_crt_set_subject_key(&write_cert, &key); + + r = mbedtls_x509write_crt_set_subject_name(&write_cert, subject_name.c_str()); + if (0 != r) { + FAIL_MSG("Failed to set the subject name."); + r = -80; + goto error; + } + + r = mbedtls_x509write_crt_set_issuer_name(&write_cert, subject_name.c_str()); + if (0 != r) { + FAIL_MSG("Failed to set the issuer name."); + r = -90; + goto error; + } + + r = mbedtls_x509write_crt_set_validity(&write_cert, time_from_str, time_to_str); + if (0 != r) { + FAIL_MSG("Failed to set the x509 validity string."); + r = -100; + goto error; + } + + r = mbedtls_x509write_crt_set_basic_constraints(&write_cert, 0, -1); + if (0 != r) { + FAIL_MSG("Failed ot set the basic constraints for the certificate."); + r = -110; + goto error; + } + + r = mbedtls_x509write_crt_set_subject_key_identifier(&write_cert); + if (0 != r) { + FAIL_MSG("Failed to set the subjectKeyIdentifier."); + r = -120; + goto error; + } + + r = mbedtls_x509write_crt_set_authority_key_identifier(&write_cert); + if (0 != r) { + FAIL_MSG("Failed to set the authorityKeyIdentifier."); + r = -130; + goto error; + } + + // set certificate serial; mpi is used to perform i/o + mbedtls_mpi_init(&serial_mpi); + mbedtls_mpi_read_string(&serial_mpi, 16, serial_hex); + r = mbedtls_x509write_crt_set_serial(&write_cert, &serial_mpi); + if (0 != r) { + FAIL_MSG("Failed to set the certificate serial."); + r = -140; + goto error; + } + + // write the certificate into a PEM structure + r = mbedtls_x509write_crt_pem(&write_cert, buf, sizeof(buf), mbedtls_ctr_drbg_random, &rand_ctx); + if (0 != r) { + FAIL_MSG("Failed to create the PEM data from the x509 write structure."); + r = -150; + goto error; + } + + // convert the PEM data into out `mbedtls_x509_cert` member. + // len should be PEM including the string null terminating + // char. @todo there must be a way to convert the write + // struct into a `mbedtls_x509_cert` w/o calling this parse + // function. + mbedtls_x509_crt_init(&cert); + + r = mbedtls_x509_crt_parse(&cert, (const unsigned char*)buf, strlen((char*)buf) + 1); + if (0 != r) { + mbedtls_strerror(r, (char*)buf, sizeof(buf)); + FAIL_MSG("Failed to convert the mbedtls_x509write_crt into a mbedtls_x509_crt: %s", buf); + r = -160; + goto error; + } + + error: + + // cleanup + mbedtls_ctr_drbg_free(&rand_ctx); + mbedtls_entropy_free(&entropy_ctx); + mbedtls_x509write_crt_free(&write_cert); + mbedtls_mpi_free(&serial_mpi); + + if (r < 0) { + shutdown(); + } + + return r; +} + +int Certificate::shutdown() { + rsa_ctx = NULL; + mbedtls_pk_free(&key); + mbedtls_x509_crt_free(&cert); + return 0; +} + +std::string Certificate::getFingerprintSha256() { + + uint8_t fingerprint_raw[32] = {}; + uint8_t fingerprint_hex[128] = {}; + mbedtls_md_type_t hash_type = MBEDTLS_MD_SHA256; + + mbedtls_sha256(cert.raw.p, cert.raw.len, fingerprint_raw, 0); + + for (int i = 0; i < 32; ++i) { + sprintf((char*)(fingerprint_hex + (i * 3)), ":%02X", (int)fingerprint_raw[i]); + } + + fingerprint_hex[32 * 3] = '\0'; + + std::string result = std::string((char*)fingerprint_hex + 1, (32 * 3) - 1); + return result; +} + + + diff --git a/lib/certificate.h b/lib/certificate.h new file mode 100644 index 00000000..24a7a2fa --- /dev/null +++ b/lib/certificate.h @@ -0,0 +1,34 @@ +#pragma once +/* + + MBEDTLS BASED CERTIFICATE + ========================= + + This class can be used to generate a self-signed x509 + certificate which enables you to perform secure + communication. This certificate uses a 2048 bits RSA key. + + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class Certificate { +public: + Certificate(); + int init(const std::string &countryName, const std::string &organization, const std::string& commonName); + int shutdown(); + std::string getFingerprintSha256(); + +public: + mbedtls_x509_crt cert; + mbedtls_pk_context key; /* key context, stores private and public key. */ + mbedtls_rsa_context* rsa_ctx; /* rsa context, stored in key_ctx */ +}; diff --git a/lib/dtls_srtp_handshake.cpp b/lib/dtls_srtp_handshake.cpp new file mode 100644 index 00000000..fe35fffe --- /dev/null +++ b/lib/dtls_srtp_handshake.cpp @@ -0,0 +1,420 @@ +#include +#include "defines.h" +#include "dtls_srtp_handshake.h" + +/* Write mbedtls into a log file. */ +#define LOG_TO_FILE 0 +#if LOG_TO_FILE +# include +#endif + +/* ----------------------------------------- */ + +static void print_mbedtls_error(int r); +static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str); +static int on_mbedtls_wants_to_read(void* user, unsigned char* buf, size_t len); /* Called when mbedtls wants to read data from e.g. a socket. */ +static int on_mbedtls_wants_to_write(void* user, const unsigned char* buf, size_t len); /* Called when mbedtls wants to write data to e.g. a socket. */ +static std::string mbedtls_err_to_string(int r); + +/* ----------------------------------------- */ + +DTLSSRTPHandshake::DTLSSRTPHandshake() + :write_callback(NULL) + ,cert(NULL) + ,key(NULL) +{ + memset((void*)&entropy_ctx, 0x00, sizeof(entropy_ctx)); + memset((void*)&rand_ctx, 0x00, sizeof(rand_ctx)); + memset((void*)&ssl_ctx, 0x00, sizeof(ssl_ctx)); + memset((void*)&ssl_conf, 0x00, sizeof(ssl_conf)); + memset((void*)&cookie_ctx, 0x00, sizeof(cookie_ctx)); + memset((void*)&timer_ctx, 0x00, sizeof(timer_ctx)); +} + +int DTLSSRTPHandshake::init(mbedtls_x509_crt* certificate, + mbedtls_pk_context* privateKey, + int(*writeCallback)(const uint8_t* data, int* nbytes) +) +{ + + int r = 0; + mbedtls_ssl_srtp_profile srtp_profiles[] = { + MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80, + MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32 + }; + + if (!writeCallback) { + FAIL_MSG("No writeCallack function given."); + r = -3; + goto error; + } + + if (!certificate) { + FAIL_MSG("Given certificate is null."); + r = -5; + goto error; + } + + if (!privateKey) { + FAIL_MSG("Given key is null."); + r = -10; + goto error; + } + + cert = certificate; + key = privateKey; + + /* init the contexts */ + mbedtls_entropy_init(&entropy_ctx); + mbedtls_ctr_drbg_init(&rand_ctx); + mbedtls_ssl_init(&ssl_ctx); + mbedtls_ssl_config_init(&ssl_conf); + mbedtls_ssl_cookie_init(&cookie_ctx); + + /* seed and setup the random number generator */ + r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, (const unsigned char*)"mist-srtp", 9); + if (0 != r) { + print_mbedtls_error(r); + r = -20; + goto error; + } + + /* load defaults into our ssl_conf */ + r = mbedtls_ssl_config_defaults(&ssl_conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT); + if (0 != r) { + print_mbedtls_error(r); + r = -30; + goto error; + } + + mbedtls_ssl_conf_authmode(&ssl_conf, MBEDTLS_SSL_VERIFY_NONE); + mbedtls_ssl_conf_rng(&ssl_conf, mbedtls_ctr_drbg_random, &rand_ctx); + mbedtls_ssl_conf_dbg(&ssl_conf, print_mbedtls_debug_message, stdout); + mbedtls_ssl_conf_ca_chain(&ssl_conf, cert, NULL); + mbedtls_ssl_conf_cert_profile(&ssl_conf, &mbedtls_x509_crt_profile_default); + mbedtls_debug_set_threshold(10); + + /* enable SRTP */ + r = mbedtls_ssl_conf_dtls_srtp_protection_profiles(&ssl_conf, srtp_profiles, sizeof(srtp_profiles) / sizeof(srtp_profiles[0])); + if (0 != r) { + print_mbedtls_error(r); + r = -40; + goto error; + } + + /* cert certificate chain + key, so we can verify the client-hello signed data */ + r = mbedtls_ssl_conf_own_cert(&ssl_conf, cert, key); + if (0 != r) { + print_mbedtls_error(r); + r = -50; + goto error; + } + + /* cookie setup (e.g. to prevent ddos). */ + r = mbedtls_ssl_cookie_setup(&cookie_ctx, mbedtls_ctr_drbg_random, &rand_ctx); + if (0 != r) { + print_mbedtls_error(r); + r = -60; + goto error; + } + + /* register callbacks for dtls cookies (server only). */ + mbedtls_ssl_conf_dtls_cookies(&ssl_conf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &cookie_ctx); + + /* setup the ssl context for use. note that ssl_conf will be referenced internall by the context and therefore should be kept around. */ + r = mbedtls_ssl_setup(&ssl_ctx, &ssl_conf); + if (0 != r) { + print_mbedtls_error(r); + r = -70; + goto error; + } + + /* set bio handlers */ + mbedtls_ssl_set_bio(&ssl_ctx, (void*)this, on_mbedtls_wants_to_write, on_mbedtls_wants_to_read, NULL); + + /* set temp id, just adds some exta randomness */ + { + std::string remote_id = "mist"; + r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char*)remote_id.c_str(), remote_id.size()); + if (0 != r) { + print_mbedtls_error(r); + r = -80; + goto error; + } + } + + /* set timer callbacks */ + mbedtls_ssl_set_timer_cb(&ssl_ctx, &timer_ctx, mbedtls_timing_set_delay, mbedtls_timing_get_delay); + + write_callback = writeCallback; + + error: + + if (r < 0) { + shutdown(); + } + + return r; +} + +int DTLSSRTPHandshake::shutdown() { + + /* cleanup the refs from the settings. */ + cert = NULL; + key = NULL; + buffer.clear(); + cipher.clear(); + remote_key.clear(); + remote_salt.clear(); + local_key.clear(); + local_salt.clear(); + + /* free our contexts; we do not free the `settings.cert` and `settings.key` as they are owned by the user of this class. */ + mbedtls_entropy_free(&entropy_ctx); + mbedtls_ctr_drbg_free(&rand_ctx); + mbedtls_ssl_free(&ssl_ctx); + mbedtls_ssl_config_free(&ssl_conf); + mbedtls_ssl_cookie_free(&cookie_ctx); + + return 0; +} + +/* ----------------------------------------- */ + +int DTLSSRTPHandshake::parse(const uint8_t* data, size_t nbytes) { + + if (NULL == data) { + ERROR_MSG("Given `data` is NULL."); + return -1; + } + + if (0 == nbytes) { + ERROR_MSG("Given nbytes is 0."); + return -2; + } + + if (MBEDTLS_SSL_HANDSHAKE_OVER == ssl_ctx.state) { + ERROR_MSG("Already finished the handshake."); + return -3; + } + + /* copy incoming data into a temporary buffer which is read via our `bio` read function. */ + int r = 0; + std::copy(data, data + nbytes, std::back_inserter(buffer)); + + do { + + r = mbedtls_ssl_handshake(&ssl_ctx); + + switch (r) { + /* 0 = handshake done. */ + case 0: { + if (0 != extractKeyingMaterial()) { + ERROR_MSG("Failed to extract keying material after handshake was done."); + return -2; + } + return 0; + } + /* see the dtls server example; this is used to prevent certain attacks (ddos) */ + case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: { + if (0 != resetSession()) { + ERROR_MSG("Failed to reset the session which is necessary when we need to verify the HELLO."); + return -3; + } + break; + } + case MBEDTLS_ERR_SSL_WANT_READ: { + DONTEVEN_MSG("mbedtls wants a bit more data before it can continue parsing the DTLS handshake."); + break; + } + default: { + ERROR_MSG("A serious mbedtls error occured."); + print_mbedtls_error(r); + return -2; + } + } + } + while (MBEDTLS_ERR_SSL_WANT_WRITE == r); + + return 0; +} + +/* ----------------------------------------- */ + +int DTLSSRTPHandshake::resetSession() { + + std::string remote_id = "mist"; /* @todo for now we hardcoded this... */ + int r = 0; + + r = mbedtls_ssl_session_reset(&ssl_ctx); + if (0 != r) { + print_mbedtls_error(r); + return -1; + } + + r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char*)remote_id.c_str(), remote_id.size()); + if (0 != r) { + print_mbedtls_error(r); + return -2; + } + + buffer.clear(); + + return 0; +} + +/* + master key is 128 bits => 16 bytes. + master salt is 112 bits => 14 bytes +*/ +int DTLSSRTPHandshake::extractKeyingMaterial() { + + int r = 0; + uint8_t keying_material[MBEDTLS_DTLS_SRTP_MAX_KEY_MATERIAL_LENGTH] = {}; + size_t keying_material_len = sizeof(keying_material); + + r = mbedtls_ssl_get_dtls_srtp_key_material(&ssl_ctx, keying_material, &keying_material_len); + if (0 != r) { + print_mbedtls_error(r); + return -1; + } + + /* @todo following code is for server mode only */ + mbedtls_ssl_srtp_profile srtp_profile = mbedtls_ssl_get_dtls_srtp_protection_profile(&ssl_ctx); + switch (srtp_profile) { + case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80: { + cipher = "SRTP_AES128_CM_SHA1_80"; + break; + } + case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32: { + cipher = "SRTP_AES128_CM_SHA1_32"; + break; + } + default: { + ERROR_MSG("Unhandled SRTP profile, cannot extract keying material."); + return -6; + } + } + + remote_key.assign((char*)(&keying_material[0]) + 0, 16); + local_key.assign((char*)(&keying_material[0]) + 16, 16); + remote_salt.assign((char*)(&keying_material[0]) + 32, 14); + local_salt.assign((char*)(&keying_material[0]) + 46, 14); + + DONTEVEN_MSG("Extracted the DTLS SRTP keying material with cipher %s.", cipher.c_str()); + DONTEVEN_MSG("Remote DTLS SRTP key size is %zu.", remote_key.size()); + DONTEVEN_MSG("Remote DTLS SRTP salt size is %zu.", remote_salt.size()); + DONTEVEN_MSG("Local DTLS SRTP key size is %zu.", local_key.size()); + DONTEVEN_MSG("Local DTLS SRTP salt size is %zu.", local_salt.size()); + + return 0; +} + +/* ----------------------------------------- */ + +/* + + This function is called by mbedtls whenever it wants to read + some data. The documentation states the following: "For DTLS, + you need to provide either a non-NULL f_recv_timeout + callback, or a f_recv that doesn't block." As this + implementation is completely decoupled from any I/O and uses + a "push" model instead of a "pull" model we have to copy new + input bytes into a temporary buffer (see parse), but we act + as if we were using a non-blocking socket, which means: + + - we return MBETLS_ERR_SSL_WANT_READ when there is no data left to read + - when there is data in our temporary buffer, we read from that + +*/ +static int on_mbedtls_wants_to_read(void* user, unsigned char* buf, size_t len) { + + DTLSSRTPHandshake* hs = static_cast(user); + if (NULL == hs) { + ERROR_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake."); + return -1; + } + + /* figure out how much we can read. */ + if (hs->buffer.size() == 0) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + + size_t nbytes = hs->buffer.size(); + if (nbytes > len) { + nbytes = len; + } + + /* "read" into the given buffer. */ + memcpy(buf, &hs->buffer[0], nbytes); + hs->buffer.erase(hs->buffer.begin(), hs->buffer.begin() + nbytes); + + return (int)nbytes; +} + +static int on_mbedtls_wants_to_write(void* user, const unsigned char* buf, size_t len) { + + DTLSSRTPHandshake* hs = static_cast(user); + if (!hs) { + FAIL_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake."); + return -1; + } + + if (!hs->write_callback) { + FAIL_MSG("The `write_callback` member is NULL."); + return -2; + } + + int nwritten = (int)len; + if (0 != hs->write_callback(buf, &nwritten)) { + FAIL_MSG("Failed to write some DTLS handshake data."); + return -3; + } + + if (nwritten != (int)len) { + FAIL_MSG("The DTLS-SRTP handshake listener MUST write all the data."); + return -4; + } + + return nwritten; +} + +/* ----------------------------------------- */ + +static void print_mbedtls_error(int r) { + char buf[1024] = {}; + mbedtls_strerror(r, buf, sizeof(buf)); + ERROR_MSG("mbedtls error: %s", buf); +} + +static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str) { + DONTEVEN_MSG("%s:%04d: %.*s", file, line, strlen(str) - 1, str); + +#if LOG_TO_FILE + static std::ofstream ofs; + if (!ofs.is_open()) { + ofs.open("mbedtls.log", std::ios::out); + } + if (!ofs.is_open()) { + return; + } + ofs << str; + ofs.flush(); +#endif + +} + +static std::string mbedtls_err_to_string(int r) { + switch (r) { + case MBEDTLS_ERR_SSL_WANT_READ: { return "MBEDTLS_ERR_SSL_WANT_READ"; } + case MBEDTLS_ERR_SSL_WANT_WRITE: { return "MBEDTLS_ERR_SSL_WANT_WRITE"; } + default: { + print_mbedtls_error(r); + return "UNKNOWN"; + } + } +} + +/* ---------------------------------------- */ + + diff --git a/lib/dtls_srtp_handshake.h b/lib/dtls_srtp_handshake.h new file mode 100644 index 00000000..dfb319a9 --- /dev/null +++ b/lib/dtls_srtp_handshake.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* ----------------------------------------- */ + +class DTLSSRTPHandshake { +public: + DTLSSRTPHandshake(); + int init(mbedtls_x509_crt* certificate, mbedtls_pk_context* privateKey, int(*writeCallback)(const uint8_t* data, int* nbytes)); // writeCallback should return 0 on succes < 0 on error. nbytes holds the number of bytes to be sent and needs to be set to the number of bytes actually sent. + int shutdown(); + int parse(const uint8_t* data, size_t nbytes); + bool hasKeyingMaterial(); + +private: + int extractKeyingMaterial(); + int resetSession(); + +private: + mbedtls_x509_crt* cert; /* Certificate, we do not own the key. Make sure it's kept alive during the livetime of this class instance. */ + mbedtls_pk_context* key; /* Private key, we do not own the key. Make sure it's kept alive during the livetime of this class instance. */ + mbedtls_entropy_context entropy_ctx; + mbedtls_ctr_drbg_context rand_ctx; + mbedtls_ssl_context ssl_ctx; + mbedtls_ssl_config ssl_conf; + mbedtls_ssl_cookie_ctx cookie_ctx; + mbedtls_timing_delay_context timer_ctx; + +public: + int (*write_callback)(const uint8_t* data, int* nbytes); + std::deque buffer; /* Accessed from BIO callbback. We copy the bytes you pass into `parse()` into this temporary buffer which is read by a trigger to `mbedlts_ssl_handshake()`. */ + std::string cipher; /* selected SRTP cipher. */ + std::string remote_key; + std::string remote_salt; + std::string local_key; + std::string local_salt; +}; + +/* ----------------------------------------- */ + +inline bool DTLSSRTPHandshake::hasKeyingMaterial() { + return (0 != remote_key.size() + && 0 != remote_salt.size() + && 0 != local_key.size() + && 0 != local_salt.size()); +} + +/* ----------------------------------------- */ diff --git a/lib/rtp_fec.cpp b/lib/rtp_fec.cpp new file mode 100644 index 00000000..893b0204 --- /dev/null +++ b/lib/rtp_fec.cpp @@ -0,0 +1,569 @@ +#include "defines.h" +#include "rtp_fec.h" +#include "rtp.h" + +namespace RTP{ + /// Based on the `block PT` value, we can either find the + /// contents of the codec payload (e.g. H264, VP8) or a ULPFEC header + /// (RFC 5109). The structure of the ULPFEC data is as follows. + /// + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | RTP Header (12 octets or more) | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | FEC Header (10 octets) | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | FEC Level 0 Header | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | FEC Level 0 Payload | + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | FEC Level 1 Header | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | FEC Level 1 Payload | + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Cont. | + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// + /// FEC HEADER: + /// + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |E|L|P|X| CC |M| PT recovery | SN base | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | TS recovery | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | length recovery | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// + /// + /// FEC LEVEL HEADER + /// + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Protection Length | mask | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | mask cont. (present only when L = 1) | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// + PacketFEC::PacketFEC() { + } + + PacketFEC::~PacketFEC() { + receivedSeqNums.clear(); + coveredSeqNums.clear(); + } + + bool PacketFEC::initWithREDPacket(const char* data, size_t nbytes) { + + if (!data) { + FAIL_MSG("Given fecData pointer is NULL."); + return false; + } + + if (nbytes < 23) { + FAIL_MSG("Given fecData is too small. Should be at least: 12 (RTP) + 1 (RED) + 10 (FEC) 23 bytes."); + return false; + } + + if (coveredSeqNums.size() != 0) { + FAIL_MSG("It seems we're already initialized; coveredSeqNums already set."); + return false; + } + + if (receivedSeqNums.size() != 0) { + FAIL_MSG("It seems we're already initialized; receivedSeqNums is not empty."); + return false; + } + + // Decode RED header. + RTP::Packet rtpPkt(data, nbytes); + uint8_t* redHeader = (uint8_t*)(data + rtpPkt.getHsize()); + uint8_t moreBlocks = redHeader[0] & 0x80; + if (moreBlocks == 1) { + FAIL_MSG("RED header indicates there are multiple blocks. Haven't seen this before (@todo implement, exiting now)."); + // \todo do not EXIT! + return false; + } + + // Copy the data, starting at the FEC header (skip RTP + RED header) + size_t numHeaderBytes = rtpPkt.getHsize() + 1; + if (numHeaderBytes > nbytes) { + FAIL_MSG("Invalid FEC packet; too small to contain FEC data."); + return false; + } + + fecPacketData.assign(NULL, 0); + fecPacketData.append(data + numHeaderBytes, nbytes - numHeaderBytes); + + // Extract the sequence numbers this packet protects. + if (!extractCoveringSequenceNumbers()) { + FAIL_MSG("Failed to extract the protected sequence numbers for this FEC."); + // @todo we probably want to reset our set. + return false; + } + + return true; + } + + uint8_t PacketFEC::getExtensionFlag() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get extension-flag from the FEC header; fecPacketData member is not set. Not initialized?"); + return 0; + } + + return ((fecPacketData[0] & 0x80) >> 7); + } + + uint8_t PacketFEC::getLongMaskFlag() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get the long-mask-flag from the FEC header. fecPacketData member is not set. Not initialized?"); + return 0; + } + + return ((fecPacketData[0] & 0x40) >> 6); + } + + // Returns 0 (error), 2 or 6, wich are the valid sizes of the mask. + uint8_t PacketFEC::getNumBytesUsedForMask() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get the number of bytes used by the mask. fecPacketData member is not set. Not initialized?"); + return 0; + } + + if (getLongMaskFlag() == 0) { + return 2; + } + + return 6; + } + + uint16_t PacketFEC::getSequenceBaseNumber() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get the sequence base number. fecPacketData member is not set. Not initialized?"); + return 0; + } + + return (uint16_t) (fecPacketData[2] << 8) | fecPacketData[3]; + } + + char* PacketFEC::getFECHeader() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get fec header. fecPacketData member is not set. Not initialized?"); + } + + return fecPacketData; + } + + char* PacketFEC::getLevel0Header() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get the level 0 header. fecPacketData member is not set. Not initialized?"); + return NULL; + } + + return (char*)(fecPacketData + 10); + } + + char* PacketFEC::getLevel0Payload() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get the level 0 payload. fecPacketData member is not set. Not initialized?"); + return NULL; + } + + // 10 bytes for FEC header + // 2 bytes for `Protection Length` + // 2 or 6 bytes for `mask`. + return (char*)(fecPacketData + 10 + 2 + getNumBytesUsedForMask()); + } + + uint16_t PacketFEC::getLevel0ProtectionLength() { + + if (fecPacketData.size() == 0) { + FAIL_MSG("Cannot get the level 0 protection length. fecPacketData member is not set. Not initialized?"); + return 0; + } + + char* level0Header = getLevel0Header(); + if (!level0Header) { + FAIL_MSG("Failed to get the level 0 header, cannot get protection length."); + return 0; + } + + uint16_t protectionLength = (level0Header[0] << 8) | level0Header[1]; + return protectionLength; + } + + uint16_t PacketFEC::getLengthRecovery() { + + char* fecHeader = getFECHeader(); + if (!fecHeader) { + FAIL_MSG("Cannot get the FEC header which we need to get the `length recovery` field. Not initialized?"); + return 0; + } + + uint16_t lengthRecovery = (fecHeader[8] << 8) | fecHeader[9]; + return lengthRecovery; + } + + // Based on InsertFecPacket of forward_error_correction.cc from + // Chromium. (used as reference). The `mask` from the + // FEC-level-header can be 2 or 6 bytes long. Whenever a bit is + // set to 1 it means that we have to calculate the sequence + // number for that bit. To calculate the sequence number we + // start with the `SN base` value (base sequence number) and + // use the bit offset to increment the SN-base value. E.g. + // when it's bit 4 and SN-base is 230, it meas that this FEC + // packet protects the media packet with sequence number + // 230. We have to start counting the bit numbers from the + // most-significant-bit (e.g. 1 << 7). + bool PacketFEC::extractCoveringSequenceNumbers() { + + if (coveredSeqNums.size() != 0) { + FAIL_MSG("Cannot extract protected sequence numbers; looks like we already did that."); + return false; + } + + size_t maskNumBytes = getNumBytesUsedForMask(); + if (maskNumBytes != 2 && maskNumBytes != 6) { + FAIL_MSG("Invalid mask size (%u) cannot extract sequence numbers.", maskNumBytes); + return false; + } + + char* maskPtr = getLevel0Header(); + if (!maskPtr) { + FAIL_MSG("Failed to get the level-0 header ptr. Cannot extract protected sequence numbers."); + return false; + } + + uint16_t seqNumBase = getSequenceBaseNumber(); + if (seqNumBase == 0) { + WARN_MSG("Base sequence number is 0; it's possible but unlikely."); + } + + // Skip the `Protection Length` + maskPtr = maskPtr + 2; + + for (uint16_t byteDX = 0; byteDX < maskNumBytes; ++byteDX) { + uint8_t maskByte = maskPtr[byteDX]; + for (uint16_t bitDX = 0; bitDX < 8; ++bitDX) { + if (maskByte & (1 << 7 - bitDX)) { + uint16_t seqNum = seqNumBase + (byteDX << 3) + bitDX; + coveredSeqNums.insert(seqNum); + } + } + } + + return true; + } + + // \todo rename coversSequenceNumber + bool PacketFEC::coversSequenceNumber(uint16_t sn) { + return (coveredSeqNums.count(sn) == 0) ? false : true; + } + + void PacketFEC::addReceivedSequenceNumber(uint16_t sn) { + if (false == coversSequenceNumber(sn)) { + FAIL_MSG("Trying to add a received sequence number this instance is not handling (%u).", sn); + return; + } + + receivedSeqNums.insert(sn); + } + + /// This function can be called to recover a missing packet. A + /// FEC packet is received with a list of media packets it + /// might be able to recover; this PacketFEC is received after + /// we should have received the media packets it's protecting. + /// + /// Here we first fill al list with the received sequence + /// numbers that we're protecting; when we're missing one + /// packet this function will try to recover it. + /// + /// The `receivedMediaPackets` is the history of media packets + /// that you received and keep in a memory. These are used + /// when XORing when we reconstruct a packet. + void PacketFEC::tryToRecoverMissingPacket(std::map& receivedMediaPackets, Packet& reconstructedPacket) { + + // Mark all the media packets that we protect and which have + // been received as "received" in our internal list. + std::set::iterator protIt = coveredSeqNums.begin(); + while (protIt != coveredSeqNums.end()) { + if (receivedMediaPackets.count(*protIt) == 1) { + addReceivedSequenceNumber(*protIt); + } + protIt++; + } + + // We have received all media packets that we could recover; + // so there is no need for this FEC packet. + // @todo Jaron shall we reuse allocs/PacketFECs? + if (receivedSeqNums.size() == coveredSeqNums.size()) { + return; + } + + if (coveredSeqNums.size() != (receivedSeqNums.size() + 1)) { + // missing more then 1 packet. we can only recover when + // one packet is lost. + return; + } + + // Find missing sequence number. + uint16_t missingSeqNum = 0; + protIt = coveredSeqNums.begin(); + while (protIt != coveredSeqNums.end()) { + if (receivedSeqNums.count(*protIt) == 0) { + missingSeqNum = *protIt; + break; + } + ++protIt; + } + if (!coversSequenceNumber(missingSeqNum)) { + WARN_MSG("We cannot recover %u.", missingSeqNum); + return; + } + + // Copy FEC into new RTP-header + char* fecHeader = getFECHeader(); + if (!fecHeader) { + FAIL_MSG("Failed to get the fec header. Cannot recover."); + return; + } + recoverData.assign(NULL, 0); + recoverData.append(fecHeader, 12); + + // Copy FEC into new RTP-payload + char* level0Payload = getLevel0Payload(); + if (!level0Payload) { + FAIL_MSG("Failed to get the level-0 payload data (XOR'd media data from FEC packet)."); + return; + } + uint16_t level0ProtLen = getLevel0ProtectionLength(); + if (level0ProtLen == 0) { + FAIL_MSG("Failed to get the level-0 protection length."); + return; + } + recoverData.append(level0Payload, level0ProtLen); + + uint8_t recoverLength[2] = { fecHeader[8], fecHeader[9] }; + + // XOR headers + protIt = coveredSeqNums.begin(); + while (protIt != coveredSeqNums.end()) { + + uint16_t seqNum = *protIt; + if (seqNum == missingSeqNum) { + ++protIt; + continue; + } + + Packet& mediaPacket = receivedMediaPackets[seqNum]; + char* mediaData = mediaPacket.ptr(); + uint16_t mediaSize = mediaPacket.getPayloadSize(); + uint8_t* mediaSizePtr = (uint8_t*)&mediaSize; + + WARN_MSG(" => XOR header with %u, size: %u.", seqNum, mediaSize); + + // V, P, X, CC, M, PT + recoverData[0] ^= mediaData[0]; + recoverData[1] ^= mediaData[1]; + + // Timestamp + recoverData[4] ^= mediaData[4]; + recoverData[5] ^= mediaData[5]; + recoverData[6] ^= mediaData[6]; + recoverData[7] ^= mediaData[7]; + + // Length of recovered media packet + recoverLength[0] ^= mediaSizePtr[1]; + recoverLength[1] ^= mediaSizePtr[0]; + + ++protIt; + } + + uint16_t recoverPayloadSize = ntohs(*(uint16_t*)recoverLength); + + // XOR payloads + protIt = coveredSeqNums.begin(); + while (protIt != coveredSeqNums.end()) { + uint16_t seqNum = *protIt; + if (seqNum == missingSeqNum) { + ++protIt; + continue; + } + Packet& mediaPacket = receivedMediaPackets[seqNum]; + char* mediaData = mediaPacket.ptr() + mediaPacket.getHsize(); + for (size_t i = 0; i < recoverPayloadSize; ++i) { + recoverData[12 + i] ^= mediaData[i]; + } + ++protIt; + } + + // And setup the reconstructed packet. + reconstructedPacket = Packet(recoverData, recoverPayloadSize); + reconstructedPacket.setSequence(missingSeqNum); + // @todo check what other header fields we need to fix. + } + + + void FECSorter::addPacket(const Packet &pack){ + if (tmpVideoLossPrevention & SDP_LOSS_PREVENTION_ULPFEC) { + packetHistory[pack.getSequence()] = pack; + while (packetHistory.begin()->first < pack.getSequence() - 500){ + packetHistory.erase(packetHistory.begin()); + } + } + Sorter::addPacket(pack); + } + + /// This function will handle RED packets that may be used to + /// encapsulate ULPFEC or simply the codec payload (e.g. H264, + /// VP8). This function is created to handle FEC with + /// WebRTC. When we want to use FEC with WebRTC we have to add + /// both the `a=rtpmap: ulpfec/90000` and + /// `a=rtpmap red/90000` lines to the SDP. FEC is + /// always used together with RED (RFC 1298). It turns out + /// that with WebRTC the RED only adds one byte after the RTP + /// header (only the `F` and `block PT`, see below)`. The + /// `block PT` is the payload type of the data that + /// follows. This would be `` for FEC data. Though + /// these RED packets may contain FEC or just the media: + /// H264/VP8. + /// + /// RED HEADER: + /// + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |F| block PT | timestamp offset | block length | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// + void FECSorter::addREDPacket(char* dat, + unsigned int len, + uint8_t codecPayloadType, + uint8_t REDPayloadType, + uint8_t ULPFECPayloadType) + { + + RTP::Packet pkt(dat, len); + if (pkt.getPayloadType() != REDPayloadType) { + FAIL_MSG("Requested to add a RED packet, but it has an invalid payload type."); + return; + } + + // Check if the `F` flag is set. Chromium will always set + // this to 0 (at time of writing, check: https://goo.gl/y1eJ6k + uint8_t* REDHeader = (uint8_t*)(dat + pkt.getHsize()); + uint8_t moreBlocksAvailable = REDHeader[0] & 0x80; + if (moreBlocksAvailable == 1) { + FAIL_MSG("Not yet received a RED packet that had it's F bit set; @todo implement."); + exit(EXIT_FAILURE); + return; + } + + // Extract the `block PT` field which can be the media-pt, + // fec-pt. When it's just media that follows, we move all + // data one byte up and reconstruct a normal media packet. + uint8_t blockPayloadType = REDHeader[0] & 0x7F; + if (blockPayloadType == codecPayloadType) { + memmove(dat + pkt.getHsize(), dat + pkt.getHsize() + 1, len - pkt.getHsize() - 1); + dat[1] &= 0x80; + dat[1] |= codecPayloadType; + RTP::Packet mediaPacket((const char*)dat, len -1); + addPacket(mediaPacket); + return; + } + + // When the payloadType equals our ULP/FEC payload type, we + // received a REC packet (RFC 5109) that contains FEC data + // and a list of sequence number that it covers and can + // reconstruct. + // + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + // + // \todo Jaron, I'm now just generating a `PacketFEC` on the heap + // and we're not managing destruction anywhere atm; I guess + // re-use or destruction needs to be part of the algo that + // is going to deal with FEC. + // + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + if (blockPayloadType == ULPFECPayloadType) { + WARN_MSG(" => got fec packet: %u", pkt.getSequence()); + PacketFEC* fec = new PacketFEC(); + if (!fec->initWithREDPacket(dat, len)) { + delete fec; + fec = NULL; + FAIL_MSG("Failed to initialize a `PacketFEC`"); + } + fecPackets.push_back(fec); + + Packet recreatedPacket; + fec->tryToRecoverMissingPacket(packetHistory, recreatedPacket); + if (recreatedPacket.ptr() != NULL) { + char* pl = recreatedPacket.getPayload(); + WARN_MSG(" => reconstructed %u, %02X %02X %02X %02X | %02X %02X %02X %02X", recreatedPacket.getSequence(), pl[0], pl[1], pl[2], pl[3], pl[4], pl[5], pl[6], pl[7]); + addPacket(recreatedPacket); + } + return; + } + + FAIL_MSG("Unhandled RED block payload type %u. Check the answer SDP.", blockPayloadType); + } + + /// Each FEC packet is capable of recovering a limited amount + /// of media packets. Some experimentation showed that most + /// often one FEC is used to protect somewhere between 2-10 + /// media packets. Each FEC packet has a list of sequence + /// number that it can recover when all other media packets + /// have been received except the one that we want to + /// recover. This function returns the FEC packet might be able + /// to recover the given sequence number. + PacketFEC* FECSorter::getFECPacketWhichCoversSequenceNumber(uint16_t sn) { + size_t nfecs = fecPackets.size(); + for (size_t i = 0; i < nfecs; ++i) { + PacketFEC* fec = fecPackets[i]; + if (fec->coversSequenceNumber(sn)) { + return fec; + } + } + return NULL; + } + + void FECPacket::sendRTCP_RR(RTP::FECSorter &sorter, uint32_t mySSRC, uint32_t theirSSRC, void* userData, void callBack(void* userData, const char* payload, uint32_t nbytes)) { + char *rtcpData = (char *)malloc(32); + if (!rtcpData){ + FAIL_MSG("Could not allocate 32 bytes. Something is seriously messed up."); + return; + } + if (!(sorter.lostCurrent + sorter.packCurrent)){sorter.packCurrent++;} + rtcpData[0] = 0x81; // version 2, no padding, one receiver report + rtcpData[1] = 201; // receiver report + Bit::htobs(rtcpData + 2, 7); // 7 4-byte words follow the header + Bit::htobl(rtcpData + 4, mySSRC); // set receiver identifier + Bit::htobl(rtcpData + 8, theirSSRC); // set source identifier + rtcpData[12] = + (sorter.lostCurrent * 255) / + (sorter.lostCurrent + sorter.packCurrent); // fraction lost since prev RR + Bit::htob24(rtcpData + 13, sorter.lostTotal); // cumulative packets lost since start + Bit::htobl(rtcpData + 16, sorter.rtpSeq | (sorter.packTotal & + 0xFFFF0000ul)); // highest sequence received + Bit::htobl(rtcpData + 20, 0); /// \TODO jitter (diff in timestamp vs packet arrival) + Bit::htobl(rtcpData + 24, 0); /// \TODO last SR (middle 32 bits of last SR or zero) + Bit::htobl(rtcpData + 28, 0); /// \TODO delay since last SR in 2b seconds + 2b fraction + callBack(userData, rtcpData, 32); + sorter.lostCurrent = 0; + sorter.packCurrent = 0; + free(rtcpData); + } + +} + diff --git a/lib/rtp_fec.h b/lib/rtp_fec.h new file mode 100644 index 00000000..5d938f49 --- /dev/null +++ b/lib/rtp_fec.h @@ -0,0 +1,100 @@ +#pragma once +#include "rtp.h" +#include "sdp_media.h" +#include "util.h" +#include + +namespace RTP{ + + /// Util class that can be used to retrieve information from a + /// FEC packet. A FEC packet is contains recovery data. This + /// data can be used to reconstruct a media packet. This class + /// was created and tested for the WebRTC implementation where + /// each FEC packet is encapsulated by a RED packet (RFC 1298). + /// A RED packet may contain ordinary payload data -or- FEC + /// data (RFC 5109). We assume that the data given into + /// `initWithREDPacket()` contains FEC data and you did a bit + /// of parsing to figure this out: by checking if the `block + /// PT` from the RED header is the ULPFEC payload type; if so + /// this PacketFEC class can be used. + class PacketFEC{ + public: + PacketFEC(); + ~PacketFEC(); + bool initWithREDPacket( + const char *data, + size_t nbytes); /// Initialize using the given data. `data` must point to the first byte of + /// the RTP packet which contains the RED and FEC headers and data. + uint8_t getExtensionFlag(); ///< From fec header: should be 0, see + ///< https://tools.ietf.org/html/rfc5109#section-7.3. + uint8_t + getLongMaskFlag(); ///< From fec header: returns 0 when the short mask version is used (16 + ///< bits), otherwise 1 (48 bits). The mask is used to calculate what + ///< sequence numbers are protected, starting at the base sequence number. + uint16_t getSequenceBaseNumber(); ///< From fec header: get the base sequence number. The base + ///< sequence number is used together with the mask to + ///< determine what the sequence numbers of the media packets + ///< are that the fec data protects. + uint8_t getNumBytesUsedForMask(); ///< From fec level header: a fec packet can protected up to + ///< 48 media packets. Which sequence numbers are stored using + ///< a mask bit string. This returns either 2 or 6. + char *getLevel0Header(); ///< Get a pointer to the start of the fec-level-0 header (contains the + ///< protection-length and mask) + char *getLevel0Payload(); /// < Get a pointer to the actual FEC data. This is the XOR'd header + /// and paylaod. + char *getFECHeader(); ///< Get a pointer to the first byte of the FEC header. + uint16_t getLevel0ProtectionLength(); ///< Get the length of the `getLevel0Payload()`. + uint16_t + getLengthRecovery(); ///< Get the `length recovery` value (Little Endian). This value is used + ///< while XORing to recover the length of the missing media packet. + bool coversSequenceNumber(uint16_t sn); ///< Returns true when this `PacketFEC` instance is used + ///< to protect the given sequence number. + void + addReceivedSequenceNumber(uint16_t sn); ///< Whenever you receive a media packet (complete) call + ///< this as we need to know if enough media packets + ///< exist that are needed to recover another one. + void tryToRecoverMissingPacket( + std::map &receivedMediaPackets, + Packet &reconstructedPacket); ///< Pass in a `std::map` indexed by sequence number of -all- + ///< the media packets that you keep as history. When this + ///< `PacketFEC` is capable of recovering a media packet it + ///< will fill the packet passed by reference. + + private: + bool + extractCoveringSequenceNumbers(); ///< Used internally to fill the `coveredSeqNums` member which + ///< tell us what media packets this FEC packet rotects. + + public: + Util::ResizeablePointer fecPacketData; + Util::ResizeablePointer recoverData; + std::set + coveredSeqNums; ///< The sequence numbers of the packets that this FEC protects. + std::set + receivedSeqNums; ///< We keep track of sequence numbers that were received (at some higher + ///< level). We can only recover 1 media packet and this is used to check + ///< if this `PacketFEC` instance is capable of recovering anything. + }; + + class FECSorter : public Sorter{ + public: + void addPacket(const Packet &pack); + void addREDPacket(char *dat, unsigned int len, uint8_t codecPayloadType, uint8_t REDPayloadType, + uint8_t ULPFECPayloadType); + PacketFEC *getFECPacketWhichCoversSequenceNumber(uint16_t sn); + uint8_t tmpVideoLossPrevention; ///< TMP used to drop packets for FEC; see output_webrtc.cpp + ///< `handleSignalingCommandRemoteOfferForInput()`. This + ///< variable should be rmeoved when cleaning up. + private: + std::map packetHistory; + std::vector fecPackets; + }; + + class FECPacket : public Packet{ + public: + void sendRTCP_RR(RTP::FECSorter &sorter, uint32_t mySSRC, uint32_t theirSSRC, void *userData, + void callBack(void *userData, const char *payload, uint32_t nbytes)); + }; + +}// namespace RTP + diff --git a/lib/sdp_media.cpp b/lib/sdp_media.cpp new file mode 100644 index 00000000..50e20a1c --- /dev/null +++ b/lib/sdp_media.cpp @@ -0,0 +1,1184 @@ +#include "defines.h" +#include "sdp_media.h" +#include +#include + +namespace SDP{ + + std::string codecRTP2Mist(const std::string &codec){ + if (codec == "H265"){ + return "HEVC"; + }else if (codec == "H264"){ + return "H264"; + }else if (codec == "VP8"){ + return "VP8"; + }else if (codec == "AC3"){ + return "AC3"; + }else if (codec == "PCMA"){ + return "ALAW"; + }else if (codec == "PCMU"){ + return "ULAW"; + }else if (codec == "L8"){ + return "PCM"; + }else if (codec == "L16"){ + return "PCM"; + }else if (codec == "L20"){ + return "PCM"; + }else if (codec == "L24"){ + return "PCM"; + }else if (codec == "MPA"){ + // can also be MP2, the data should be inspected. + return "MP3"; + }else if (codec == "MPEG4-GENERIC"){ + return "AAC"; + }else if (codec == "OPUS"){ + return "opus"; + }else if (codec == "ULPFEC"){ + return ""; + }else if (codec == "RED"){ + return ""; + } + ERROR_MSG( "%s support not implemented", codec.c_str()); + return ""; + } + + std::string codecMist2RTP(const std::string &codec){ + if (codec == "HEVC"){ + return "H265"; + }else if (codec == "H264"){ + return "H264"; + }else if (codec == "VP8"){ + return "VP8"; + }else if (codec == "AC3"){ + return "AC3"; + }else if (codec == "ALAW"){ + return "PCMA"; + }else if (codec == "ULAW"){ + return "PCMU"; + }else if (codec == "PCM"){ + return "L24"; + }else if (codec == "MP2"){ + return "MPA"; + }else if (codec == "MP3"){ + return "MPA"; + }else if (codec == "AAC"){ + return "MPEG4-GENERIC"; + }else if (codec == "opus"){ + return "OPUS"; + }else if (codec == "ULPFEC"){ + return ""; + }else if (codec == "RED"){ + return ""; + } + ERROR_MSG( "%s support not implemented", codec.c_str()); + BACKTRACE; + return ""; + } + + static std::vector sdp_split( + const std::string &str, const std::string &delim, + bool keepEmpty); // Split a string on the given delimeter and return a vector with the parts. + static bool + sdp_extract_payload_type(const std::string &str, + uint64_t &result); // Extract the payload number from a SDP line that + // starts like: `a=something:[payloadtype]`. + static bool sdp_get_name_value_from_varval( + const std::string &str, std::string &var, + std::string &value); // Extracts the `name` and `value` from a string like `=`. + // The `name` will always be converted to lowercase!. + static void sdp_get_all_name_values_from_string( + const std::string &str, std::map + &result); // Extracts all the name/value pairs from a string like: + // `=;=`. The `name` part will + // always be converted to lowercase. + static bool + sdp_get_attribute_value(const std::string &str, + std::string &result); // Extract an "attribute" value, which is formatted + // like: `a=something:` + static std::string string_to_upper(const std::string &str); + + MediaFormat::MediaFormat(){ + payloadType = SDP_PAYLOAD_TYPE_NONE; + associatedPayloadType = SDP_PAYLOAD_TYPE_NONE; + audioSampleRate = 0; + audioNumChannels = 0; + audioBitSize = 0; + videoRate = 0; + } + + /// \TODO what criteria do you (Jaron) want to use? + MediaFormat::operator bool() const{ + if (payloadType == SDP_PAYLOAD_TYPE_NONE){return false;} + if (encodingName.empty()){return false;} + return true; + } + + uint32_t MediaFormat::getAudioSampleRate() const{ + if (0 != audioSampleRate){return audioSampleRate;} + if (payloadType != SDP_PAYLOAD_TYPE_NONE){ + switch (payloadType){ + case 0:{ + return 8000; + } + case 8:{ + return 8000; + } + case 10:{ + return 44100; + } + case 11:{ + return 44100; + } + } + } + return 0; + } + + uint32_t MediaFormat::getAudioNumChannels() const{ + if (0 != audioNumChannels){return audioNumChannels;} + if (payloadType != SDP_PAYLOAD_TYPE_NONE){ + switch (payloadType){ + case 0:{ + return 1; + } + case 8:{ + return 1; + } + case 10:{ + return 2; + } + case 11:{ + return 1; + } + } + } + return 0; + } + + uint32_t MediaFormat::getAudioBitSize() const{ + + if (0 != audioBitSize){return audioBitSize;} + + if (payloadType != SDP_PAYLOAD_TYPE_NONE){ + switch (payloadType){ + case 10:{ + return 16; + } + case 11:{ + return 16; + } + } + } + + if (encodingName == "L8"){return 8;} + if (encodingName == "L16"){return 16;} + if (encodingName == "L20"){return 20;} + if (encodingName == "L24"){return 24;} + + return 0; + } + + uint32_t MediaFormat::getVideoRate() const{ + + if (0 != videoRate){return videoRate;} + + if (encodingName == "H264"){ + return 90000; + }else if (encodingName == "H265"){ + return 90000; + }else if (encodingName == "VP8"){ + return 90000; + }else if (encodingName == "vp9"){ + return 90000; + } + + return 0; + } + + /// \todo Maybe we should create one member `rate` which is used by audio and video (?) + uint32_t MediaFormat::getVideoOrAudioRate() const{ + uint32_t r = getAudioSampleRate(); + if (0 == r){r = getVideoRate();} + return r; + } + + std::string MediaFormat::getFormatParameterForName(const std::string &name) const{ + std::string name_lower = name; + std::transform(name_lower.begin(), name_lower.end(), name_lower.begin(), ::tolower); + std::map::const_iterator it = formatParameters.find(name_lower); + if (it == formatParameters.end()){return "";} + return it->second; + } + + uint64_t MediaFormat::getPayloadType() const{return payloadType;} + + int32_t MediaFormat::getPacketizationModeForH264(){ + + if (encodingName != "H264"){ + ERROR_MSG("Encoding is not H264."); + return -1; + } + + std::string val = getFormatParameterForName("packetization-mode"); + if (val.empty()){ + WARN_MSG( + "No packetization-mode found for this format. We default to packetization-mode = 0."); + return 0; + } + + std::stringstream ss; + ss << val; + int32_t pm = 0; + ss >> pm; + + return pm; + } + + std::string MediaFormat::getProfileLevelIdForH264() { + + if (encodingName != "H264") { + ERROR_MSG("Encoding is not H264, cannot get profile-level-id."); + return ""; + } + + return getFormatParameterForName("profile-level-id"); + } + + Media::Media(){ + framerate = 0.0; + supportsRTCPMux = false; + supportsRTCPReducedSize = false; + candidatePort = 0; + SSRC = 0; + } + + /// \TODO what other checks do you want to perform? + Media::operator bool() const{ + if (formats.size() == 0){return false;} + if (type.empty()){return false;} + return true; + } + + /// Parses a SDP media line like `m=video 9 UDP/TLS/RTP/SAVPF + /// 96 97 98 99 100 101 102` For each payloadtype it will + /// create a MediaFormat entry and initializes it with some + /// default settings. + bool Media::parseMediaLine(const std::string &line){ + + // split and verify + std::vector words = sdp_split(line, " ", false); + if (!words.size()){ + ERROR_MSG("Invalid media line."); + return false; + } + + // check if we're dealing with audio or video. + if (words[0] == "m=audio"){ + type = "audio"; + }else if (words[0] == "m=video"){ + type = "video"; + }else{ + ERROR_MSG("Unhandled media type: `%s`.", words[0].c_str()); + return false; + } + + // proto: UDP/TLS/RTP/SAVP + proto = words[2]; + + // create static and dynamic tracks. + for (size_t i = 3; i < words.size(); ++i){ + SDP::MediaFormat format; + format.payloadType = JSON::Value(words[i]).asInt(); + formats[format.payloadType] = format; + if (!payloadTypes.empty()){payloadTypes += " ";} + payloadTypes += words[i]; + } + + return true; + } + + bool Media::parseRtpMapLine(const std::string &line){ + + MediaFormat *format = getFormatForSdpLine(line); + if (NULL == format){ + ERROR_MSG( + "Cannot parse the a=rtpmap line because we did not find the track for the payload type."); + return false; + } + + // convert to fullcaps + std::string mediaType = line.substr(line.find(' ', 8) + 1); + std::string encodingName = mediaType.substr(0, mediaType.find('/')); + for (unsigned int i = 0; i < encodingName.size(); ++i){ + if (encodingName[i] <= 122 && encodingName[i] >= 97){encodingName[i] -= 32;} + } + format->encodingName = encodingName; + format->rtpmap = line.substr(line.find("=") + 1); + + // extract audio info + if (type == "audio"){ + std::string extraInfo = mediaType.substr(mediaType.find('/') + 1); + size_t lastSlash = extraInfo.find('/'); + if (lastSlash != std::string::npos){ + format->audioSampleRate = atoll(extraInfo.substr(0, lastSlash).c_str()); + format->audioNumChannels = atoll(extraInfo.substr(lastSlash + 1).c_str()); + }else{ + format->audioSampleRate = atoll(extraInfo.c_str()); + format->audioNumChannels = 1; + } + } + + return true; + } + + bool Media::parseRtspControlLine(const std::string &line){ + + if (line.substr(0, 10) != "a=control:"){ + ERROR_MSG( + "Cannot parse the given rtsp control url line because it's incorrectly formatted: `%s`.", + line.c_str()); + return false; + } + + control = line.substr(10); + if (control.empty()){ + ERROR_MSG("Failed to parse the rtsp control line."); + return false; + } + + return true; + } + + bool Media::parseFrameRateLine(const std::string &line){ + + if (line.substr(0, 12) != "a=framerate:"){ + ERROR_MSG("Cannot parse the `a=framerate:` line because it's incorrectly formatted: `%s`.", + line.c_str()); + return false; + } + + framerate = atof(line.c_str() + 12) * 1000; + return true; + } + + /// Parses a line like: + /// `a=fmtp:97 + /// streamtype=5;profile-level-id=2;mode=AAC-hbr;config=1408;sizelength=13;indexlength=3;indexdeltalength=3;bitrate=32000` + /// `a=fmtp:96 + /// packetization-mode=1;profile-level-id=4d0029;sprop-parameter-sets=Z00AKeKQCADDYC3AQEBpB4kRUA==,aO48gA==` + /// `a=fmtp:97 apt=96` + bool Media::parseFormatParametersLine(const std::string &line){ + + MediaFormat *format = getFormatForSdpLine(line); + if (!format){ + ERROR_MSG("No format found for the given `a=fmtp:` line. The payload type () should be " + "part of the media line."); + return false; + } + + // start parsing the parameters after the first character. + size_t start = line.find(" "); + if (start == std::string::npos){ + ERROR_MSG( + "Invalid formatted a=fmtp line. No space between format and data. `a=fmtp: `"); + return false; + } + start = start + 1; + sdp_get_all_name_values_from_string(line.substr(start), format->formatParameters); + + // When this format is associated with another format + // which is the case for RTX, we store the associated + // payload type too. `apt` means "Associated Payload Type". + if (format->formatParameters.count("apt") != 0) { + std::stringstream ss(format->formatParameters["apt"]); + ss >> format->associatedPayloadType; + } + return true; + } + + bool Media::parseRtcpFeedbackLine(const std::string &line){ + + // does this feedback mechanism apply to all or only a specific format? + MediaFormat *format = NULL; + size_t num_formats = 0; + if (line.substr(0, 11) == "a=rtcp-fb:*"){ + num_formats = formats.size(); + format = &formats[0]; + }else{ + num_formats = 1; + format = getFormatForSdpLine(line); + } + + // make sure we found something valid. + if (!format){ + ERROR_MSG("No format found for the given `a=rtcp-fb` line. The payload type () should " + "be part of the media line."); + return false; + } + if (num_formats == 0){ + ERROR_MSG("num_formats is 0. Seems like no format has been added. Invalid media line in SDP " + "maybe?"); + return false; + } + std::string fb = line.substr(11); + if (fb.empty()){ + ERROR_MSG("The given `a=rtcp-fb` line doesn't contain a rtcp-fb-val."); + return false; + } + + // add the feedback mechanism to the found format(s) + for (size_t i = 0; i < num_formats; ++i){format[i].rtcpFormats.insert(fb);} + + return true; + } + /// Extracts the fingerpint hash and value, from a line like: + /// a=fingerprint:sha-256 + /// C7:98:6F:A9:55:75:C0:73:F2:EB:CF:14:B8:6E:58:FE:A5:F1:B0:C7:41:B7:BA:D3:4A:CF:7E:5C:69:8B:FA:F4 + bool Media::parseFingerprintLine(const std::string &sdpLine){ + + // extract the type. + size_t start = sdpLine.find(":"); + if (start == std::string::npos){ + ERROR_MSG("Invalid `a=fingerprint: ` line, no `:` found."); + return false; + } + size_t end = sdpLine.find(" ", start); + if (end == std::string::npos){ + ERROR_MSG("Invalid `a=fingerprint: ` line, no found after `:`."); + return false; + } + if (end <= start){ + ERROR_MSG("Invalid `a=fingerpint: ` line. Space before the `:` found."); + return false; + } + fingerprintHash = sdpLine.substr(start, end - start); + fingerprintValue = sdpLine.substr(end); + return true; + } + + bool Media::parseSSRCLine(const std::string &sdpLine){ + + if (0 != SSRC){ + // We set our SSRC to the first one that we found. + return true; + } + + size_t firstSpace = sdpLine.find(" "); + if (firstSpace == std::string::npos){ + ERROR_MSG("Failed to parse the `a=ssrc:` line."); + return false; + } + if (firstSpace < 7){ + ERROR_MSG("We found an invalid space position. Cannot get SSRC."); + return false; + } + + std::string ssrcStr = sdpLine.substr(7, firstSpace - 7); + std::stringstream ss; + ss << ssrcStr; + ss >> SSRC; + + return true; + } + + MediaFormat *Media::getFormatForSdpLine(const std::string &sdpLine){ + uint64_t payloadType = 0; + if (!sdp_extract_payload_type(sdpLine, payloadType)){ + ERROR_MSG("Cannot get track for the given SDP line: %s", sdpLine.c_str()); + return NULL; + } + return getFormatForPayloadType(payloadType); + } + + MediaFormat *Media::getFormatForPayloadType(uint64_t &payloadType){ + std::map::iterator it = formats.find(payloadType); + if (it == formats.end()){ + ERROR_MSG("No format found for payload type: %u.", payloadType); + return NULL; + } + return &it->second; + } + + // This will check if there is a `SDP::MediaFormat` with a + // codec that matches the given codec name. Note that we will + // convert the given `encName` into fullcaps as SDP stores all + // codecs in caps. + MediaFormat *Media::getFormatForEncodingName(const std::string &encName){ + + std::string encNameCaps = codecMist2RTP(encName); + std::map::iterator it = formats.begin(); + while (it != formats.end()){ + MediaFormat &mf = it->second; + if (mf.encodingName == encNameCaps){return &mf;} + ++it; + } + + return NULL; + } + + std::vector Media::getFormatsForEncodingName(const std::string &encName){ + + std::string encNameCaps = string_to_upper(encName); + std::vector result; + std::map::iterator it = formats.begin(); + while (it != formats.end()){ + MediaFormat &mf = it->second; + if (mf.encodingName == encNameCaps){result.push_back(&mf);} + ++it; + } + + return result; + } + + MediaFormat* Media::getRetransMissionFormatForPayloadType(uint64_t pt) { + + std::vector rtxFormats = getFormatsForEncodingName("RTX"); + if (rtxFormats.size() == 0) { + return NULL; + } + + for (size_t i = 0; i < rtxFormats.size(); ++i) { + if (rtxFormats[i]->associatedPayloadType == pt) { + return rtxFormats[i]; + } + } + + return NULL; + } + + std::string Media::getIcePwdForFormat(const MediaFormat &fmt){ + if (!fmt.icePwd.empty()){return fmt.icePwd;} + return icePwd; + } + + uint32_t Media::getSSRC() const{return SSRC;} + + // Get the `SDP::Media*` for a given type, e.g. audio or video. + Media *Session::getMediaForType(const std::string &type){ + size_t n = medias.size(); + for (size_t i = 0; i < n; ++i){ + if (medias[i].type == type){return &medias[i];} + } + return NULL; + } + + /// Get the `SDP::MediaFormat` which represents the format and + /// e.g. encoding, rtp attributes for a specific codec (H264, OPUS, etc.) + /// + /// @param mediaType `video` or `audio` + /// @param encodingName Encoding name in fullcaps, e.g. `H264`, `OPUS`, etc. + MediaFormat *Session::getMediaFormatByEncodingName(const std::string &mediaType, + const std::string &encodingName){ + SDP::Media *media = getMediaForType(mediaType); + if (!media){ + WARN_MSG("No SDP::Media found for media type %s.", mediaType.c_str()); + return NULL; + } + + SDP::MediaFormat *mediaFormat = media->getFormatForEncodingName(encodingName); + if (!mediaFormat){ + WARN_MSG("No SDP::MediaFormat found for encoding name %s.", encodingName.c_str()); + return NULL; + } + return mediaFormat; + } + + bool Session::hasReceiveOnlyMedia(){ + size_t numMedias = medias.size(); + for (size_t i = 0; i < numMedias; ++i){ + if (medias[i].direction == "recvonly"){return true;} + } + return false; + } + + bool Session::parseSDP(const std::string &sdp){ + + if (sdp.empty()){ + FAIL_MSG("Requested to parse an empty SDP."); + return false; + } + + SDP::Media *currMedia = NULL; + std::stringstream ss(sdp); + std::string line; + + while (std::getline(ss, line, '\n')){ + + // validate line + if (!line.empty() && *line.rbegin() == '\r'){line.erase(line.size() - 1, 1);} + if (line.empty()){ + continue; + } + + // Parse session (or) media data. + else if (line.substr(0, 2) == "m="){ + SDP::Media media; + if (!media.parseMediaLine(line)){ + ERROR_MSG("Failed to parse the m= line."); + return false; + } + medias.push_back(media); + currMedia = &medias.back(); + // set properties which can be global and may be overwritten per stream + currMedia->iceUFrag = iceUFrag; + currMedia->icePwd = icePwd; + } + + // the lines below assume that we found a media line already. + if (!currMedia){continue;} + + // parse properties we need later. + if (line.substr(0, 8) == "a=rtpmap"){ + currMedia->parseRtpMapLine(line); + }else if (line.substr(0, 10) == "a=control:"){ + currMedia->parseRtspControlLine(line); + }else if (line.substr(0, 12) == "a=framerate:"){ + currMedia->parseFrameRateLine(line); + }else if (line.substr(0, 7) == "a=fmtp:"){ + currMedia->parseFormatParametersLine(line); + }else if (line.substr(0, 11) == "a=rtcp-fb:"){ + currMedia->parseRtcpFeedbackLine(line); + }else if (line.substr(0, 10) == "a=sendonly"){ + currMedia->direction = "sendonly"; + }else if (line.substr(0, 10) == "a=sendrecv"){ + currMedia->direction = "sendrecv"; + }else if (line.substr(0, 10) == "a=recvonly"){ + currMedia->direction = "recvonly"; + }else if (line.substr(0, 11) == "a=ice-ufrag"){ + sdp_get_attribute_value(line, currMedia->iceUFrag); + }else if (line.substr(0, 9) == "a=ice-pwd"){ + sdp_get_attribute_value(line, currMedia->icePwd); + }else if (line.substr(0, 10) == "a=rtcp-mux"){ + currMedia->supportsRTCPMux = true; + }else if (line.substr(0, 10) == "a=rtcp-rsize"){ + currMedia->supportsRTCPReducedSize = true; + }else if (line.substr(0, 7) == "a=setup"){ + sdp_get_attribute_value(line, currMedia->setupMethod); + }else if (line.substr(0, 13) == "a=fingerprint"){ + currMedia->parseFingerprintLine(line); + }else if (line.substr(0, 6) == "a=mid:"){ + sdp_get_attribute_value(line, currMedia->mediaID); + }else if (line.substr(0, 7) == "a=ssrc:"){ + currMedia->parseSSRCLine(line); + } + }// while + + return true; + } + + static std::vector sdp_split(const std::string &str, const std::string &delim, + bool keepEmpty){ + std::vector strings; + std::ostringstream word; + for (size_t n = 0; n < str.size(); ++n){ + if (std::string::npos == delim.find(str[n])){ + word << str[n]; + }else{ + if (false == word.str().empty() || true == keepEmpty){strings.push_back(word.str());} + word.str(""); + } + } + if (false == word.str().empty() || true == keepEmpty){strings.push_back(word.str());} + return strings; + } + + static bool sdp_extract_payload_type(const std::string &str, uint64_t &result){ + // extract payload type. + size_t start_pos = str.find_first_of(':'); + size_t end_pos = str.find_first_of(' ', start_pos); + if (start_pos == std::string::npos || end_pos == std::string::npos || + (start_pos + 1) >= end_pos){ + FAIL_MSG("Invalid `a=rtpmap` line. Has not payload type."); + return false; + } + // make sure payload type was part of the media line and is supported. + result = JSON::Value(str.substr(start_pos + 1, end_pos - (start_pos + 1))).asInt(); + return true; + } + + // Extract the `name` and `value` from a string like + // `=`. This function will return on success, + // when it extract the `name` and `value` and returns false + // when the given input string doesn't contains a + // `=` pair. This function is for example used + // when parsing the `a=fmtp:` line. + // + // Note that we will always convert the `var` to lowercase. + static bool sdp_get_name_value_from_varval(const std::string &str, std::string &var, + std::string &value){ + + if (str.empty()){ + ERROR_MSG("Cannot get `name` and `value` from string because the given string is empty. " + "String is: `%s`", + str.c_str()); + return false; + } + + size_t pos = str.find("="); + if (pos == std::string::npos){ + WARN_MSG("Cannot get `name` and `value` from string becuase it doesn't contain a `=` sign. " + "String is: `%s`. Returning the string as is.", + str.c_str()); + value = str; + return true; + } + + var = str.substr(0, pos); + value = str.substr(pos + 1, str.length() - pos); + std::transform(var.begin(), var.end(), var.begin(), ::tolower); + return true; + } + + // This function will extract name=value pairs from a string + // which are separated by a ";" delmiter. In the future we + // might want to pass this delimiter as an parameter. + // Currently this is used to parse a `a=fmtp` line. + static void sdp_get_all_name_values_from_string(const std::string &str, + std::map &result){ + + std::string varval; + std::string name; + std::string value; + size_t start = 0; + size_t end = str.find(";"); + while (end != std::string::npos){ + varval = str.substr(start, end - start); + if (sdp_get_name_value_from_varval(varval, name, value)){result[name] = value;} + start = end + 1; + end = str.find(";", start); + } + + // the last element needs to read separately + varval = str.substr(start, end); + if (sdp_get_name_value_from_varval(varval, name, value)){result[name] = value;} + } + + // Extract an "attribute" value, which is formatted like: + // `a=something:` + static bool sdp_get_attribute_value(const std::string &str, std::string &result){ + + if (str.empty()){ + ERROR_MSG("Cannot get attribute value because the given string is empty."); + return false; + } + + size_t start = str.find(":"); + if (start == std::string::npos){ + ERROR_MSG("Cannot get attribute value because we did not find the : character in %s.", + str.c_str()); + return false; + } + + result = str.substr(start + 1, result.length() - start); + return true; + } + + Answer::Answer() : + isVideoEnabled(false), + isAudioEnabled(false), + candidatePort(0), + videoLossPrevention(SDP_LOSS_PREVENTION_NONE) + {} + + bool Answer::parseOffer(const std::string &sdp){ + + if (!sdpOffer.parseSDP(sdp)){ + FAIL_MSG("Cannot parse given offer SDP."); + return false; + } + + return true; + } + + bool Answer::hasVideo(){ + SDP::Media *m = sdpOffer.getMediaForType("video"); + return (m != NULL) ? true : false; + } + + bool Answer::hasAudio(){ + SDP::Media *m = sdpOffer.getMediaForType("audio"); + return (m != NULL) ? true : false; + } + + bool Answer::enableVideo(const std::string &codecName){ + if (!enableMedia("video", codecName, answerVideoMedia, answerVideoFormat)){ + DONTEVEN_MSG("Failed to enable video."); + return false; + } + isVideoEnabled = true; + return true; + } + + bool Answer::enableAudio(const std::string &codecName){ + if (!enableMedia("audio", codecName, answerAudioMedia, answerAudioFormat)){ + DONTEVEN_MSG("Not enabling audio."); + return false; + } + isAudioEnabled = true; + return true; + } + + void Answer::setCandidate(const std::string &ip, uint16_t port){ + if (ip.empty()){WARN_MSG("Given candidate IP is empty. It's fine if you want to unset it.");} + candidateIP = ip; + candidatePort = port; + } + + void Answer::setFingerprint(const std::string &fingerprintSha){ + if (fingerprintSha.empty()){ + WARN_MSG( + "Given fingerprint is empty; fine when you want to unset it; otherwise check your code."); + } + fingerprint = fingerprintSha; + } + + void Answer::setDirection(const std::string &dir){ + if (dir.empty()){WARN_MSG("Given direction string is empty; fine if you want to unset.");} + direction = dir; + } + + bool Answer::setupVideoDTSCTrack(DTSC::Track &result){ + + if (!isVideoEnabled){ + FAIL_MSG("Video is disabled; cannot setup DTSC::Track."); + return false; + } + + result.codec = codecRTP2Mist(answerVideoFormat.encodingName); + if (result.codec.empty()){ + FAIL_MSG("Failed to convert the format codec into one that MistServer understands. %s.", + answerVideoFormat.encodingName.c_str()); + return false; + } + + result.type = "video"; + result.rate = answerVideoFormat.getVideoRate(); + result.trackID = answerVideoFormat.payloadType; + return true; + } + + bool Answer::setupAudioDTSCTrack(DTSC::Track &result){ + + if (!isAudioEnabled){ + FAIL_MSG("Audio is disabled; cannot setup DTSC::Track."); + return false; + } + + result.codec = codecRTP2Mist(answerAudioFormat.encodingName); + if (result.codec.empty()){ + FAIL_MSG("Failed to convert the format codec into one that MistServer understands. %s.", + answerAudioFormat.encodingName.c_str()); + return false; + } + + result.type = "audio"; + result.rate = answerAudioFormat.getAudioSampleRate(); + result.channels = answerAudioFormat.getAudioNumChannels(); + result.size = answerAudioFormat.getAudioBitSize(); + result.trackID = answerAudioFormat.payloadType; + return true; + } + + std::string Answer::toString(){ + + if (direction.empty()){ + FAIL_MSG("Cannot create SDP answer; direction not set. call setDirection()."); + return ""; + } + if (candidateIP.empty()){ + FAIL_MSG("Cannot create SDP answer; candidate not set. call setCandidate()."); + return ""; + } + if (fingerprint.empty()){ + FAIL_MSG("Cannot create SDP answer; fingerprint not set. call setFingerpint()."); + return ""; + } + + std::string result; + output.clear(); + + // session + addLine("v=0"); + addLine("o=- %s 0 IN IP4 0.0.0.0", generateSessionId().c_str()); + addLine("s=-"); + addLine("t=0 0"); + addLine("a=ice-lite"); + + // session: bundle (audio and video use same candidate) + if (isVideoEnabled && isAudioEnabled){ + if (answerVideoMedia.mediaID.empty()){ + FAIL_MSG("video media has no media id; necessary for BUNDLE."); + return ""; + } + if (answerAudioMedia.mediaID.empty()){ + FAIL_MSG("audio media has no media id; necessary for BUNDLE."); + return ""; + } + std::string bundled; + for (size_t i = 0; i < sdpOffer.medias.size(); ++i){ + if (sdpOffer.medias[i].type == "audio" || sdpOffer.medias[i].type == "video"){ + if (!bundled.empty()){bundled += " ";} + bundled += sdpOffer.medias[i].mediaID; + } + } + addLine("a=group:BUNDLE %s", bundled.c_str()); + } + + // add medias (order is important) + for (size_t i = 0; i < sdpOffer.medias.size(); ++i){ + + SDP::Media &mediaOffer = sdpOffer.medias[i]; + std::string type = mediaOffer.type; + SDP::Media *media = NULL; + SDP::MediaFormat *fmtMedia = NULL; + SDP::MediaFormat* fmtRED = NULL; + SDP::MediaFormat* fmtULPFEC = NULL; + + bool isEnabled = false; + std::vector supportedPayloadTypes; + if (type != "audio" && type != "video"){continue;} + + // port = 9 (default), port = 0 (disable this media) + if (type == "audio"){ + isEnabled = isAudioEnabled; + media = &answerAudioMedia; + fmtMedia = &answerAudioFormat; + }else if (type == "video"){ + isEnabled = isVideoEnabled; + media = &answerVideoMedia; + fmtMedia = &answerVideoFormat; + fmtRED = media->getFormatForEncodingName("RED"); + fmtULPFEC = media->getFormatForEncodingName("ULPFEC"); + } + + if (!media) { + WARN_MSG("No media found."); + continue; + } + if (!fmtMedia) { + WARN_MSG("No format found."); + continue; + } + // we collect all supported payload types (e.g. RED and + // ULPFEC have their own payload type). We then serialize + // them payload types into a string that is used with the + // `m=` line to indicate we have support for these. + supportedPayloadTypes.push_back((uint8_t)fmtMedia->payloadType); + if ((videoLossPrevention & SDP_LOSS_PREVENTION_ULPFEC) + && fmtRED + && fmtULPFEC) + { + supportedPayloadTypes.push_back(fmtRED->payloadType); + supportedPayloadTypes.push_back(fmtULPFEC->payloadType); + } + + std::stringstream ss; + size_t nels = supportedPayloadTypes.size(); + for (size_t k = 0; k < nels; ++k) { + ss << (int)supportedPayloadTypes[k]; + if ((k + 1) < nels) { + ss << " "; + } + } + std::string payloadTypes = ss.str(); + + if (isEnabled){ + addLine("m=%s 9 UDP/TLS/RTP/SAVPF %s", type.c_str(), payloadTypes.c_str()); + }else{ + addLine("m=%s %u UDP/TLS/RTP/SAVPF %s", type.c_str(), 0, mediaOffer.payloadTypes.c_str()); + } + + addLine("c=IN IP4 0.0.0.0"); + if (!isEnabled){ + // We have to add the direction otherwise we'll receive an error + // like "Answer tried to set recv when offer did not set send" + // from Firefox. + addLine("a=%s", direction.c_str()); + continue; + } + + addLine("a=rtcp:9"); + addLine("a=%s", direction.c_str()); + addLine("a=setup:passive"); + addLine("a=fingerprint:sha-256 %s", fingerprint.c_str()); + addLine("a=ice-ufrag:%s", fmtMedia->iceUFrag.c_str()); + addLine("a=ice-pwd:%s", fmtMedia->icePwd.c_str()); + addLine("a=rtcp-mux"); + addLine("a=rtcp-rsize"); + addLine("a=%s", fmtMedia->rtpmap.c_str()); + + // BEGIN FEC/RTX: testing with just FEC or RTX + if ((videoLossPrevention & SDP_LOSS_PREVENTION_ULPFEC) + && fmtRED + && fmtULPFEC) + { + addLine("a=rtpmap:%u ulpfec/90000", fmtULPFEC->payloadType); + addLine("a=rtpmap:%u red/90000", fmtRED->payloadType); + } + if (videoLossPrevention & SDP_LOSS_PREVENTION_NACK) { + addLine("a=rtcp-fb:%u nack", fmtMedia->payloadType); + } + // END FEC/RTX + if (type == "video"){addLine("a=rtcp-fb:%u goog-remb", fmtMedia->payloadType);} + + if (!media->mediaID.empty()){addLine("a=mid:%s", media->mediaID.c_str());} + + if (fmtMedia->encodingName == "H264"){ + std::string usedProfile = fmtMedia->getFormatParameterForName("profile-level-id"); + if (usedProfile != "42e01f"){ + WARN_MSG("The selected profile-level-id was not 42e01f. We rewrite it into this because that's what we support atm."); + usedProfile = "42e01f"; + } + + addLine("a=fmtp:%u profile-level-id=%s;level-asymmetry-allowed=1;packetization-mode=1", + fmtMedia->payloadType, usedProfile.c_str()); + }else if (fmtMedia->encodingName == "OPUS"){ + addLine("a=fmtp:%u minptime=10;useinbandfec=1", fmtMedia->payloadType); + } + + addLine("a=candidate:1 1 udp 2130706431 %s %u typ host", candidateIP.c_str(), candidatePort); + addLine("a=end-of-candidates"); + } + + // combine all the generated lines. + size_t nlines = output.size(); + for (size_t i = 0; i < nlines; ++i){result += output[i] + "\r\n";} + + return result; + } + + void Answer::addLine(const std::string &fmt, ...){ + + char buffer[1024] ={}; + va_list args; + va_start(args, fmt); + vsnprintf(buffer, sizeof(buffer), fmt.c_str(), args); + va_end(args); + + output.push_back(buffer); + } + + // This function will check if the offer you passed into + // `parseOffer()` contains a media line for the given + // `type`. When found we also check if it contains a codec for + // the given `codecName`. When both are found copy it to the + // given `outMedia` and `outFormat` (which are our answer + // objects. We also generate the values for ice-pwd and + // ice-ufrag which are used during STUN. + // + // @param codecName (string) Can be a comma separated + // string with codecs that you + // support; we select the first + // one that we find. + bool Answer::enableMedia(const std::string &type, const std::string &codecList, + SDP::Media &outMedia, SDP::MediaFormat &outFormat){ + Media *media = sdpOffer.getMediaForType(type); + if (!media){ + INFO_MSG("Cannot enable %s codec; offer doesn't have %s media.", codecList.c_str(), + type.c_str()); + return false; + } + + std::vector codecs = splitString(codecList, ','); + if (codecs.size() == 0){ + FAIL_MSG("Failed to split the given codecList."); + return false; + } + + // ok, this is a bit ugly sorry for that... but when H264 was + // requested we have to check if the packetization mode is + // what we support. Firefox does packetization-mode 0 and 1 + // and provides both formats in their SDP. It may happen that + // an SDP contains multiple format-specs for H264 + SDP::MediaFormat *format = NULL; + for (size_t i = 0; i < codecs.size(); ++i){ + std::string codec = codecMist2RTP(codecs[i]); + std::vector formats = media->getFormatsForEncodingName(codec); + for (size_t j = 0; j < formats.size(); ++j){ + if (codec == "H264"){ + if (formats[j]->getPacketizationModeForH264() != 1){ + FAIL_MSG( + "Skipping this H264 format because it uses a packetization mode we don't support."); + format = NULL; + continue; + } + if (formats[j]->getProfileLevelIdForH264() != "42e01f") { + FAIL_MSG("Skipping this H264 format because it uses an unsupported profile-level-id."); + format = NULL; + continue; + } + } + format = formats[j]; + break; + } + if (format){break;} + } + + if (!format){ + FAIL_MSG("Cannot enable %s; codec not found %s.", type.c_str(), codecList.c_str()); + return false; + } + + INFO_MSG("Enabling media for codec: %s", format->encodingName.c_str()); + + outMedia = *media; + outFormat = *format; + outFormat.rtcpFormats.clear(); + outFormat.icePwd = generateIcePwd(); + outFormat.iceUFrag = generateIceUFrag(); + + return true; + } + + std::string Answer::generateSessionId(){ + + srand(time(NULL)); + uint64_t id = Util::getMicros(); + id += rand(); + + std::stringstream ss; + ss << id; + + return ss.str(); + } + + std::string Answer::generateIceUFrag(){return generateRandomString(4);} + + std::string Answer::generateIcePwd(){return generateRandomString(22);} + + std::string Answer::generateRandomString(const int len){ + + static const char alphanum[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + + std::string s; + for (int i = 0; i < len; ++i){s.push_back(alphanum[rand() % (sizeof(alphanum) - 1)]);} + + return s; + } + + std::vector Answer::splitString(const std::string &str, char delim){ + + std::stringstream ss(str); + std::string segment; + std::vector result; + + while (std::getline(ss, segment, delim)){result.push_back(segment);} + + return result; + } + + static std::string string_to_upper(const std::string &str){ + std::string result = str; + char *p = (char *)result.c_str(); + while (*p != 0){ + if (*p >= 'a' && *p <= 'z'){*p = *p & ~0x20;} + p++; + } + return result; + } + + +} + diff --git a/lib/sdp_media.h b/lib/sdp_media.h new file mode 100644 index 00000000..bc2653f1 --- /dev/null +++ b/lib/sdp_media.h @@ -0,0 +1,220 @@ +#pragma once +#include +#include +#include +#include +#include "dtsc.h" + +#define SDP_PAYLOAD_TYPE_NONE 9999 /// Define that is used to indicate a payload type is not set. +#define SDP_LOSS_PREVENTION_NONE 0 +#define SDP_LOSS_PREVENTION_NACK (1 << 1) /// Use simple NACK based loss prevention. (e.g. send a NACK to pusher of video stream when a packet is lost) +#define SDP_LOSS_PREVENTION_ULPFEC (1 << 2) /// Use FEC (See rtp.cpp, PacketRED). When used we try to add the correct `a=rtpmap` for RED and ULPFEC to the SDP when supported by the offer. + +namespace SDP{ + + /// A MediaFormat stores the infomation that is specific for an + /// encoding. With RTSP there is often just one media format + /// per media line. Though with WebRTC, where an SDP is used to + /// determine a common capability, one media line can contain + /// different formats. These formats are indicated by with the + /// attribute of the media line. For each there may + /// be one or more custom properties. For each property, like + /// the `encodingName` (e.g. VP8, VP9, H264, etc). we create a + /// new `SDP::MediaFormat` object and store it in the `formats` + /// member of `SDP::Media`. + /// + /// When you want to retrieve some specific data and there is a + /// getter function defined for it, you SHOULD use this + /// function as these functions add some extra logic based on + /// the set members. + class MediaFormat{ + public: + MediaFormat(); + std::string getFormatParameterForName( + const std::string &name) const; ///< Get a parameter which was part of the `a=fmtp:` line. + uint32_t getAudioSampleRate() + const; ///< Returns the audio sample rate. When `audioSampleRate` has been set this will be + ///< returned, otherwise we use the `payloadType` to determine the samplerate or + ///< return 0 when we fail to determine to samplerate. + uint32_t getAudioNumChannels() + const; ///< Returns the number of audio channels. When `audioNumChannels` has been set this + ///< will be returned, otherwise we use the `payloadType` when it's set to determine + ///< the samplerate or we return 0 when we can't determine the number of channels. + uint32_t getAudioBitSize() + const; ///< Returns the audio bitsize. When `audioBitSize` has been set this will be + ///< returned, othwerise we use the `encodingName` to determine the right + ///< `audioBitSize` or 0 when we can't determine the `audioBitSize` + uint32_t + getVideoRate() const; ///< Returns the video time base. When `videoRate` has been set this will + ///< be returned, otherwise we use the `encodingName` to determine the + ///< right value or 0 when we can't determine the video rate. + uint32_t getVideoOrAudioRate() const; ///< Returns whichever rate has been set. + uint64_t getPayloadType() const; ///< Returns the `payloadType` member. + int32_t + getPacketizationModeForH264(); ///< When this represents a h264 format this will return the + ///< packetization mode when it was provided in the SDP + std::string getProfileLevelIdForH264(); ///< When this represents a H264 format, this will return the profile-level-id from the format parameters. + + operator bool() const; + + public: + uint64_t payloadType; ///< The payload type as set in the media line (the is -the- + ///< payloadType). + uint64_t associatedPayloadType; ///< From `a=fmtp: apt=`; maps this format to another payload type. + int32_t audioSampleRate; ///< Samplerate of the audio type. + int32_t audioNumChannels; ///< Number of audio channels extracted from the `a=fmtp` or set in + ///< `setDefaultsForPayloadType()`. + int32_t audioBitSize; ///< Number of bits used in case this is an audio type 8, 16, set in + ///< `setDefaultsForCodec()` and `setDefaultsForPayloadType()`. + int32_t videoRate; ///< Video framerate, e.g. 9000 + std::string encodingName; ///< Stores the UPPERCASED encoding name from the `a=rtpmap: + std::string iceUFrag; ///< From `a=ice-ufrag:, used with WebRTC / STUN. + std::string icePwd; ///< From `a=ice-pwd:`, used with WebRTC / STUN + std::string rtpmap; ///< The `a= value; value between brackets. + std::map + formatParameters; ///< Stores the var-val pairs from `a=fmtp:` entry e.g. = + ///< `packetization-mode=1;profile-level-id=4d0029;sprop-parameter-sets=Z00AKeKQCADDYC3AQEBpB4kRUA==,aO48gA==` + ///< */ + std::set + rtcpFormats; ///< Stores the `fb-val` from the line with `a=rtcp-fb: `. + }; + + class Media{ + public: + Media(); + bool parseMediaLine(const std::string &sdpLine); ///< Parses `m=` line. Creates a `MediaFormat` + ///< entry for each of the found values. + bool parseRtpMapLine( + const std::string &sdpLine); ///< Parses `a=rtpmap:` line which contains the some codec + ///< specific info. When this line contains the samplerate and + ///< number of audio channels they will be extracted. + bool parseRtspControlLine(const std::string &sdpLine); ///< Parses `a=control:` + bool parseFrameRateLine(const std::string &sdpLine); ///< Parses `a=framerate:` + bool parseFormatParametersLine(const std::string &sdpLine); ///< Parses `a=fmtp:`. + bool parseRtcpFeedbackLine( + const std::string &sdpLine); ///< Parses `a=rtcp-fb:`. See RFC4584 + bool parseFingerprintLine( + const std::string + &sdpLine); ///< Parses `a=fingerprint: `. See + ///< https://tools.ietf.org/html/rfc8122#section-5, used with WebRTC + bool parseSSRCLine(const std::string &sdpLine); ///< Parses `a=ssrc:`. + MediaFormat *getFormatForSdpLine( + const std::string + &sdpLine); ///< Returns the track to which this SDP line applies. This means that the + ///< SDP line should be formatteed like: `a=something:[payloadtype]`. + MediaFormat *getFormatForPayloadType( + uint64_t &payloadType); ///< Finds the `MediaFormat` in `formats`. Returns NULL when no + ///< format was found for the given payload type. . + MediaFormat *getFormatForEncodingName( + const std::string + &encName); ///< Finds the `MediaFormats in `formats`. Returns NULL when no format was + ///< found for the given encoding name. E.g. `VP8`, `VP9`, `H264` + std::vector getFormatsForEncodingName(const std::string &encName); + std::string getIcePwdForFormat( + const MediaFormat + &fmt); ///< The `a=ice-pwd` can be session global or media specific. This function will + ///< check if the `SDP::MediaFormat` has a ice-pwd that we should use. + uint32_t + getSSRC() const; ///< Returns the first SSRC `a=ssrc:` value found for the media. + operator bool() const; + MediaFormat* getRetransMissionFormatForPayloadType(uint64_t pt); ///< When available, it resurns the RTX format that is directly associated with the media (not encapsulated with a RED header). RTX can be combined with FEC in which case it's supposed to be stored in RED packets. The `encName` should be something like H264,VP8; e.g. the format for which you want to get the RTX format. + + public: + std::string type; ///< The `media` field of the media line: `m= `, + ///< like "video" or "audio" + std::string proto; ///< The `proto` field of the media line: `m= `, + ///< like "video" or "audio" + std::string control; ///< From `a=control:` The RTSP control url. + std::string direction; ///< From `a=sendonly`, `a=recvonly` and `a=sendrecv` + std::string iceUFrag; ///< From `a=ice-ufrag:, used with WebRTC / STUN. + std::string icePwd; ///< From `a=ice-pwd:`, used with WebRTC / STUN + std::string setupMethod; ///< From `a=setup:, used with WebRTC / STUN + std::string fingerprintHash; ///< From `a=fingerprint: `, e.g. sha-256, used with + ///< WebRTC / STUN + std::string + fingerprintValue; ///< From `a=fingerprint: `, the actual fingerprint, used + ///< with WebRTC / STUN, see https://tools.ietf.org/html/rfc8122#section-5 + std::string mediaID; ///< From `a=mid:`. When generating an WebRTC answer this value must + ///< be the same as in the offer. + std::string candidateIP; ///< Used when we generate a WebRTC answer. + uint16_t candidatePort; ///< Used when we generate a WebRTC answer. + uint32_t SSRC; ///< From `a=ssrc: `; the first SSRC that we encountered. + double framerate; ///< From `a=framerate`. + bool supportsRTCPMux; ///< From `a=rtcp-mux`, indicates if it can mux RTP and RTCP on one + ///< transport channel. + bool supportsRTCPReducedSize; ///< From `a=rtcp-rsize`, reduced size RTCP packets. + std::string + payloadTypes; ///< From `m=` line, all the payload types as string, separated by space. + std::map + formats; ///< Formats indexed by payload type. Payload type is the number in the + ///< field(s) from the `m=` line. + }; + + class Session{ + public: + bool parseSDP(const std::string &sdp); + Media *getMediaForType( + const std::string &type); ///< Get a `SDP::Media*` for the given type, e.g. `video` or + ///< `audio`. Returns NULL when the type was not found. + MediaFormat *getMediaFormatByEncodingName(const std::string &mediaType, + const std::string &encodingName); + bool hasReceiveOnlyMedia(); ///< Returns true when one of the media sections has a `a=recvonly` + ///< attribute. This is used to determine if the other peer only + ///< wants to receive or also sent data. */ + + public: + std::vector medias; ///< For each `m=` line we create a `SDP::Media` instance. The + ///< stream specific infomration is stored in a `MediaFormat` + std::string icePwd; ///< From `a=ice-pwd`, this property can be session-wide or media specific. + ///< Used with WebRTC and STUN when calculating the message-integrity. + std::string + iceUFrag; ///< From `a=ice-ufag`, this property can be session-wide or media specific. Used + ///< with WebRTC and STUN when calculating the message-integrity. + }; + + class Answer{ + public: + Answer(); + bool parseOffer(const std::string &sdp); + bool hasVideo(); ///< Check if the offer has video. + bool hasAudio(); ///< Check if the offer has audio. + bool enableVideo(const std::string &codecName); + bool enableAudio(const std::string &codecName); + void setCandidate(const std::string &ip, uint16_t port); + void setFingerprint(const std::string &fingerprintSha); ///< Set the SHA265 that represents the + ///< certificate that is used with DTLS. + void setDirection(const std::string &dir); + bool setupVideoDTSCTrack(DTSC::Track &result); + bool setupAudioDTSCTrack(DTSC::Track &result); + std::string toString(); + + private: + bool enableMedia(const std::string &type, const std::string &codecName, SDP::Media &outMedia, + SDP::MediaFormat &outFormat); + void addLine(const std::string &fmt, ...); + std::string generateSessionId(); + std::string generateIceUFrag(); ///< Generates the `ice-ufrag` value. + std::string generateIcePwd(); ///< Generates the `ice-pwd` value. + std::string generateRandomString(const int len); + std::vector splitString(const std::string &str, char delim); + + public: + SDP::Session sdpOffer; + SDP::Media answerVideoMedia; + SDP::Media answerAudioMedia; + SDP::MediaFormat answerVideoFormat; + SDP::MediaFormat answerAudioFormat; + bool isAudioEnabled; + bool isVideoEnabled; + std::string candidateIP; ///< We use rtcp-mux and BUNDLE; so only one candidate necessary. + uint16_t candidatePort; ///< We use rtcp-mux and BUNDLE; so only one candidate necessary. + std::string fingerprint; + std::string direction; ///< The direction used when generating the answer SDP string. + std::vector output; ///< The lines that are used when adding lines (see `addLine()` + ///< for the answer sdp.). + uint8_t videoLossPrevention; ///< See the SDP_LOSS_PREVENTION_* values at the top of this header. + }; + +} + diff --git a/lib/srtp.cpp b/lib/srtp.cpp new file mode 100644 index 00000000..a7200b7f --- /dev/null +++ b/lib/srtp.cpp @@ -0,0 +1,422 @@ +#include +#include "defines.h" +#include "srtp.h" + +/* --------------------------------------- */ + +static std::string srtp_status_to_string(uint32_t status); + +/* --------------------------------------- */ + +SRTPReader::SRTPReader() { + memset((void*)&session, 0x00, sizeof(session)); + memset((void*)&policy, 0x00, sizeof(policy)); +} + +/* + Before initializing the srtp library we shut it down first + because initializing the library twice results in an error. +*/ +int SRTPReader::init(const std::string& cipher, const std::string& key, const std::string& salt) { + + int r = 0; + srtp_err_status_t status = srtp_err_status_ok; + srtp_profile_t profile; + memset((void*) &profile, 0x00, sizeof(profile)); + + /* validate input */ + if (cipher.empty()) { + FAIL_MSG("Given `cipher` is empty."); + r = -1; + goto error; + } + if (key.empty()) { + FAIL_MSG("Given `key` is empty."); + r = -2; + goto error; + } + if (salt.empty()) { + FAIL_MSG("Given `salt` is empty."); + r = -3; + goto error; + } + + /* re-initialize the srtp library. */ + status = srtp_shutdown(); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to shutdown the srtp lib %s", srtp_status_to_string(status).c_str()); + r = -1; + goto error; + } + + status = srtp_init(); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to initialize the SRTP library. %s", srtp_status_to_string(status).c_str()); + r = -2; + goto error; + } + + /* select the right profile from exchanged cipher */ + if ("SRTP_AES128_CM_SHA1_80" == cipher) { + profile = srtp_profile_aes128_cm_sha1_80; + } + else if ("SRTP_AES128_CM_SHA1_32" == cipher) { + profile = srtp_profile_aes128_cm_sha1_32; + } + else { + ERROR_MSG("Unsupported SRTP cipher used: %s.", cipher.c_str()); + r = -2; + goto error; + } + + /* set the crypto policy using the profile. */ + status = srtp_crypto_policy_set_from_profile_for_rtp(&policy.rtp, profile); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to set the crypto policy for RTP for cipher %s.", cipher.c_str()); + r = -3; + goto error; + } + + status = srtp_crypto_policy_set_from_profile_for_rtcp(&policy.rtcp, profile); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to set the crypto policy for RTCP for cipher %s.", cipher.c_str()); + r = -4; + goto error; + } + + /* set the keying material. */ + std::copy(key.begin(), key.end(), std::back_inserter(key_salt)); + std::copy(salt.begin(), salt.end(), std::back_inserter(key_salt)); + policy.key = (unsigned char*)&key_salt[0]; + + /* only unprotecting data for now, so using inbound; and some other settings. */ + policy.ssrc.type = ssrc_any_inbound; + policy.window_size = 1024; + policy.allow_repeat_tx = 1; + + /* create the srtp session. */ + status = srtp_create(&session, &policy); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to initialize our SRTP session. Status: %s. ", srtp_status_to_string(status).c_str()); + r = -3; + goto error; + } + + error: + if (r < 0) { + shutdown(); + } + + return r; +} + +int SRTPReader::shutdown() { + + int r = 0; + + srtp_err_status_t status = srtp_dealloc(session); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to cleanly shutdown the SRTP session. Status: %s", srtp_status_to_string(status).c_str()); + r -= 5; + } + + memset((void*)&policy, 0x00, sizeof(policy)); + memset((char*)&session, 0x00, sizeof(session)); + + return r; +} + +/* --------------------------------------- */ + +int SRTPReader::unprotectRtp(uint8_t* data, int* nbytes) { + + if (NULL == data) { + ERROR_MSG("Cannot unprotect the given SRTP, because data is NULL."); + return -1; + } + + if (NULL == nbytes) { + ERROR_MSG("Cannot unprotect the given SRTP, becuase nbytes is NULL."); + return -2; + } + + if (0 == (*nbytes)) { + ERROR_MSG("Cannot unprotect the given SRTP, because nbytes is 0."); + return -3; + } + + if (NULL == policy.key) { + ERROR_MSG("Cannot unprotect the SRTP packet, it seems we're not initialized."); + return -4; + } + + srtp_err_status_t status = srtp_unprotect(session, data, nbytes); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to unprotect the given SRTP. %s.", srtp_status_to_string(status).c_str()); + return -5; + } + + DONTEVEN_MSG("Unprotected SRTP into %d bytes.", *nbytes); + + return 0; +} + +int SRTPReader::unprotectRtcp(uint8_t* data, int* nbytes) { + + if (NULL == data) { + ERROR_MSG("Cannot unprotect the given SRTCP, because data is NULL."); + return -1; + } + + if (NULL == nbytes) { + ERROR_MSG("Cannot unprotect the given SRTCP, becuase nbytes is NULL."); + return -2; + } + + if (0 == (*nbytes)) { + ERROR_MSG("Cannot unprotect the given SRTCP, because nbytes is 0."); + return -3; + } + + if (NULL == policy.key) { + ERROR_MSG("Cannot unprotect the SRTCP packet, it seems we're not initialized."); + return -4; + } + + srtp_err_status_t status = srtp_unprotect_rtcp(session, data, nbytes); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to unprotect the given SRTCP. %s.", srtp_status_to_string(status).c_str()); + return -5; + } + + return 0; +} + +/* --------------------------------------- */ + +SRTPWriter::SRTPWriter() { + memset((void*)&session, 0x00, sizeof(session)); + memset((void*)&policy, 0x00, sizeof(policy)); +} + +/* + Before initializing the srtp library we shut it down first + because initializing the library twice results in an error. +*/ +int SRTPWriter::init(const std::string& cipher, const std::string& key, const std::string& salt) { + + int r = 0; + srtp_err_status_t status = srtp_err_status_ok; + srtp_profile_t profile; + memset((void*)&profile, 0x00, sizeof(profile)); + + /* validate input */ + if (cipher.empty()) { + FAIL_MSG("Given `cipher` is empty."); + r = -1; + goto error; + } + if (key.empty()) { + FAIL_MSG("Given `key` is empty."); + r = -2; + goto error; + } + if (salt.empty()) { + FAIL_MSG("Given `salt` is empty."); + r = -3; + goto error; + } + + /* re-initialize the srtp library. */ + status = srtp_shutdown(); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to shutdown the srtp lib %s", srtp_status_to_string(status).c_str()); + r = -1; + goto error; + } + + status = srtp_init(); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to initialize the SRTP library. %s", srtp_status_to_string(status).c_str()); + r = -2; + goto error; + } + + /* select the exchanged cipher */ + if ("SRTP_AES128_CM_SHA1_80" == cipher) { + profile = srtp_profile_aes128_cm_sha1_80; + } + else if ("SRTP_AES128_CM_SHA1_32" == cipher) { + profile = srtp_profile_aes128_cm_sha1_32; + } + else { + ERROR_MSG("Unsupported SRTP cipher used: %s.", cipher.c_str()); + r = -2; + goto error; + } + + /* set the crypto policy using the profile. */ + status = srtp_crypto_policy_set_from_profile_for_rtp(&policy.rtp, profile); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to set the crypto policy for RTP for cipher %s.", cipher.c_str()); + r = -3; + goto error; + } + + status = srtp_crypto_policy_set_from_profile_for_rtcp(&policy.rtcp, profile); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to set the crypto policy for RTCP for cipher %s.", cipher.c_str()); + r = -4; + goto error; + } + + /* set the keying material. */ + std::copy(key.begin(), key.end(), std::back_inserter(key_salt)); + std::copy(salt.begin(), salt.end(), std::back_inserter(key_salt)); + policy.key = (unsigned char*)&key_salt[0]; + + /* only unprotecting data for now, so using inbound; and some other settings. */ + policy.ssrc.type = ssrc_any_outbound; + policy.window_size = 128; + policy.allow_repeat_tx = 0; + + /* create the srtp session. */ + status = srtp_create(&session, &policy); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to initialize our SRTP session. Status: %s. ", srtp_status_to_string(status).c_str()); + r = -3; + goto error; + } + + error: + if (r < 0) { + shutdown(); + } + + return r; +} + +int SRTPWriter::shutdown() { + + int r = 0; + + srtp_err_status_t status = srtp_dealloc(session); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to cleanly shutdown the SRTP session. Status: %s", srtp_status_to_string(status).c_str()); + r -= 5; + } + + memset((char*)&policy, 0x00, sizeof(policy)); + memset((char*)&session, 0x00, sizeof(session)); + + return r; +} + +/* --------------------------------------- */ + +int SRTPWriter::protectRtp(uint8_t* data, int* nbytes) { + + if (NULL == data) { + ERROR_MSG("Cannot protect the RTP packet because given data is NULL."); + return -1; + } + + if (NULL == nbytes) { + ERROR_MSG("Cannot protect the RTP packet because the given nbytes is NULL."); + return -2; + } + + if ((*nbytes) <= 0) { + ERROR_MSG("Cannot protect the RTP packet because the given nbytes has a value <= 0."); + return -3; + } + + if (NULL == policy.key) { + ERROR_MSG("Cannot protect the RTP packet because we're not initialized."); + return -4; + } + + srtp_err_status_t status = srtp_protect(session, (void*)data, nbytes); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to protect the RTP packet. %s.", srtp_status_to_string(status).c_str()); + return -5; + } + + return 0; +} + +/* + Make sure that `data` has `SRTP_MAX_TRAILER_LEN + 4` number + of bytes at the into which libsrtp can write the + authentication tag +*/ +int SRTPWriter::protectRtcp(uint8_t* data, int* nbytes) { + + if (NULL == data) { + ERROR_MSG("Cannot protect the RTCP packet because given data is NULL."); + return -1; + } + + if (NULL == nbytes) { + ERROR_MSG("Cannot protect the RTCP packet because nbytes is NULL."); + return -2; + } + + if ((*nbytes) <= 0) { + ERROR_MSG("Cannot protect the RTCP packet because *nbytes is <= 0."); + return -3; + } + + if (NULL == policy.key) { + ERROR_MSG("Not initialized cannot protect the RTCP packet."); + return -4; + } + + srtp_err_status_t status = srtp_protect_rtcp(session, (void*)data, nbytes); + if (srtp_err_status_ok != status) { + ERROR_MSG("Failed to protect the RTCP packet. %s.", srtp_status_to_string(status).c_str()); + return -3; + } + + return 0; +} + + +/* --------------------------------------- */ + +static std::string srtp_status_to_string(uint32_t status) { + + switch (status) { + case srtp_err_status_ok: { return "srtp_err_status_ok"; } + case srtp_err_status_fail: { return "srtp_err_status_fail"; } + case srtp_err_status_bad_param: { return "srtp_err_status_bad_param"; } + case srtp_err_status_alloc_fail: { return "srtp_err_status_alloc_fail"; } + case srtp_err_status_dealloc_fail: { return "srtp_err_status_dealloc_fail"; } + case srtp_err_status_init_fail: { return "srtp_err_status_init_fail"; } + case srtp_err_status_terminus: { return "srtp_err_status_terminus"; } + case srtp_err_status_auth_fail: { return "srtp_err_status_auth_fail"; } + case srtp_err_status_cipher_fail: { return "srtp_err_status_cipher_fail"; } + case srtp_err_status_replay_fail: { return "srtp_err_status_replay_fail"; } + case srtp_err_status_replay_old: { return "srtp_err_status_replay_old"; } + case srtp_err_status_algo_fail: { return "srtp_err_status_algo_fail"; } + case srtp_err_status_no_such_op: { return "srtp_err_status_no_such_op"; } + case srtp_err_status_no_ctx: { return "srtp_err_status_no_ctx"; } + case srtp_err_status_cant_check: { return "srtp_err_status_cant_check"; } + case srtp_err_status_key_expired: { return "srtp_err_status_key_expired"; } + case srtp_err_status_socket_err: { return "srtp_err_status_socket_err"; } + case srtp_err_status_signal_err: { return "srtp_err_status_signal_err"; } + case srtp_err_status_nonce_bad: { return "srtp_err_status_nonce_bad"; } + case srtp_err_status_read_fail: { return "srtp_err_status_read_fail"; } + case srtp_err_status_write_fail: { return "srtp_err_status_write_fail"; } + case srtp_err_status_parse_err: { return "srtp_err_status_parse_err"; } + case srtp_err_status_encode_err: { return "srtp_err_status_encode_err"; } + case srtp_err_status_semaphore_err: { return "srtp_err_status_semaphore_err"; } + case srtp_err_status_pfkey_err: { return "srtp_err_status_pfkey_err"; } + case srtp_err_status_bad_mki: { return "srtp_err_status_bad_mki"; } + case srtp_err_status_pkt_idx_old: { return "srtp_err_status_pkt_idx_old"; } + case srtp_err_status_pkt_idx_adv: { return "srtp_err_status_pkt_idx_adv"; } + default: { return "UNKNOWN"; } + } +} + +/* --------------------------------------- */ diff --git a/lib/srtp.h b/lib/srtp.h new file mode 100644 index 00000000..cc5452c1 --- /dev/null +++ b/lib/srtp.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include + +#define SRTP_PARSER_MASTER_KEY_LEN 16 +#define SRTP_PARSER_MASTER_SALT_LEN 14 +#define SRTP_PARSER_MASTER_LEN (SRTP_PARSER_MASTER_KEY_LEN + SRTP_PARSER_MASTER_SALT_LEN) + +/* --------------------------------------- */ + +class SRTPReader { +public: + SRTPReader(); + int init(const std::string& cipher, const std::string& key, const std::string& salt); + int shutdown(); + int unprotectRtp(uint8_t* data, int* nbytes); /* `nbytes` should contain the number of bytes in `data`. On success `nbytes` will hold the number of bytes of the decoded RTP packet. */ + int unprotectRtcp(uint8_t* data, int* nbytes); /* `nbytes` should contains the number of bytes in `data`. On success `nbytes` will hold the number of bytes the decoded RTCP packet. */ + +private: + srtp_t session; + srtp_policy_t policy; + std::vector key_salt; /* Combination of key + salt which is used to unprotect the SRTP/SRTCP data. */ +}; + +/* --------------------------------------- */ + +class SRTPWriter { +public: + SRTPWriter(); + int init(const std::string& cipher, const std::string& key, const std::string& salt); + int shutdown(); + int protectRtp(uint8_t* data, int* nbytes); + int protectRtcp(uint8_t* data, int* nbytes); + +private: + srtp_t session; + srtp_policy_t policy; + std::vector key_salt; /* Combination of key + salt which is used to protect the SRTP/SRTCP data. */ +}; + +/* --------------------------------------- */ diff --git a/lib/stun.cpp b/lib/stun.cpp new file mode 100644 index 00000000..42d0f759 --- /dev/null +++ b/lib/stun.cpp @@ -0,0 +1,1051 @@ +#include "defines.h" +#include "stun.h" +#include "checksum.h" // for crc32 + +/* --------------------------------------- */ + +std::string stun_family_type_to_string(uint8_t type) { + switch (type) { + case STUN_IP4: { return "STUN_IP4"; } + case STUN_IP6: { return "STUN_IP6"; } + default: { return "UNKNOWN"; } + } +} + +std::string stun_message_type_to_string(uint16_t type) { + switch (type) { + case STUN_MSG_TYPE_NONE: { return "STUN_MSG_TYPE_NONE"; } + case STUN_MSG_TYPE_BINDING_REQUEST: { return "STUN_MSG_TYPE_BINDING_REQUEST"; } + case STUN_MSG_TYPE_BINDING_RESPONSE_SUCCESS: { return "STUN_MSG_TYPE_BINDING_RESPONSE_SUCCESS"; } + case STUN_MSG_TYPE_BINDING_RESPONSE_ERROR: { return "STUN_MSG_TYPE_BINDING_RESPONSE_ERROR"; } + case STUN_MSG_TYPE_BINDING_INDICATION: { return "STUN_MSG_TYPE_BINDING_INDICATION"; } + default: { return "UNKNOWN"; } + } +} + +std::string stun_attribute_type_to_string(uint16_t type) { + switch (type) { + case STUN_ATTR_TYPE_NONE: { return "STUN_ATTR_TYPE_NONE"; } + case STUN_ATTR_TYPE_MAPPED_ADDR: { return "STUN_ATTR_TYPE_MAPPED_ADDR"; } + case STUN_ATTR_TYPE_CHANGE_REQ: { return "STUN_ATTR_TYPE_CHANGE_REQ"; } + case STUN_ATTR_TYPE_USERNAME: { return "STUN_ATTR_TYPE_USERNAME"; } + case STUN_ATTR_TYPE_MESSAGE_INTEGRITY: { return "STUN_ATTR_TYPE_MESSAGE_INTEGRITY"; } + case STUN_ATTR_TYPE_ERR_CODE: { return "STUN_ATTR_TYPE_ERR_CODE"; } + case STUN_ATTR_TYPE_UNKNOWN_ATTRIBUTES: { return "STUN_ATTR_TYPE_UNKNOWN_ATTRIBUTES"; } + case STUN_ATTR_TYPE_CHANNEL_NUMBER: { return "STUN_ATTR_TYPE_CHANNEL_NUMBER"; } + case STUN_ATTR_TYPE_LIFETIME: { return "STUN_ATTR_TYPE_LIFETIME"; } + case STUN_ATTR_TYPE_XOR_PEER_ADDR: { return "STUN_ATTR_TYPE_XOR_PEER_ADDR"; } + case STUN_ATTR_TYPE_DATA: { return "STUN_ATTR_TYPE_DATA"; } + case STUN_ATTR_TYPE_REALM: { return "STUN_ATTR_TYPE_REALM"; } + case STUN_ATTR_TYPE_NONCE: { return "STUN_ATTR_TYPE_NONCE"; } + case STUN_ATTR_TYPE_XOR_RELAY_ADDRESS: { return "STUN_ATTR_TYPE_XOR_RELAY_ADDRESS"; } + case STUN_ATTR_TYPE_REQ_ADDRESS_FAMILY: { return "STUN_ATTR_TYPE_REQ_ADDRESS_FAMILY"; } + case STUN_ATTR_TYPE_EVEN_PORT: { return "STUN_ATTR_TYPE_EVEN_PORT"; } + case STUN_ATTR_TYPE_REQUESTED_TRANSPORT: { return "STUN_ATTR_TYPE_REQUESTED_TRANSPORT"; } + case STUN_ATTR_TYPE_DONT_FRAGMENT: { return "STUN_ATTR_TYPE_DONT_FRAGMENT"; } + case STUN_ATTR_TYPE_XOR_MAPPED_ADDRESS: { return "STUN_ATTR_TYPE_XOR_MAPPED_ADDRESS"; } + case STUN_ATTR_TYPE_RESERVATION_TOKEN: { return "STUN_ATTR_TYPE_RESERVATION_TOKEN"; } + case STUN_ATTR_TYPE_PRIORITY: { return "STUN_ATTR_TYPE_PRIORITY"; } + case STUN_ATTR_TYPE_USE_CANDIDATE: { return "STUN_ATTR_TYPE_USE_CANDIDATE"; } + case STUN_ATTR_TYPE_PADDING: { return "STUN_ATTR_TYPE_PADDING"; } + case STUN_ATTR_TYPE_RESPONSE_PORT: { return "STUN_ATTR_TYPE_RESPONSE_PORT"; } + case STUN_ATTR_TYPE_SOFTWARE: { return "STUN_ATTR_TYPE_SOFTWARE"; } + case STUN_ATTR_TYPE_ALTERNATE_SERVER: { return "STUN_ATTR_TYPE_ALTERNATE_SERVER"; } + case STUN_ATTR_TYPE_FINGERPRINT: { return "STUN_ATTR_TYPE_FINGERPRINT"; } + case STUN_ATTR_TYPE_ICE_CONTROLLED: { return "STUN_ATTR_TYPE_ICE_CONTROLLED"; } + case STUN_ATTR_TYPE_ICE_CONTROLLING: { return "STUN_ATTR_TYPE_ICE_CONTROLLING"; } + case STUN_ATTR_TYPE_RESPONSE_ORIGIN: { return "STUN_ATTR_TYPE_RESPONSE_ORIGIN"; } + case STUN_ATTR_TYPE_OTHER_ADDRESS: { return "STUN_ATTR_TYPE_OTHER_ADDRESS"; } + default: { return "UNKNOWN"; } + } +} + +static uint32_t poly_crc32(uint32_t inCrc, const uint8_t* data, size_t nbytes) { + + static const unsigned long crc_table[256] = { + 0x00000000,0x77073096,0xEE0E612C,0x990951BA,0x076DC419,0x706AF48F,0xE963A535, + 0x9E6495A3,0x0EDB8832,0x79DCB8A4,0xE0D5E91E,0x97D2D988,0x09B64C2B,0x7EB17CBD, + 0xE7B82D07,0x90BF1D91,0x1DB71064,0x6AB020F2,0xF3B97148,0x84BE41DE,0x1ADAD47D, + 0x6DDDE4EB,0xF4D4B551,0x83D385C7,0x136C9856,0x646BA8C0,0xFD62F97A,0x8A65C9EC, + 0x14015C4F,0x63066CD9,0xFA0F3D63,0x8D080DF5,0x3B6E20C8,0x4C69105E,0xD56041E4, + 0xA2677172,0x3C03E4D1,0x4B04D447,0xD20D85FD,0xA50AB56B,0x35B5A8FA,0x42B2986C, + 0xDBBBC9D6,0xACBCF940,0x32D86CE3,0x45DF5C75,0xDCD60DCF,0xABD13D59,0x26D930AC, + 0x51DE003A,0xC8D75180,0xBFD06116,0x21B4F4B5,0x56B3C423,0xCFBA9599,0xB8BDA50F, + 0x2802B89E,0x5F058808,0xC60CD9B2,0xB10BE924,0x2F6F7C87,0x58684C11,0xC1611DAB, + 0xB6662D3D,0x76DC4190,0x01DB7106,0x98D220BC,0xEFD5102A,0x71B18589,0x06B6B51F, + 0x9FBFE4A5,0xE8B8D433,0x7807C9A2,0x0F00F934,0x9609A88E,0xE10E9818,0x7F6A0DBB, + 0x086D3D2D,0x91646C97,0xE6635C01,0x6B6B51F4,0x1C6C6162,0x856530D8,0xF262004E, + 0x6C0695ED,0x1B01A57B,0x8208F4C1,0xF50FC457,0x65B0D9C6,0x12B7E950,0x8BBEB8EA, + 0xFCB9887C,0x62DD1DDF,0x15DA2D49,0x8CD37CF3,0xFBD44C65,0x4DB26158,0x3AB551CE, + 0xA3BC0074,0xD4BB30E2,0x4ADFA541,0x3DD895D7,0xA4D1C46D,0xD3D6F4FB,0x4369E96A, + 0x346ED9FC,0xAD678846,0xDA60B8D0,0x44042D73,0x33031DE5,0xAA0A4C5F,0xDD0D7CC9, + 0x5005713C,0x270241AA,0xBE0B1010,0xC90C2086,0x5768B525,0x206F85B3,0xB966D409, + 0xCE61E49F,0x5EDEF90E,0x29D9C998,0xB0D09822,0xC7D7A8B4,0x59B33D17,0x2EB40D81, + 0xB7BD5C3B,0xC0BA6CAD,0xEDB88320,0x9ABFB3B6,0x03B6E20C,0x74B1D29A,0xEAD54739, + 0x9DD277AF,0x04DB2615,0x73DC1683,0xE3630B12,0x94643B84,0x0D6D6A3E,0x7A6A5AA8, + 0xE40ECF0B,0x9309FF9D,0x0A00AE27,0x7D079EB1,0xF00F9344,0x8708A3D2,0x1E01F268, + 0x6906C2FE,0xF762575D,0x806567CB,0x196C3671,0x6E6B06E7,0xFED41B76,0x89D32BE0, + 0x10DA7A5A,0x67DD4ACC,0xF9B9DF6F,0x8EBEEFF9,0x17B7BE43,0x60B08ED5,0xD6D6A3E8, + 0xA1D1937E,0x38D8C2C4,0x4FDFF252,0xD1BB67F1,0xA6BC5767,0x3FB506DD,0x48B2364B, + 0xD80D2BDA,0xAF0A1B4C,0x36034AF6,0x41047A60,0xDF60EFC3,0xA867DF55,0x316E8EEF, + 0x4669BE79,0xCB61B38C,0xBC66831A,0x256FD2A0,0x5268E236,0xCC0C7795,0xBB0B4703, + 0x220216B9,0x5505262F,0xC5BA3BBE,0xB2BD0B28,0x2BB45A92,0x5CB36A04,0xC2D7FFA7, + 0xB5D0CF31,0x2CD99E8B,0x5BDEAE1D,0x9B64C2B0,0xEC63F226,0x756AA39C,0x026D930A, + 0x9C0906A9,0xEB0E363F,0x72076785,0x05005713,0x95BF4A82,0xE2B87A14,0x7BB12BAE, + 0x0CB61B38,0x92D28E9B,0xE5D5BE0D,0x7CDCEFB7,0x0BDBDF21,0x86D3D2D4,0xF1D4E242, + 0x68DDB3F8,0x1FDA836E,0x81BE16CD,0xF6B9265B,0x6FB077E1,0x18B74777,0x88085AE6, + 0xFF0F6A70,0x66063BCA,0x11010B5C,0x8F659EFF,0xF862AE69,0x616BFFD3,0x166CCF45, + 0xA00AE278,0xD70DD2EE,0x4E048354,0x3903B3C2,0xA7672661,0xD06016F7,0x4969474D, + 0x3E6E77DB,0xAED16A4A,0xD9D65ADC,0x40DF0B66,0x37D83BF0,0xA9BCAE53,0xDEBB9EC5, + 0x47B2CF7F,0x30B5FFE9,0xBDBDF21C,0xCABAC28A,0x53B39330,0x24B4A3A6,0xBAD03605, + 0xCDD70693,0x54DE5729,0x23D967BF,0xB3667A2E,0xC4614AB8,0x5D681B02,0x2A6F2B94, + 0xB40BBE37,0xC30C8EA1,0x5A05DF1B,0x2D02EF8D + }; + + uint32_t crc32 = inCrc ^ 0xFFFFFFFF; + size_t i; + + for (i = 0; i < nbytes; i++) { + crc32 = (crc32 >> 8) ^ crc_table[ (crc32 ^ data[i]) & 0xFF ]; + } + + return (crc32 ^ 0xFFFFFFFF); +} + +/* --------------------------------------- */ + +int stun_compute_hmac_sha1(uint8_t* message, uint32_t nbytes, std::string key, uint8_t* output) { + + int r = 0; + mbedtls_md_context_t md_ctx = { 0 }; + const mbedtls_md_info_t *md_info = NULL; + + if (NULL == message) { + FAIL_MSG("Can't compute hmac_sha1 as the input message is empty."); + return -1; + } + + if (nbytes == 0) { + FAIL_MSG("Can't compute hmac_sha1 as the input length is invalid."); + return -2; + } + + if (key.size() == 0) { + FAIL_MSG("Can't compute the hmac_sha1 as the key size is 0."); + return -3; + } + + if (NULL == output) { + FAIL_MSG("Can't compute the hmac_sha as the output buffer is NULL."); + return -4; + } + + md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA1); + if (!md_info) { + FAIL_MSG("Failed to find the MBEDTLS_MD_SHA1"); + r = -5; + goto error; + } + + r = mbedtls_md_setup(&md_ctx, md_info, 1); + if (r != 0) { + FAIL_MSG("Failed to setup the md context."); + r = -6; + goto error; + } + + DONTEVEN_MSG("Calculating hmac-sha1 with key `%s` with size %zu over %zu bytes of data.", key.c_str(), key.size(), nbytes); + + r = mbedtls_md_hmac_starts(&md_ctx, (const unsigned char*)key.c_str(), key.size()); + if (r != 0) { + FAIL_MSG("Failed to start the hmac."); + r = -7; + goto error; + } + + r = mbedtls_md_hmac_update(&md_ctx, (const unsigned char*)message, nbytes); + if (r != 0) { + FAIL_MSG("Failed to update the hmac."); + r = -8; + goto error; + } + + r = mbedtls_md_hmac_finish(&md_ctx, output); + if (r != 0) { + FAIL_MSG("Failed to finish the hmac."); + r = -9; + goto error; + } + +#if 0 + printf("stun::compute_hmac_sha1 - verbose: computing hash over %u bytes, using key `%s`:\n", nbytes, key.c_str()); + printf("-----------------------------------\n\t0: "); + int nl = 0, lines = 0; + for (int i = 0; i < nbytes; ++i, ++nl) { + if (nl == 4) { + printf("\n\t"); + nl = 0; + lines++; + printf("%d: ", lines); + } + printf("%02X ", message[i]); + } + printf("\n-----------------------------------\n"); +#endif + +#if 0 + + printf("stun::compute_hmac_sha1 - verbose: computed hash: "); + int len = 20; + for(unsigned int i = 0; i < len; ++i) { + printf("%02X ", output[i]); + } + printf("\n"); +#endif + + error: + mbedtls_md_free(&md_ctx); + return r; +} + +int stun_compute_message_integrity(std::vector& buffer, std::string key, uint8_t* output) { + + uint16_t dx = 20; + uint16_t offset = 0; + uint16_t len = 0; + uint16_t type = 0; + uint8_t curr_size[2]; + + if (0 == buffer.size()) { + FAIL_MSG("Cannot compute message integrity; buffer empty."); + return -1; + } + + if (0 == key.size()) { + FAIL_MSG("Error: cannot compute message inegrity, key empty."); + return -2; + } + + curr_size[0] = buffer[2]; + curr_size[1] = buffer[3]; + + while (dx < buffer.size()) { + + type |= buffer[dx + 1] & 0x00FF; + type |= (buffer[dx + 0] << 8) & 0xFF00; + dx += 2; + + len |= (buffer[dx + 1] & 0x00FF); + len |= (buffer[dx + 0] << 8) & 0xFF00; + dx += 2; + + offset = dx; + dx += len; + + /* skip padding. */ + while ( (dx & 0x03) != 0 && dx < buffer.size()) { + dx++; + } + + if (type == STUN_ATTR_TYPE_MESSAGE_INTEGRITY) { + break; + } + + type = 0; + len = 0; + } + + /* rewrite Message-Length header field */ + buffer[2] = (offset >> 8) & 0xFF; + buffer[3] = offset & 0xFF; + + /* + and compute the sha1 + we subtract the last 4 bytes, which are the attribute-type and + attribute-length of the Message-Integrity field which are not + used. + */ + if (0 != stun_compute_hmac_sha1(&buffer[0], offset - 4, key, output)) { + buffer[2] = curr_size[0]; + buffer[3] = curr_size[1]; + return -3; + } + + /* rewrite message-length. */ + buffer[2] = curr_size[0]; + buffer[3] = curr_size[1]; + + return 0; +} + +int stun_compute_fingerprint(std::vector& buffer, uint32_t& result) { + + uint32_t dx = 20; + uint16_t offset = 0; + uint16_t len = 0; /* messsage-length */ + uint16_t type = 0; + uint8_t curr_size[2]; + + if (0 == buffer.size()) { + FAIL_MSG("Cannot compute fingerprint because the buffer is empty."); + return -1; + } + + /* copy current message-length */ + curr_size[0] = buffer[2]; + curr_size[1] = buffer[3]; + + /* compute the size that should be used as Message-Length when computing the CRC32 */ + while (dx < buffer.size()) { + + type |= buffer[dx + 1] & 0x00FF; + type |= (buffer[dx + 0] << 8) & 0xFF00; + dx += 2; + + len |= buffer[dx + 1] & 0x00FF; + len |= (buffer[dx + 0] << 8) & 0xFF00; + dx += 2; + + offset = dx; + dx += len; + + /* skip padding. */ + while ( (dx & 0x03) != 0 && dx < buffer.size()) { + dx++; + } + + if (type == STUN_ATTR_TYPE_FINGERPRINT) { + break; + } + + type = 0; + len = 0; + } + + /* rewrite message-length */ + offset -= 16; + buffer[2] = (offset >> 8) & 0xFF; + buffer[3] = offset & 0xFF; + + + // result = (checksum::crc32LE(0 ^ 0xFFFFFFFF, (const char*)&buffer[0], offset + 12) ^ 0xFFFFFFFF) ^ 0x5354554e; + result = poly_crc32(0L, &buffer[0], offset + 12) ^ 0x5354554e; + + /* and reset the size */ + buffer[2] = curr_size[0]; + buffer[3] = curr_size[1]; + + return 0; +} + +/* --------------------------------------- */ + +StunAttribute::StunAttribute() + :type(STUN_ATTR_TYPE_NONE) + ,length(0) +{ +} + +void StunAttribute::print() { + + DONTEVEN_MSG("StunAttribute.type: %s", stun_attribute_type_to_string(type).c_str()); + DONTEVEN_MSG("StunAttribute.length: %u", length); + + switch (type) { + + case STUN_ATTR_TYPE_XOR_MAPPED_ADDRESS: { + DONTEVEN_MSG("StunAttribute.xor_address.family: %s", stun_family_type_to_string(xor_address.family).c_str()); + DONTEVEN_MSG("StunAttribute.xor_address.port: %u", xor_address.port); + DONTEVEN_MSG("StunAttribute.xor_address.ip: %s", (char*)xor_address.ip); + break; + } + + case STUN_ATTR_TYPE_USERNAME: { + DONTEVEN_MSG("StunAttribute.username.value: `%.*s`", length, username.value); + break; + } + + case STUN_ATTR_TYPE_SOFTWARE: { + DONTEVEN_MSG("StunAttribute.software.value: `%.*s`", length, software.value); + break; + } + + case STUN_ATTR_TYPE_ICE_CONTROLLING: { + uint8_t* p = (uint8_t*)&ice_controlling.tie_breaker; + DONTEVEN_MSG("StunAttribute.ice_controlling.tie_breaker: 0x%04x%04x", *(uint32_t*)(p + 4), *(uint32_t*)(p)); + break; + } + + case STUN_ATTR_TYPE_PRIORITY: { + DONTEVEN_MSG("StunAttribute.priority.value: %u", priority.value); + break; + } + + case STUN_ATTR_TYPE_MESSAGE_INTEGRITY: { + std::stringstream ss; + for(int i = 0; i < 20; ++i) { + ss << std::hex << (int) message_integrity.sha1[i]; + } + std::string str = ss.str(); + DONTEVEN_MSG("StunAttribute.message_integrity.sha1: %s", str.c_str()); + break; + } + + case STUN_ATTR_TYPE_FINGERPRINT: { + DONTEVEN_MSG("StunAttribute.fingerprint.value: 0x%08x", fingerprint.value); + break; + } + } +} + +/* --------------------------------------- */ + +StunMessage::StunMessage() + :type(STUN_MSG_TYPE_NONE) + ,length(0) + ,cookie(0x2112a442) +{ + transaction_id[0] = 0; + transaction_id[1] = 0; + transaction_id[2] = 0; +} + +void StunMessage::setType(uint16_t messageType) { + type = messageType; +} + +void StunMessage::setTransactionId(uint32_t a, uint32_t b, uint32_t c) { + transaction_id[0] = a; + transaction_id[1] = b; + transaction_id[2] = c; +} + +void StunMessage::removeAttributes() { + attributes.clear(); +} + +void StunMessage::addAttribute(StunAttribute& attr) { + attributes.push_back(attr); +} + +void StunMessage::print() { + DONTEVEN_MSG("StunMessage.type: %s", stun_message_type_to_string(type).c_str()); + DONTEVEN_MSG("StunMessage.length: %u", length); + DONTEVEN_MSG("StunMessage.cookie: 0x%08X", cookie); + DONTEVEN_MSG("StunMessage.transaction_id: 0x%08X, 0x%08X, 0x%08X", transaction_id[0], transaction_id[1], transaction_id[2]); +} + +StunAttribute* StunMessage::getAttributeByType(uint16_t type) { + size_t nattribs = attributes.size(); + for (size_t i = 0; i < nattribs; ++i) { + if (attributes[i].type == type) { + return &attributes[i]; + } + } + return NULL; +} +/* --------------------------------------- */ +StunReader::StunReader() + :buffer_data(NULL) + ,buffer_size(0) + ,read_dx(0) +{ +} + +int StunReader::parse(uint8_t* data, size_t nbytes, size_t& nparsed, StunMessage& msg) { + + StunAttribute attr; + size_t attr_offset = 0; + nparsed = 0; + + if (NULL == data) { + FAIL_MSG("Cannot parse stun message because given data ptr is a NULL."); + return -1; + } + + if (nbytes < 20) { + FAIL_MSG("Cannot parse stun message because given nbytes is < 20."); + return -2; + } + + buffer_data = data; + buffer_size = nbytes; + read_dx = 0; + + /* Read stun header. */ + msg.type = readU16(); + msg.length = readU16(); + msg.cookie = readU32(); + msg.transaction_id[0] = readU32(); + msg.transaction_id[1] = readU32(); + msg.transaction_id[2] = readU32(); + + if ((nbytes - 20) < msg.length) { + FAIL_MSG("Buffer is too small to contain the full stun message."); + return -3; + } + + /* Read all the attributes. */ + while ((read_dx + 4) < buffer_size) { + + attr.type = readU16(); + attr.length = readU16(); + attr_offset = read_dx; + + switch (attr.type) { + + case STUN_ATTR_TYPE_USERNAME: { + if (0 != parseUsername(attr)) { + FAIL_MSG("Failed to read the username."); + return -4; + } + break; + } + + case STUN_ATTR_TYPE_XOR_MAPPED_ADDRESS: { + if (0 != parseXorMappedAddress(attr)) { + FAIL_MSG("Failed to read the xor-mapped-address."); + return -4; + } + break; + } + + case STUN_ATTR_TYPE_ICE_CONTROLLING: { + if (0 != parseIceControlling(attr)) { + FAIL_MSG("Failed to read the ice-contontrolling attribute."); + return -4; + } + break; + } + + case STUN_ATTR_TYPE_PRIORITY: { + if (0 != parsePriority(attr)) { + FAIL_MSG("Failed to read the priority attribute."); + return -4; + } + break; + } + + case STUN_ATTR_TYPE_MESSAGE_INTEGRITY: { + if (0 != parseMessageIntegrity(attr)) { + FAIL_MSG("Failed to parse the message integrity."); + return -4; + } + break; + } + + case STUN_ATTR_TYPE_FINGERPRINT: { + if (0 != parseFingerprint(attr)) { + FAIL_MSG("Failed to parse the fingerprint."); + return -4; + } + break; + } + + case STUN_ATTR_TYPE_SOFTWARE: { + if (0 != parseSoftware(attr)) { + FAIL_MSG("Failed to parse the software attribute."); + return -4; + } + break; + } + + default: { + DONTEVEN_MSG("Unhandled stun attribute: 0x%04X, %s", attr.type, stun_attribute_type_to_string(attr.type).c_str()); + break; + } + } + + /* Move the read_dx so it's positioned after the currently parsed attribute */ + read_dx = attr_offset + attr.length; + while ( (read_dx & 0x03) != 0 && (read_dx < buffer_size)) { + read_dx++; + } + + msg.attributes.push_back(attr); + + attr.print(); + } + + nparsed = read_dx; + + return 0; +} + +/* --------------------------------------- */ + +int StunReader::parseFingerprint(StunAttribute& attr) { + + if ((read_dx + 4) > buffer_size) { + FAIL_MSG("Cannot read FINGERPRINT because the buffer is too small."); + return -1; + } + + attr.fingerprint.value = readU32(); + + return 0; +} + +int StunReader::parseMessageIntegrity(StunAttribute& attr) { + + if ((read_dx + 20) > buffer_size) { + FAIL_MSG("Cannot read the MESSAGE-INTEGRITY because the buffer is too small."); + return -1; + } + + attr.message_integrity.sha1 = buffer_data + read_dx; + + return 0; +} + +int StunReader::parsePriority(StunAttribute& attr) { + + if ((read_dx + 4) > buffer_size) { + FAIL_MSG("Cannot read the PRIORITY attribute because the buffer is too small."); + return -1; + } + + attr.priority.value = readU32(); + + return 0; +} + +int StunReader::parseSoftware(StunAttribute& attr) { + + if ((read_dx + attr.length) > buffer_size) { + FAIL_MSG("Cannot read SOFTWARE attribute because the buffer is too small."); + return -1; + } + + attr.software.value = (char*)(buffer_data + read_dx); + + return 0; +} + +int StunReader::parseIceControlling(StunAttribute& attr) { + + if ((read_dx + 8) > buffer_size) { + FAIL_MSG("Cannot read the ICE-CONTROLLING attribute because the buffer is too small."); + return -1; + } + + attr.ice_controlling.tie_breaker = readU64(); + + return 0; +} + +int StunReader::parseUsername(StunAttribute& attr) { + + if ((read_dx + attr.length) > buffer_size) { + FAIL_MSG("Cannot read USRENAME attribute because the buffer is too small."); + return -1; + } + + attr.username.value = (char*)(buffer_data + read_dx); + + return 0; +} + +int StunReader::parseXorMappedAddress(StunAttribute& attr) { + + if ( (read_dx + 8) > buffer_size) { + FAIL_MSG("Cannot read XOR_MAPPED_ADDRESS because the buffer is too small."); + return -1; + } + + /* Skip the first byte, should be ignored by readers. */ + read_dx++; + + /* Read family */ + attr.xor_address.family = readU8(); + + if (STUN_IP4 != attr.xor_address.family) { + FAIL_MSG("Currently we only implemented the IP4 XOR_MAPPED_ADDRESS"); + return -2; + } + + uint8_t cookie[] = { 0x42, 0xA4, 0x12, 0x21 }; + uint32_t ip = 0; + uint8_t* ip_ptr = (uint8_t*) &ip; + uint8_t* port_ptr = (uint8_t*) &attr.xor_address.port; + + /* Read the port. */ + attr.xor_address.port = readU16(); + port_ptr[0] = port_ptr[0] ^ cookie[2]; + port_ptr[1] = port_ptr[1] ^ cookie[3]; + + /* Read IP4. */ + ip = readU32(); + ip_ptr[0] = ip_ptr[0] ^ cookie[0]; + ip_ptr[1] = ip_ptr[1] ^ cookie[1]; + ip_ptr[2] = ip_ptr[2] ^ cookie[2]; + ip_ptr[3] = ip_ptr[3] ^ cookie[3]; + + sprintf((char*)attr.xor_address.ip, + "%u.%u.%u.%u", + ip_ptr[3], + ip_ptr[2], + ip_ptr[1], + ip_ptr[0]); + + return 0; +} + +/* --------------------------------------- */ + +uint8_t StunReader::readU8() { + + if ( (read_dx + 1) > buffer_size) { + FAIL_MSG("Cannot readU8(), out of bounds."); + return 0; + } + + uint8_t v = 0; + v = buffer_data[read_dx]; + read_dx = read_dx + 1; + + return v; +} + +uint16_t StunReader::readU16() { + + if ( (read_dx + 2) > buffer_size) { + FAIL_MSG("Cannot readU16(), out of bounds."); + return 0; + } + + uint16_t v = 0; + uint8_t* p = (uint8_t*)&v; + p[0] = buffer_data[read_dx + 1]; + p[1] = buffer_data[read_dx + 0]; + read_dx = read_dx + 2; + + return v; +} + +uint32_t StunReader::readU32() { + + if ( (read_dx + 4) > buffer_size) { + FAIL_MSG("Cannot readU32(), out of bounds."); + return 0; + } + + uint32_t v = 0; + uint8_t* p = (uint8_t*)&v; + p[0] = buffer_data[read_dx + 3]; + p[1] = buffer_data[read_dx + 2]; + p[2] = buffer_data[read_dx + 1]; + p[3] = buffer_data[read_dx + 0]; + read_dx = read_dx + 4; + + return v; +} + +uint64_t StunReader::readU64() { + + if ( (read_dx + 8) > buffer_size) { + FAIL_MSG("Cannot readU64(), out of bounds."); + return 0; + } + + uint64_t v = 0; + uint8_t* p = (uint8_t*)&v; + + p[0] = buffer_data[read_dx + 7]; + p[1] = buffer_data[read_dx + 6]; + p[2] = buffer_data[read_dx + 5]; + p[3] = buffer_data[read_dx + 4]; + p[4] = buffer_data[read_dx + 3]; + p[5] = buffer_data[read_dx + 2]; + p[6] = buffer_data[read_dx + 1]; + p[7] = buffer_data[read_dx + 0]; + + read_dx = read_dx + 8; + + return v; +} + +/* --------------------------------------- */ + +StunWriter::StunWriter() + :padding_byte(0) +{ +} + +int StunWriter::begin(StunMessage& msg, uint8_t paddingByte) { + + /* set the byte that we use when adding padding. */ + padding_byte = paddingByte; + + /* make sure we start with an empty buffer. */ + buffer.clear(); + + /* writer header */ + writeU16(msg.type); /* type */ + writeU16(0); /* length */ + writeU32(msg.cookie); /* magic cookie */ + writeU32(msg.transaction_id[0]); /* transaction id */ + writeU32(msg.transaction_id[1]); /* transaction id */ + writeU32(msg.transaction_id[2]); /* transaction id */ + + return 0; +} + +int StunWriter::end() { + + if (buffer.size() < 20) { + FAIL_MSG("Cannot finalize the stun message because the header wasn't written."); + return -1; + } + + rewriteU16(2, buffer.size() - 20); + + return 0; +} + +/* --------------------------------------- */ + +int StunWriter::writeXorMappedAddress(sockaddr_in addr) { + + if (AF_INET != addr.sin_family) { + FAIL_MSG("Currently we only support ip4 xor-mapped-address attributes."); + return -1; + } + + return writeXorMappedAddress(STUN_IP4, ntohs(addr.sin_port), ntohl(addr.sin_addr.s_addr)); +} + +int StunWriter::writeXorMappedAddress(uint8_t family, uint16_t port, const std::string& ip) { + + uint32_t ip_int = 0; + if (0 != convertIp4StringToInt(ip, ip_int)) { + FAIL_MSG("Cannot write xor-mapped-address, because we failed to convert the given IP4 string into a uint32_t."); + return -1; + } + + return writeXorMappedAddress(family, port, ip_int); +} + +/* `ip` is in host byte order. */ +int StunWriter::writeXorMappedAddress(uint8_t family, uint16_t port, uint32_t ip) { + + if (buffer.size() < 20) { + FAIL_MSG("Cannot write the xor-mapped-address. Make sure you wrote the header first."); + return -1; + } + + if (STUN_IP4 != family) { + FAIL_MSG("Cannot write the xor-mapped-address, we only support ip4 for now."); + return -2; + } + + /* xor the port */ + uint8_t cookie[] = { 0x42, 0xA4, 0x12, 0x21 }; + uint8_t* port_ptr = (uint8_t*)&port; + port_ptr[0] = port_ptr[0] ^ cookie[2]; + port_ptr[1] = port_ptr[1] ^ cookie[3]; + + /* xor the ip */ + uint8_t* ip_ptr = (uint8_t*)&ip; + ip_ptr[0] = ip_ptr[0] ^ cookie[0]; + ip_ptr[1] = ip_ptr[1] ^ cookie[1]; + ip_ptr[2] = ip_ptr[2] ^ cookie[2]; + ip_ptr[3] = ip_ptr[3] ^ cookie[3]; + + /* write header */ + writeU16(STUN_ATTR_TYPE_XOR_MAPPED_ADDRESS); + writeU16(8); + writeU8(0); + writeU8(family); + + /* port and ip */ + writeU16(port); + writeU32(ip); + + writePadding(); + + return 0; +} + +int StunWriter::writeUsername(const std::string& username) { + + if (buffer.size() < 20) { + FAIL_MSG("Cannot write username because you didn't call `begin()` and the STUN header hasn't been written yet.."); + return -1; + } + + writeU16(STUN_ATTR_TYPE_USERNAME); + writeU16(username.size()); + writeString(username); + writePadding(); + + return 0; +} + +int StunWriter::writeSoftware(const std::string& software) { + + if (buffer.size() < 20) { + FAIL_MSG("Cannot write software because it seems that you didn't call `begin()` which writes the stun header."); + return -1; + } + + if (software.size() > 763) { + FAIL_MSG("Given software length is too big. "); + return -2; + } + + writeU16(STUN_ATTR_TYPE_SOFTWARE); + writeU16(software.size()); + writeString(software); + writePadding(); + + return 0; +} + +int StunWriter::writeMessageIntegrity(const std::string& password) { + + if (buffer.size() < 20) { + FAIL_MSG("Cannot write the message integrity because it seems that you didn't call `begin()` which writes the stun header."); + return -1; + } + + if (0 == password.size()) { + FAIL_MSG("The password is empty, cannot write the message integrity."); + return -2; + } + + writeU16(STUN_ATTR_TYPE_MESSAGE_INTEGRITY); + writeU16(20); + + /* calculate the sha1 over the current buffer. */ + uint8_t sha1[20] = {}; + if (0 != stun_compute_message_integrity(buffer, password, sha1)) { + FAIL_MSG("Failed to write the message integrity."); + return -3; + } + + /* store the message-integrity */ + std::copy(sha1, sha1 + 20, std::back_inserter(buffer)); + + writePadding(); + + return 0; +} + +/* https://tools.ietf.org/html/rfc5389#section-15.5 */ +int StunWriter::writeFingerprint() { + + if (buffer.size() < 20) { + FAIL_MSG("Cannot write the fingerprint because it seems that you didn't write the header, call `begin()` first."); + return -1; + } + + writeU16(STUN_ATTR_TYPE_FINGERPRINT); + writeU16(4); + + uint32_t fingerprint = 0; + if (0 != stun_compute_fingerprint(buffer, fingerprint)) { + FAIL_MSG("Failed to compute the fingerprint."); + return -2; + } + + writeU32(fingerprint); + writePadding(); + + return 0; +} + +/* --------------------------------------- */ + +int StunWriter::convertIp4StringToInt(const std::string& ip, uint32_t& result) { + + if (0 == ip.size()) { + FAIL_MSG("Given ip string is empty."); + return -1; + } + + in_addr addr; + if (1 != inet_pton(AF_INET, ip.c_str(), &addr)) { + FAIL_MSG("inet_pton() failed, cannot convert ip4 string into uint32_t."); + return -2; + } + + result = ntohl(addr.s_addr); + + return 0; +} + +/* --------------------------------------- */ + +void StunWriter::writeU8(uint8_t v) { + + buffer.push_back(v); +} + +void StunWriter::writeU16(uint16_t v) { + + uint8_t* p = (uint8_t*)&v; + buffer.push_back(p[1]); + buffer.push_back(p[0]); +} + +void StunWriter::writeU32(uint32_t v) { + + uint8_t* p = (uint8_t*)&v; + buffer.push_back(p[3]); + buffer.push_back(p[2]); + buffer.push_back(p[1]); + buffer.push_back(p[0]); +} + +void StunWriter::writeU64(uint64_t v) { + + uint8_t* p = (uint8_t*)&v; + buffer.push_back(p[7]); + buffer.push_back(p[6]); + buffer.push_back(p[5]); + buffer.push_back(p[4]); + buffer.push_back(p[3]); + buffer.push_back(p[2]); + buffer.push_back(p[1]); + buffer.push_back(p[0]); +} + +void StunWriter::rewriteU16(size_t dx, uint16_t v) { + + if ((dx + 2) > buffer.size()) { + FAIL_MSG("Trying to rewriteU16, but our buffer is too small to contain a u16."); + return; + } + + uint8_t* p = (uint8_t*) &v; + buffer[dx + 0] = p[1]; + buffer[dx + 1] = p[0]; +} + +void StunWriter::rewriteU32(size_t dx, uint32_t v) { + + if ((dx + 4) > buffer.size()) { + FAIL_MSG("Trying to rewrite U32 in Stun::StunWriter::rewriteU32() but index is out of bounds.\n"); + return; + } + + uint8_t* p = (uint8_t*)&v; + buffer[dx + 0] = p[3]; + buffer[dx + 1] = p[2]; + buffer[dx + 2] = p[1]; + buffer[dx + 3] = p[0]; +} + +void StunWriter::writeString(const std::string& str) { + std::copy(str.begin(), str.end(), std::back_inserter(buffer)); +} + +void StunWriter::writePadding() { + + while ((buffer.size() & 0x03) != 0) { + buffer.push_back(padding_byte); + } +} + +/* --------------------------------------- */ diff --git a/lib/stun.h b/lib/stun.h new file mode 100644 index 00000000..0e69f9b4 --- /dev/null +++ b/lib/stun.h @@ -0,0 +1,250 @@ +#pragma once +#include +#include +#include +#include + +/* --------------------------------------- */ + +#define STUN_IP4 0x01 +#define STUN_IP6 0x02 + +#define STUN_MSG_TYPE_NONE 0x0000 +#define STUN_MSG_TYPE_BINDING_REQUEST 0x0001 +#define STUN_MSG_TYPE_BINDING_RESPONSE_SUCCESS 0x0101 +#define STUN_MSG_TYPE_BINDING_RESPONSE_ERROR 0x0111 +#define STUN_MSG_TYPE_BINDING_INDICATION 0x0011 + +#define STUN_ATTR_TYPE_NONE 0x0000 +#define STUN_ATTR_TYPE_MAPPED_ADDR 0x0001 +#define STUN_ATTR_TYPE_CHANGE_REQ 0x0003 +#define STUN_ATTR_TYPE_USERNAME 0x0006 +#define STUN_ATTR_TYPE_MESSAGE_INTEGRITY 0x0008 +#define STUN_ATTR_TYPE_ERR_CODE 0x0009 +#define STUN_ATTR_TYPE_UNKNOWN_ATTRIBUTES 0x000a +#define STUN_ATTR_TYPE_CHANNEL_NUMBER 0x000c +#define STUN_ATTR_TYPE_LIFETIME 0x000d +#define STUN_ATTR_TYPE_XOR_PEER_ADDR 0x0012 +#define STUN_ATTR_TYPE_DATA 0x0013 +#define STUN_ATTR_TYPE_REALM 0x0014 +#define STUN_ATTR_TYPE_NONCE 0x0015 +#define STUN_ATTR_TYPE_XOR_RELAY_ADDRESS 0x0016 +#define STUN_ATTR_TYPE_REQ_ADDRESS_FAMILY 0x0017 +#define STUN_ATTR_TYPE_EVEN_PORT 0x0018 +#define STUN_ATTR_TYPE_REQUESTED_TRANSPORT 0x0019 +#define STUN_ATTR_TYPE_DONT_FRAGMENT 0x001a +#define STUN_ATTR_TYPE_XOR_MAPPED_ADDRESS 0x0020 +#define STUN_ATTR_TYPE_RESERVATION_TOKEN 0x0022 +#define STUN_ATTR_TYPE_PRIORITY 0x0024 +#define STUN_ATTR_TYPE_USE_CANDIDATE 0x0025 +#define STUN_ATTR_TYPE_PADDING 0x0026 +#define STUN_ATTR_TYPE_RESPONSE_PORT 0x0027 +#define STUN_ATTR_TYPE_SOFTWARE 0x8022 +#define STUN_ATTR_TYPE_ALTERNATE_SERVER 0x8023 +#define STUN_ATTR_TYPE_FINGERPRINT 0x8028 +#define STUN_ATTR_TYPE_ICE_CONTROLLED 0x8029 +#define STUN_ATTR_TYPE_ICE_CONTROLLING 0x802a +#define STUN_ATTR_TYPE_RESPONSE_ORIGIN 0x802b +#define STUN_ATTR_TYPE_OTHER_ADDRESS 0x802c + +/* --------------------------------------- */ + +std::string stun_message_type_to_string(uint16_t type); +std::string stun_attribute_type_to_string(uint16_t type); +std::string stun_family_type_to_string(uint8_t type); + +/* + Compute the hmac-sha1 over message. + uint8_t* message: the data over which we compute the hmac sha + uint32_t nbytes: the number of bytse in message + std::string key: key to use for hmac + uint8_t* output: we write the sha1 into this buffer. +*/ +int stun_compute_hmac_sha1(uint8_t* message, uint32_t nbytes, std::string key, uint8_t* output); + +/* + Compute the Message-Integrity of a stun message. + This will not change the given buffer. + + std::vector& buffer: the buffer that contains a valid stun message + std::string key: key to use for hmac + uint8_t* output: will be filled with the correct hmac-sha1 of that represents the integrity message value. +*/ +int stun_compute_message_integrity(std::vector& buffer, std::string key, uint8_t* output); + +/* + Compute the fingerprint value for the stun message. + This will not change the given buffer. + std::vector& buffer: the buffer that contains a valid stun message. + uint32_t& result: will be set to the calculated crc value. +*/ +int stun_compute_fingerprint(std::vector& buffer, uint32_t& result); + +/* --------------------------------------- */ + +/* https://tools.ietf.org/html/rfc5389#section-15.10 */ +class StunAttribSoftware { +public: + char* value; +}; + +class StunAttribFingerprint { +public: + uint32_t value; +}; + +/* https://tools.ietf.org/html/rfc5389#section-15.4 */ +class StunAttribMessageIntegrity { +public: + uint8_t* sha1; +}; + +/* https://tools.ietf.org/html/rfc5245#section-19.1 */ +class StunAttribPriority { +public: + uint32_t value; +}; + +/* https://tools.ietf.org/html/rfc5245#section-19.1 */ +class StunAttribIceControllling { +public: + uint64_t tie_breaker; +}; + +/* https://tools.ietf.org/html/rfc3489#section-11.2.6 */ +class StunAttribUsername { +public: + char* value; /* Must use `length` member of attribute that indicates the number of valid bytes in the username. */ +}; + +/* https://tools.ietf.org/html/rfc5389#section-15.2 */ +class StunAttribXorMappedAddress { +public: + uint8_t family; + uint16_t port; + uint8_t ip[16]; +}; + +/* --------------------------------------- */ + +class StunAttribute { +public: + StunAttribute(); + void print(); + +public: + uint16_t type; + uint16_t length; + union { + StunAttribXorMappedAddress xor_address; + StunAttribUsername username; + StunAttribIceControllling ice_controlling; + StunAttribPriority priority; + StunAttribSoftware software; + StunAttribMessageIntegrity message_integrity; + StunAttribFingerprint fingerprint; + }; +}; + +/* --------------------------------------- */ + +class StunMessage { +public: + StunMessage(); + void setType(uint16_t type); + void setTransactionId(uint32_t a, uint32_t b, uint32_t c); + void print(); + void addAttribute(StunAttribute& attr); + void removeAttributes(); + StunAttribute* getAttributeByType(uint16_t type); + +public: + uint16_t type; + uint16_t length; + uint32_t cookie; + uint32_t transaction_id[3]; + std::vector attributes; +}; + +/* --------------------------------------- */ + +class StunReader { +public: + StunReader(); + int parse(uint8_t* data, size_t nbytes, size_t& nparsed, StunMessage& msg); /* `nparsed` and `msg` are filled. */ + +private: + int parseXorMappedAddress(StunAttribute& attr); + int parseUsername(StunAttribute& attr); + int parseIceControlling(StunAttribute& attr); + int parsePriority(StunAttribute& attr); + int parseSoftware(StunAttribute& attr); + int parseMessageIntegrity(StunAttribute& attr); + int parseFingerprint(StunAttribute& attr); + + uint8_t readU8(); + uint16_t readU16(); + uint32_t readU32(); + uint64_t readU64(); + +private: + uint8_t* buffer_data; + size_t buffer_size; + size_t read_dx; +}; + +/* --------------------------------------- */ + +class StunWriter { +public: + StunWriter(); + + /* write header and finalize. call for each stun message */ + int begin(StunMessage& msg, uint8_t paddingByte = 0x00); /* I've added the padding byte here so that we can use the examples that can be found here https://tools.ietf.org/html/rfc5769#section-2.2 as they use 0x20 or 0x00 as the padding byte which is correct as you are free to use w/e padding byte you want. */ + int end(); + + /* write attributes */ + int writeXorMappedAddress(sockaddr_in addr); + int writeXorMappedAddress(uint8_t family, uint16_t port, uint32_t ip); + int writeXorMappedAddress(uint8_t family, uint16_t port, const std::string& ip); + int writeUsername(const std::string& username); + int writeSoftware(const std::string& software); + int writeMessageIntegrity(const std::string& password); /* When using WebRtc this is the ice-upwd of the other agent. */ + int writeFingerprint(); /* Must be the last attribute in the message. When adding a fingerprint, make sure that it is added after the message-integrity (when you also use a message-integrity). */ + + /* get buffer */ + uint8_t* getBufferPtr(); + size_t getBufferSize(); + +private: + void writeU8(uint8_t v); + void writeU16(uint16_t v); + void writeU32(uint32_t v); + void writeU64(uint64_t v); + void rewriteU16(size_t dx, uint16_t v); + void rewriteU32(size_t dx, uint32_t v); + void writeString(const std::string& str); + void writePadding(); + int convertIp4StringToInt(const std::string& ip, uint32_t& result); + +private: + std::vector buffer; + uint8_t padding_byte; +}; + +/* --------------------------------------- */ + +inline uint8_t* StunWriter::getBufferPtr() { + + if (0 == buffer.size()) { + return NULL; + } + + return &buffer[0]; +} + +inline size_t StunWriter::getBufferSize() { + return buffer.size(); +} + +/* --------------------------------------- */ diff --git a/scripts/webrtc_compile.sh b/scripts/webrtc_compile.sh new file mode 100755 index 00000000..9099478f --- /dev/null +++ b/scripts/webrtc_compile.sh @@ -0,0 +1,52 @@ +#!/bin/sh + +pd=${PWD} +d=${PWD}/../ +config="Release" + +if [ ! -d ${d}/external ] ; then + mkdir ${d}/external +fi + +if [ ! -d ${d}/external/mbedtls ] ; then + #prepare mbedtls for build + cd ${d}/external/ + git clone https://github.com/diederickh/mbedtls + + cd ${d}/external/mbedtls + git checkout -b dtls_srtp_support + git merge 15179bfbaa794506c06f923f85d7c71f0dfd89e9 + + git am < ${pd}/webrtc_mbedtls_keying_material_fix.diff + if [ $? -ne 0 ] ; then + echo "Failed to apply patch" + exit + fi +fi + +if [ ! -d ${d}/build ] ; then + mkdir ${d}/build +fi + +if [ ! -d ${d}/installed ] ; then + mkdir ${d}/installed + #Build mbedtls + mkdir -p ${d}/external/mbedtls/build + cd ${d}/external/mbedtls/build + cmake -DCMAKE_INSTALL_PREFIX=${d}/installed -DENABLE_PROGRAMS=Off .. + cmake --build . --config ${config} --target install -- -j 8 +fi + + +cd ${d} +export PATH="${PATH}:${d}/installed/include" +cmake -DCMAKE_CXX_FLAGS="-I${d}/installed/include/ -L${d}/installed/lib/" \ + -DCMAKE_PREFIX_PATH=${d}/installed/include \ + -DCMAKE_MODULE_PATH=${d}/installed/ \ + -DPERPETUAL=1 \ + -DDEBUG=3 \ + -GNinja \ + . + +ninja + diff --git a/scripts/webrtc_mbedtls_keying_material_fix.diff b/scripts/webrtc_mbedtls_keying_material_fix.diff new file mode 100644 index 00000000..89ced3e6 --- /dev/null +++ b/scripts/webrtc_mbedtls_keying_material_fix.diff @@ -0,0 +1,34 @@ +From ba52913047a6821dac15f8320c8857cef589bb6f Mon Sep 17 00:00:00 2001 +From: roxlu +Date: Mon, 2 Jul 2018 22:26:21 +0200 +Subject: [PATCH] Fixes to get DTLS SRTP to work with WebRTC + +--- + library/ssl_tls.c | 4 +--- + 1 file changed, 1 insertion(+), 3 deletions(-) + +diff --git a/library/ssl_tls.c b/library/ssl_tls.c +index fe27c6a8..25b86da8 100644 +--- a/library/ssl_tls.c ++++ b/library/ssl_tls.c +@@ -6436,7 +6436,6 @@ mbedtls_ssl_srtp_profile mbedtls_ssl_get_dtls_srtp_protection_profile( const mbe + } + + int mbedtls_ssl_get_dtls_srtp_key_material( const mbedtls_ssl_context *ssl, unsigned char *key, size_t *key_len ) { +- *key_len = 0; + + /* check output buffer size */ + if ( *key_len < ssl->dtls_srtp_info.dtls_srtp_keys_len) { +@@ -7706,8 +7705,7 @@ void mbedtls_ssl_free( mbedtls_ssl_context *ssl ) + #endif + + #if defined (MBEDTLS_SSL_DTLS_SRTP) +- mbedtls_zeroize( ssl->dtls_srtp_info.dtls_srtp_keys, ssl->dtls_srtp_info.dtls_srtp_keys_len ); +- // mbedtls_free( ssl->dtls_srtp_keys ); ++ mbedtls_platform_zeroize( ssl->dtls_srtp_info.dtls_srtp_keys, ssl->dtls_srtp_info.dtls_srtp_keys_len ); + #endif /* MBEDTLS_SSL_DTLS_SRTP */ + + MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= free" ) ); +-- +2.17.1 + diff --git a/scripts/webrtc_run.sh b/scripts/webrtc_run.sh new file mode 100755 index 00000000..0e5b1b00 --- /dev/null +++ b/scripts/webrtc_run.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +cd ${PWD}/../build +set -x +#export MIST_CONTROL=1 +make MistOutWebRTC + +if [ $? -ne 0 ] ; then + echo "Failed to compile." + exit +fi + +#-fsanitize=address +#export MALLOC_CHECK_=2 +# valgrind --trace-children=yes +# ASAN_OPTIONS=symbolize=1 ASAN_SYMBOLIZER_PATH=$(shell which llvm-symbolizer) +if [ 0 -ne 0 ] ; then + reset && valgrind --trace-children=yes ./MistOutHTTPS \ + --port 4433 \ + --cert ~/.ssh/certs/arch680.rox.lu.crt \ + --key ~/.ssh/certs/arch680.rox.lu.key \ + --debug 10 +else + reset && ./MistOutHTTPS \ + --port 4433 \ + --cert ~/.ssh/certs/arch680.rox.lu.crt \ + --key ~/.ssh/certs/arch680.rox.lu.key \ + --debug 10 +fi + diff --git a/scripts/webrtc_srtp_cmakelists.txt b/scripts/webrtc_srtp_cmakelists.txt new file mode 100644 index 00000000..58e2d588 --- /dev/null +++ b/scripts/webrtc_srtp_cmakelists.txt @@ -0,0 +1,112 @@ +cmake_minimum_required(VERSION 3.8) +project(srtp2) +set(bd ${CMAKE_CURRENT_LIST_DIR}) +set(sd ${bd}) + +list(APPEND lib_sources + ${sd}/srtp/srtp.c + ${sd}/srtp/ekt.c + ${sd}/crypto/kernel/alloc.c + ${sd}/crypto/kernel/err.c + ${sd}/crypto/kernel/crypto_kernel.c + ${sd}/crypto/kernel/key.c + ${sd}/crypto/math/datatypes.c + ${sd}/crypto/math/stat.c + ${sd}/crypto/replay/rdbx.c + ${sd}/crypto/replay/rdb.c + ${sd}/crypto/replay/ut_sim.c + ${sd}/crypto/cipher/cipher.c + ${sd}/crypto/cipher/null_cipher.c + ${sd}/crypto/cipher/aes.c + ${sd}/crypto/hash/auth.c + ${sd}/crypto/hash/null_auth.c + ${sd}/crypto/cipher/aes_icm.c + ${sd}/crypto/hash/sha1.c + ${sd}/crypto/hash/hmac.c + ) + +# -- start of checks + +include(CheckIncludeFiles) +include(CheckFunctionExists) +include(CheckLibraryExists) +include(CheckTypeSize) +include(TestBigEndian) + +set(AC_APPLE_UNIVERSAL_BUILD 0) +set(CPU_CISC 1) +set(CPU_RISC 0) +set(ENABLE_DEBUG_LOGGING 0) +set(ERR_REPORTING_FILE "libsrtp_error.log") +set(ERR_REPORTING_STDOUT 0) +set(VERSION "2.3") + +check_include_files(arpa/inet.h HAVE_ARPA_INET_H) +check_include_files(byteswap.h HAVE_BYTESWAP_H) +check_function_exists(inet_aton HAVE_INET_ATON) +check_type_size(int16_t HAVE_INT16_T) +check_type_size(int32_t HAVE_INT32_T) +check_type_size(int8_t HAVE_INT8_T) +check_include_files(inttypes.h HAVE_INTTYPES_H) +check_library_exists(dl dlopen "" HAVE_LIBDL) +check_library_exists(socket socket "" HAVE_LIBSOCKET) +check_library_exists(z zlibVersion "" HAVE_LIBZ) +check_include_files(machine/types.h HAVE_MACHINE_TYPES_H) +check_include_files(memory.h HAVE_MEMORY_H) +check_include_files(netinet/in.h HAVE_NETINET_IN_H) +# @todo check winpcap +check_function_exists(sigaction HAVE_SIGACTION) +check_function_exists(socket HAVE_SOCKET) +check_include_files(stdint.h HAVE_STDINT_H) +check_include_files(stdlib.h HAVE_STDLIB_H) +check_include_files(strings.h HAVE_STRINGS_H) +check_include_files(string.h HAVE_STRING_H) +check_include_files(sys/int_types.h HAVE_SYS_INT_TYPES_H) +check_include_files(sys/socket.h HAVE_SYS_SOCKET_H) +check_include_files(sys/stat.h HAVE_SYS_STAT_H) +check_include_files(sys/types.h HAVE_SYS_TYPES_H) +check_include_files(sys/uio.h HAVE_SYS_UIO_H) +check_type_size(uint16_t HAVE_UINT16_T) +check_type_size(uint32_t HAVE_UINT32_T) +check_type_size(uint64_t HAVE_UINT64_T) +check_type_size(uint8_t HAVE_UINT8_T) +check_include_files(unistd.h HAVE_UNISTD_H) +check_function_exists(usleep HAVE_USLEEP) +check_include_files(windows.h HAVE_WINDOWS_H) +check_include_files(winsock2.h HAVE_WINSOCK2_H) +# @todo HAVE_X86 +# @todo OPENSSL +# @todo OPENSSL_CLEANSE_BROKEN +# @todo OPENSSL_KDF +# @todo PACKAGE_BUGREPORT +set(PACKAGE_BUGREPORT "testers@ddvdtech.com") +set(PACKAGE_NAME "libsrtp") +set(PACKAGE_VERSION "${VERSION}") +set(PACKAGE_STRING "${PACKAGE_NAME}_${VERSION}") +set(PACKAGE_TARNAME "${PACKAGE_STRING}.tar") +set(PACKAGE_URL "http://www.mistserver.org") +check_type_size("unsigned long" SIZEOF_UNSIGNED_LONG) +check_type_size("unsigned long long" SIZEOF_UNSIGNED_LONG_LONG) +check_include_files("stdlib.h;stdarg.h;string.h;float.h" STDC_HEADERS) +configure_file(${bd}/config.cmake ${bd}/crypto/include/config.h) + +#-------------------------------------------------------- + +include_directories( + ${bd}/include/ + ${bd}/crypto/ + ${bd}/crypto/include + ) + +add_library(srtp2 STATIC ${lib_sources}) +target_compile_definitions(srtp2 PUBLIC HAVE_CONFIG_H) + +list(APPEND include_files + ${bd}/include/srtp.h + ${bd}/crypto/include/cipher.h + ${bd}/crypto/include/auth.h + ${bd}/crypto/include/crypto_types.h + ) + +install(FILES ${include_files} DESTINATION include) +install(TARGETS srtp2 ARCHIVE DESTINATION lib) diff --git a/scripts/webrtc_srtp_config.cmake b/scripts/webrtc_srtp_config.cmake new file mode 100644 index 00000000..86cc387a --- /dev/null +++ b/scripts/webrtc_srtp_config.cmake @@ -0,0 +1,181 @@ +/* config_in.h. Generated from configure.ac by autoheader. */ + +/* Define if building universal (internal helper macro) */ +#cmakedefine AC_APPLE_UNIVERSAL_BUILD 1 + +/* Define if building for a CISC machine (e.g. Intel). */ +#cmakedefine CPU_CISC 1 + +/* Define if building for a RISC machine (assume slow byte access). */ +#cmakedefine CPU_RISC 1 + +/* Define to enabled debug logging for all mudules. */ +#cmakedefine ENABLE_DEBUG_LOGGING 1 + +/* Logging statments will be writen to this file. */ +#cmakedefine ERR_REPORTING_FILE "@ERR_REPORTING_FILE@" + +/* Define to redirect logging to stdout. */ +#cmakedefine ERR_REPORTING_STDOUT 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_ARPA_INET_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_BYTESWAP_H 1 + +/* Define to 1 if you have the `inet_aton' function. */ +#cmakedefine HAVE_INET_ATON 1 + +/* Define to 1 if the system has the type `int16_t'. */ +#cmakedefine HAVE_INT16_T 1 + +/* Define to 1 if the system has the type `int32_t'. */ +#cmakedefine HAVE_INT32_T 1 + +/* Define to 1 if the system has the type `int8_t'. */ +#cmakedefine HAVE_INT8_T 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_INTTYPES_H 1 + +/* Define to 1 if you have the `dl' library (-ldl). */ +#cmakedefine HAVE_LIBDL 1 + +/* Define to 1 if you have the `socket' library (-lsocket). */ +#cmakedefine HAVE_LIBSOCKET 1 + +/* Define to 1 if you have the `z' library (-lz). */ +#cmakedefine HAVE_LIBZ 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_MACHINE_TYPES_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_MEMORY_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_NETINET_IN_H 1 + +/* Define to 1 if you have the `winpcap' library (-lwpcap) */ +#cmakedefine HAVE_PCAP 1 + +/* Define to 1 if you have the `sigaction' function. */ +#cmakedefine HAVE_SIGACTION 1 + +/* Define to 1 if you have the `socket' function. */ +#cmakedefine HAVE_SOCKET 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_STDINT_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_STDLIB_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_STRINGS_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_STRING_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_SYS_INT_TYPES_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_SYS_SOCKET_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_SYS_STAT_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_SYS_TYPES_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_SYS_UIO_H 1 + +/* Define to 1 if the system has the type `uint16_t'. */ +#cmakedefine HAVE_UINT16_T 1 + +/* Define to 1 if the system has the type `uint32_t'. */ +#cmakedefine HAVE_UINT32_T 1 + +/* Define to 1 if the system has the type `uint64_t'. */ +#cmakedefine HAVE_UINT64_T 1 + +/* Define to 1 if the system has the type `uint8_t'. */ +#cmakedefine HAVE_UINT8_T 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_UNISTD_H 1 + +/* Define to 1 if you have the `usleep' function. */ +#cmakedefine HAVE_USLEEP 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_WINDOWS_H 1 + +/* Define to 1 if you have the header file. */ +#cmakedefine HAVE_WINSOCK2_H 1 + +/* Define to use X86 inlined assembly code */ +#cmakedefine HAVE_X86 1 + +/* Define this to use OpenSSL crypto. */ +#cmakedefine OPENSSL 1 + +/* Define this if OPENSSL_cleanse is broken. */ +#cmakedefine OPENSSL_CLEANSE_BROKEN 1 + +/* Define this to use OpenSSL KDF for SRTP. */ +#cmakedefine OPENSSL_KDF 1 + +/* Define to the address where bug reports for this package should be sent. */ +#cmakedefine PACKAGE_BUGREPORT "@PACKAGE_BUGREPORT@" + +/* Define to the full name of this package. */ +#define PACKAGE_NAME "@PACKAGE_NAME@" + +/* Define to the full name and version of this package. */ +#define PACKAGE_STRING "@PACKAGE_STRING@" + +/* Define to the one symbol short name of this package. */ +#define PACKAGE_TARNAME "@PACKAGE_TARNAME@" + +/* Define to the home page for this package. */ +#cmakedefine PACKAGE_URL "@PACKAGE_URL@" + +/* Define to the version of this package. */ +#define PACKAGE_VERSION "@PACKAGE_VERSION@" + +/* The size of a `unsigned long', as computed by sizeof. */ +#define SIZEOF_UNSIGNED_LONG @SIZEOF_UNSIGNED_LONG@ + +/* The size of a `unsigned long long', as computed by sizeof. */ +#define SIZEOF_UNSIGNED_LONG_LONG @SIZEOF_UNSIGNED_LONG_LONG@ + +/* Define to 1 if you have the ANSI C header files. */ +#cmakedefine STDC_HEADERS 1 + +/* Define WORDS_BIGENDIAN to 1 if your processor stores words with the most + significant byte first (like Motorola and SPARC, unlike Intel). */ +#if defined AC_APPLE_UNIVERSAL_BUILD +# if defined __BIG_ENDIAN__ +# define WORDS_BIGENDIAN 1 +# endif +#else +# ifndef WORDS_BIGENDIAN +# undef WORDS_BIGENDIAN +# endif +#endif + +/* Define to empty if `const' does not conform to ANSI C. */ +#undef const + +/* Define to `__inline__' or `__inline' if that's what the C compiler + calls it, or to nothing if 'inline' is not supported under any name. */ +#ifndef __cplusplus +#undef inline +#endif + +/* Define to `unsigned int' if does not define. */ +#undef size_t diff --git a/src/output/output_webrtc.cpp b/src/output/output_webrtc.cpp new file mode 100644 index 00000000..619afbf7 --- /dev/null +++ b/src/output/output_webrtc.cpp @@ -0,0 +1,1524 @@ +#include // ifaddr, listing ip addresses. +#include // ifaddr, listing ip addresses. +#include +#include "output_webrtc.h" + +namespace Mist{ + + OutWebRTC *classPointer = 0; + + /* ------------------------------------------------ */ + + static uint32_t generateSSRC(); + static void webRTCInputOutputThreadFunc(void* arg); + static int onDTLSHandshakeWantsToWriteCallback(const uint8_t* data, int* nbytes); + static void onDTSCConverterHasPacketCallback(const DTSC::Packet& pkt); + static void onDTSCConverterHasInitDataCallback(const uint64_t track, const std::string &initData); + static void onRTPSorterHasPacketCallback(const uint64_t track, const RTP::Packet &p); // when we receive RTP packets we store them in a sorter. Whenever there is a valid, sorted RTP packet that can be used this function is called. + static void onRTPPacketizerHasDataCallback(void* socket, char* data, unsigned int len, unsigned int channel); + static void onRTPPacketizerHasRTCPDataCallback(void* socket, const char* data, uint32_t nbytes); + static std::vector getLocalIP4Addresses(); + + /* ------------------------------------------------ */ + + WebRTCTrack::WebRTCTrack() + :payloadType(0) + ,SSRC(0) + ,timestampMultiplier(0) + ,ULPFECPayloadType(0) + ,REDPayloadType(0) + ,RTXPayloadType(0) + ,prevReceivedSequenceNumber(0) + { + } + + /* ------------------------------------------------ */ + + OutWebRTC::OutWebRTC(Socket::Connection &myConn) : HTTPOutput(myConn){ + + webRTCInputOutputThread = NULL; + udpPort = 0; + SSRC = generateSSRC(); + rtcpTimeoutInMillis = 0; + rtcpKeyFrameDelayInMillis = 2000; + rtcpKeyFrameTimeoutInMillis = 0; + rtpOutBuffer = NULL; + videoBitrate = 6 * 1000 * 1000; + RTP::MAX_SEND = 1200 - 28; + didReceiveKeyFrame = false; + + if (cert.init("NL", "webrtc", "webrtc") != 0) { + FAIL_MSG("Failed to create the certificate."); + exit(EXIT_FAILURE); + // \todo how do we handle this further? disconnect? + } + if (dtlsHandshake.init(&cert.cert, &cert.key, onDTLSHandshakeWantsToWriteCallback) != 0) { + FAIL_MSG("Failed to initialize the dtls-srtp handshake helper."); + exit(EXIT_FAILURE); + // \todo how do we handle this? + } + + rtpOutBuffer = (char*)malloc(2048); + if (!rtpOutBuffer) { + // \todo Jaron how do you handle these cases? + FAIL_MSG("Failed to allocate our RTP output buffer."); + exit(EXIT_FAILURE); + } + + sdpAnswer.setFingerprint(cert.getFingerprintSha256()); + classPointer = this; + + setBlocking(false); + } + + OutWebRTC::~OutWebRTC() { + + if (webRTCInputOutputThread && webRTCInputOutputThread->joinable()) { + webRTCInputOutputThread->join(); + delete webRTCInputOutputThread; + webRTCInputOutputThread = NULL; + } + + if (rtpOutBuffer) { + free(rtpOutBuffer); + rtpOutBuffer = NULL; + } + + if (srtpReader.shutdown() != 0) { + FAIL_MSG("Failed to cleanly shutdown the srtp reader."); + } + if (srtpWriter.shutdown() != 0) { + FAIL_MSG("Failed to cleanly shutdown the srtp writer."); + } + if (dtlsHandshake.shutdown() != 0) { + FAIL_MSG("Failed to cleanly shutdown the dtls handshake."); + } + if (cert.shutdown() != 0) { + FAIL_MSG("Failed to cleanly shutdown the certificate."); + } + } + + // Initialize the WebRTC output. This is where we define what + // codes types are supported and accepted when parsing the SDP + // offer or when generating the SDP. The `capa` member is + // inherited from `Output`. + void OutWebRTC::init(Util::Config *cfg) { + HTTPOutput::init(cfg); + capa["name"] = "WebRTC"; + capa["desc"] = "Provides WebRTC output"; + capa["url_rel"] = "/webrtc/$"; + capa["url_match"] = "/webrtc/$"; + capa["codecs"][0u][0u].append("H264"); + capa["codecs"][0u][0u].append("VP8"); + capa["codecs"][0u][1u].append("opus"); + capa["methods"][0u]["handler"] = "webrtc"; + capa["methods"][0u]["type"] = "webrtc"; + capa["methods"][0u]["priority"] = 2ll; + + capa["optional"]["preferredvideocodec"]["name"] = "Preferred video codecs"; + capa["optional"]["preferredvideocodec"]["help"] = "Comma separated list of video codecs you want to support in preferred order. e.g. H264,VP8"; + capa["optional"]["preferredvideocodec"]["default"] = "H264,VP8"; + capa["optional"]["preferredvideocodec"]["type"] = "string"; + capa["optional"]["preferredvideocodec"]["option"] = "--webrtc-video-codecs"; + capa["optional"]["preferredvideocodec"]["short"] = "V"; + + capa["optional"]["preferredaudiocodec"]["name"] = "Preferred audio codecs"; + capa["optional"]["preferredaudiocodec"]["help"] = "Comma separated list of audio codecs you want to support in preferred order. e.g. OPUS"; + capa["optional"]["preferredaudiocodec"]["default"] = "OPUS"; + capa["optional"]["preferredaudiocodec"]["type"] = "string"; + capa["optional"]["preferredaudiocodec"]["option"] = "--webrtc-audio-codecs"; + capa["optional"]["preferredaudiocodec"]["short"] = "A"; + + config->addOptionsFromCapabilities(capa); + } + + // This function is executed when we receive a signaling data. + // The signaling data contains commands that are used to start + // an input or output stream. + void OutWebRTC::onWebsocketFrame() { + + if (!webSock) { + FAIL_MSG("We assume `webSock` is valid at this point."); + return; + } + + if (webSock->frameType == 0x8) { + HIGH_MSG("Not handling websocket data; close frame."); + return; + } + + JSON::Value msg = JSON::fromString(webSock->data, webSock->data.size()); + handleSignalingCommand(*webSock, msg); + } + + // This function is our first handler for commands that we + // receive via the WebRTC signaling channel (our websocket + // connection). We validate the command and check what + // command-specific function we need to call. + void OutWebRTC::handleSignalingCommand(HTTP::Websocket& ws, const JSON::Value &command) { + + JSON::Value commandResult; + if (false == validateSignalingCommand(ws, command, commandResult)) { + return; + } + + if (command["type"] == "offer_sdp") { + if (!handleSignalingCommandRemoteOffer(ws, command)) { + sendSignalingError(ws, "on_answer_sdp", "Failed to handle the offer SDP."); + } + } + else if (command["type"] == "video_bitrate") { + if (!handleSignalingCommandVideoBitrate(ws, command)) { + sendSignalingError(ws, "on_video_bitrate", "Failed to handle the video bitrate change request."); + } + } + else if (command["type"] == "seek") { + if (!handleSignalingCommandSeek(ws, command)) { + sendSignalingError(ws, "on_seek", "Failed to handle the seek request."); + } + } + else if (command["type"] == "stop") { + INFO_MSG("Received stop() command."); + myConn.close(); + return; + } + else if (command["type"] == "keyframe_interval") { + if (!handleSignalingCommandKeyFrameInterval(ws, command)) { + sendSignalingError(ws, "on_keyframe_interval", "Failed to set the keyframe interval."); + } + } + else { + FAIL_MSG("Unhandled signal command %s.", command["type"].asString().c_str()); + } + } + + /// This function will check if the received command contains + /// the required fields. All commands need the `type` + /// field. When the `type` requires a `data` element we check + /// that too. When this function returns `true` you can assume + /// that it can be processed. In case of an error we return + /// false and send an error back to the other peer. + bool OutWebRTC::validateSignalingCommand(HTTP::Websocket& ws, const JSON::Value &command, JSON::Value& errorResult) { + + if (!command.isMember("type")) { + sendSignalingError(ws, "error", "Received an command but not type property was given."); + return false; + } + + /* seek command */ + if (command["type"] == "seek") { + if (!command.isMember("seek_time")) { + sendSignalingError(ws, "on_seek", "Received a seek request but no `seek_time` property."); + return false; + } + } + + /* offer command */ + if (command["type"] == "offer_sdp") { + if (!command.isMember("offer_sdp")) { + sendSignalingError(ws, "on_offer_sdp", "A `offer_sdp` command needs the offer SDP in the `offer_sdp` field."); + return false; + } + if (command["offer_sdp"].asString() == "") { + sendSignalingError(ws, "on_offer_sdp", "The given `offer_sdp` field is empty."); + return false; + } + } + + /* video bitrate */ + if (command["type"] == "video_bitrate") { + if (!command.isMember("video_bitrate")) { + sendSignalingError(ws, "on_video_bitrate", "No video_bitrate attribute found."); + return false; + } + } + + /* keyframe interval */ + if (command["type"] == "keyframe_interval") { + if (!command.isMember("keyframe_interval_millis")) { + sendSignalingError(ws, "on_keyframe_interval", "No keyframe_interval_millis attribute found."); + return false; + } + } + + /* when we arrive here everything is fine and validated. */ + return true; + } + + void OutWebRTC::sendSignalingError(HTTP::Websocket& ws, + const std::string& commandType, + const std::string& errorMessage) + { + JSON::Value commandResult; + commandResult["type"] = commandType; + commandResult["result"] = false; + commandResult["message"] = errorMessage; + ws.sendFrame(commandResult.toString()); + } + + bool OutWebRTC::handleSignalingCommandRemoteOffer(HTTP::Websocket &ws, const JSON::Value &command) { + + // get video and supported video formats from offer. + SDP::Session sdpParser; + std::string sdpOffer = command["offer_sdp"].asString(); + if (!sdpParser.parseSDP(sdpOffer)) { + FAIL_MSG("Failed to parse the remote offer sdp."); + return false; + } + + // when the SDP offer contains a `a=recvonly` we expect that + // the other peer wants to receive media from us. + if (sdpParser.hasReceiveOnlyMedia()) { + return handleSignalingCommandRemoteOfferForOutput(ws, sdpParser, sdpOffer); + } + else { + return handleSignalingCommandRemoteOfferForInput(ws, sdpParser, sdpOffer); + } + + return false; + } + + // This function is called for a peer that wants to receive + // data from us. First we update our capabilities by checking + // what codecs the peer and we support. After updating the + // codecs we `initialize()` and use `selectDefaultTracks()` to + // pick the right tracks based on our supported (and updated) + // capabilities. + bool OutWebRTC::handleSignalingCommandRemoteOfferForOutput(HTTP::Websocket &ws, SDP::Session &sdpSession, const std::string& sdpOffer) { + + updateCapabilitiesWithSDPOffer(sdpSession); + initialize(); + selectDefaultTracks(); + + // \todo I'm not sure if this is the nicest location to bind the socket but it's necessary when we create our answer SDP + if (0 == udpPort) { + bindUDPSocketOnLocalCandidateAddress(0); + } + + // get codecs from selected stream which are used to create our SDP answer. + int32_t dtscVideoTrackID = -1; + int32_t dtscAudioTrackID = -1; + std::string videoCodec; + std::string audioCodec; + std::map::iterator it = myMeta.tracks.begin(); + while (it != myMeta.tracks.end()) { + DTSC::Track& Trk = it->second; + if (Trk.type == "video") { + videoCodec = Trk.codec; + dtscVideoTrackID = Trk.trackID; + } + else if (Trk.type == "audio") { + audioCodec = Trk.codec; + dtscAudioTrackID = Trk.trackID; + } + ++it; + } + + // parse offer SDP and setup the answer SDP using the selected codecs. + if (!sdpAnswer.parseOffer(sdpOffer)) { + FAIL_MSG("Failed to parse the received offer SDP"); + FAIL_MSG("%s", sdpOffer.c_str()); + return false; + } + sdpAnswer.setDirection("sendonly"); + + // setup video WebRTC Track. + if (!videoCodec.empty()) { + if (sdpAnswer.enableVideo(videoCodec)) { + WebRTCTrack videoTrack; + if (!createWebRTCTrackFromAnswer(sdpAnswer.answerVideoMedia, sdpAnswer.answerVideoFormat, videoTrack)) { + FAIL_MSG("Failed to create the WebRTCTrack for the selected video."); + return false; + } + videoTrack.rtpPacketizer = RTP::Packet(videoTrack.payloadType, rand(), 0, videoTrack.SSRC, 0); + videoTrack.timestampMultiplier = 90; + webrtcTracks[dtscVideoTrackID] = videoTrack; + } + } + + // setup audio WebRTC Track + if (!audioCodec.empty()) { + + // @todo maybe, create a isAudioSupported() function (?) + if (sdpAnswer.enableAudio(audioCodec)) { + WebRTCTrack audioTrack; + if (!createWebRTCTrackFromAnswer(sdpAnswer.answerAudioMedia, sdpAnswer.answerAudioFormat, audioTrack)) { + FAIL_MSG("Failed to create the WebRTCTrack for the selected audio."); + return false; + } + audioTrack.rtpPacketizer = RTP::Packet(audioTrack.payloadType, rand(), 0, audioTrack.SSRC, 0); + audioTrack.timestampMultiplier = 48; + webrtcTracks[dtscAudioTrackID] = audioTrack; + } + } + + // this is necessary so that we can get the remote IP when creating STUN replies. + udp.SetDestination("0.0.0.0", 4444); + + // create result message. + JSON::Value commandResult; + commandResult["type"] = "on_answer_sdp"; + commandResult["result"] = true; + commandResult["answer_sdp"] = sdpAnswer.toString(); + ws.sendFrame(commandResult.toString()); + + // we set parseData to `true` to start the data flow. Is also + // used to break out of our loop in `onHTTP()`. + parseData = true; + + return true; + } + + // When the receive a command with a `type` attribute set to + // `video_bitrate` we will extract the bitrate and use it for + // our REMB messages that are sent as soon as an connection has + // been established. REMB messages are used from server to + // client to define the preferred bitrate. + bool OutWebRTC::handleSignalingCommandVideoBitrate(HTTP::Websocket &ws, const JSON::Value &command) { + + videoBitrate = command["video_bitrate"].asInt(); + if (videoBitrate == 0) { + FAIL_MSG("We received an invalid video_bitrate; resetting to default."); + videoBitrate = 6 * 1000 * 1000; + return false; + } + + JSON::Value commandResult; + commandResult["type"] = "on_video_bitrate"; + commandResult["result"] = true; + commandResult["video_bitrate"] = videoBitrate; + ws.sendFrame(commandResult.toString()); + + return true; + } + + bool OutWebRTC::handleSignalingCommandSeek(HTTP::Websocket& ws, const JSON::Value &command) { + + uint64_t seek_time = command["seek_time"].asInt(); + seek(seek_time); + + JSON::Value commandResult; + commandResult["type"] = "on_seek"; + commandResult["result"] = true; + ws.sendFrame(commandResult.toString()); + + return true; + } + + bool OutWebRTC::handleSignalingCommandKeyFrameInterval(HTTP::Websocket &ws, const JSON::Value &command) { + + rtcpKeyFrameDelayInMillis = command["keyframe_interval_millis"].asInt(); + if (rtcpKeyFrameDelayInMillis < 500) { + WARN_MSG("Requested a keyframe delay < 500ms; 500ms is the minimum you can set."); + rtcpKeyFrameDelayInMillis = 500; + } + + rtcpKeyFrameTimeoutInMillis = Util::getMS() + rtcpKeyFrameDelayInMillis; + + JSON::Value commandResult; + commandResult["type"] = "on_keyframe_interval"; + commandResult["result"] = true; + ws.sendFrame(commandResult.toString()); + + return true; + } + + // Creates a WebRTCTrack for the given `SDP::Media` and + // `SDP::MediaFormat`. The `SDP::MediaFormat` must contain the + // `icePwd` and `iceUFrag` which are used when we're handling + // STUN messages (these can be used to verify the integrity). + // + // When the `SDP::Media` has it's SSRC set (not 0) we copy it. + // This is the case when we're receiving data from another peer + // and the `WebRTCTrack` is used to handle input data. When the + // SSRC is 0 we generate a new one and we assume that the other + // peer expect us to send data. + bool OutWebRTC::createWebRTCTrackFromAnswer(const SDP::Media& mediaAnswer, + const SDP::MediaFormat& formatAnswer, + WebRTCTrack& result) + { + if (formatAnswer.payloadType == SDP_PAYLOAD_TYPE_NONE) { + FAIL_MSG("Cannot create a WebRTCTrack, the given SDP::MediaFormat has no `payloadType` set."); + return false; + } + + if (formatAnswer.icePwd.empty()) { + FAIL_MSG("Cannot create a WebRTCTrack, the given SDP::MediaFormat has no `icePwd` set."); + return false; + } + + if (formatAnswer.iceUFrag.empty()) { + FAIL_MSG("Cannot create a WebRTCTrack, the given SDP::MediaFormat has no `iceUFrag` set."); + return false; + } + + result.payloadType = formatAnswer.getPayloadType(); + result.localIcePwd = formatAnswer.icePwd; + result.localIceUFrag = formatAnswer.iceUFrag; + + if (mediaAnswer.SSRC != 0) { + result.SSRC = mediaAnswer.SSRC; + } + else { + result.SSRC = rand(); + } + + return true; + } + + // This function checks what capabilities the offer has and + // updates the `capa[codecs]` member with the codecs that we + // and the offer supports. Note that the names in the preferred + // codec array below should use the codec names as MistServer + // defines them. SDP internally uses fullcaps. + // + // Jaron advised me to use this approach: when we receive an + // offer we update the capabilities with the matching codecs + // and once we've updated those we call `selectDefaultTracks()` + // which sets up the tracks in `myMeta`. + void OutWebRTC::updateCapabilitiesWithSDPOffer(SDP::Session &sdpSession) { + + capa["codecs"].null(); + + const char* videoCodecPreference[] = { "H264", "VP8", NULL } ; + const char** videoCodec = videoCodecPreference; + SDP::Media* videoMediaOffer = sdpSession.getMediaForType("video"); + if (videoMediaOffer) { + while (*videoCodec) { + if (sdpSession.getMediaFormatByEncodingName("video", *videoCodec)) { + capa["codecs"][0u][0u].append(*videoCodec); + } + videoCodec++; + } + } + + const char* audioCodecPreference[] = { "opus", NULL } ; + const char** audioCodec = audioCodecPreference; + SDP::Media* audioMediaOffer = sdpSession.getMediaForType("audio"); + if (audioMediaOffer) { + while (*audioCodec) { + if (sdpSession.getMediaFormatByEncodingName("audio", *audioCodec)) { + capa["codecs"][0u][1u].append(*audioCodec); + } + audioCodec++; + } + } + } + + // This function is called to handle an offer from a peer that wants to push data towards us. + bool OutWebRTC::handleSignalingCommandRemoteOfferForInput(HTTP::Websocket &webSock, SDP::Session &sdpSession, const std::string& sdpOffer) { + + if (webRTCInputOutputThread != NULL) { + FAIL_MSG("It seems that we're already have a webrtc i/o thread running."); + return false; + } + + if (0 == udpPort) { + bindUDPSocketOnLocalCandidateAddress(0); + } + + if (!sdpAnswer.parseOffer(sdpOffer)) { + FAIL_MSG("Failed to parse the received offer SDP"); + FAIL_MSG("%s", sdpOffer.c_str()); + return false; + } + + std::string prefVideoCodec = "VP8,H264"; + if (config && config->hasOption("preferredvideocodec")) { + prefVideoCodec = config->getString("preferredvideocodec"); + if (prefVideoCodec.empty()) { + WARN_MSG("No preferred video codec value set; resetting to default."); + prefVideoCodec = "VP8,H264"; + } + } + + std::string prefAudioCodec = "OPUS"; + if (config && config->hasOption("preferredaudiocodec")) { + prefAudioCodec = config->getString("preferredaudiocodec"); + if (prefAudioCodec.empty()) { + WARN_MSG("No preferred audio codec value set; resetting to default."); + prefAudioCodec = "OPUS"; + } + } + + // video + if (sdpAnswer.enableVideo(prefVideoCodec)) { + + WebRTCTrack videoTrack; + videoTrack.payloadType = sdpAnswer.answerVideoFormat.getPayloadType(); + videoTrack.localIcePwd = sdpAnswer.answerVideoFormat.icePwd; + videoTrack.localIceUFrag = sdpAnswer.answerVideoFormat.iceUFrag; + videoTrack.SSRC = sdpAnswer.answerVideoMedia.SSRC; + + SDP::MediaFormat* fmtRED = sdpSession.getMediaFormatByEncodingName("video", "RED"); + SDP::MediaFormat* fmtULPFEC = sdpSession.getMediaFormatByEncodingName("video", "ULPFEC"); + if (fmtRED && fmtULPFEC) { + videoTrack.ULPFECPayloadType = fmtULPFEC->payloadType; + videoTrack.REDPayloadType = fmtRED->payloadType; + payloadTypeToWebRTCTrack[fmtRED->payloadType] = videoTrack.payloadType; + } + sdpAnswer.videoLossPrevention = SDP_LOSS_PREVENTION_NACK; + videoTrack.sorter.tmpVideoLossPrevention = sdpAnswer.videoLossPrevention; + + DTSC::Track dtscVideo; + if (!sdpAnswer.setupVideoDTSCTrack(dtscVideo)) { + FAIL_MSG("Failed to setup video DTSC track."); + return false; + } + + videoTrack.rtpToDTSC.setProperties(dtscVideo); + videoTrack.rtpToDTSC.setCallbacks(onDTSCConverterHasPacketCallback, onDTSCConverterHasInitDataCallback); + videoTrack.sorter.setCallback(videoTrack.payloadType, onRTPSorterHasPacketCallback); + + webrtcTracks[videoTrack.payloadType] = videoTrack; + myMeta.tracks[dtscVideo.trackID] = dtscVideo; + } + + // audio setup + if (sdpAnswer.enableAudio(prefAudioCodec)) { + + WebRTCTrack audioTrack; + audioTrack.payloadType = sdpAnswer.answerAudioFormat.getPayloadType(); + audioTrack.localIcePwd = sdpAnswer.answerAudioFormat.icePwd; + audioTrack.localIceUFrag = sdpAnswer.answerAudioFormat.iceUFrag; + audioTrack.SSRC = sdpAnswer.answerAudioMedia.SSRC; + + DTSC::Track dtscAudio; + if (!sdpAnswer.setupAudioDTSCTrack(dtscAudio)) { + FAIL_MSG("Failed to setup audio DTSC track."); + } + + audioTrack.rtpToDTSC.setProperties(dtscAudio); + audioTrack.rtpToDTSC.setCallbacks(onDTSCConverterHasPacketCallback, onDTSCConverterHasInitDataCallback); + audioTrack.sorter.setCallback(audioTrack.payloadType, onRTPSorterHasPacketCallback); + + webrtcTracks[audioTrack.payloadType] = audioTrack; + myMeta.tracks[dtscAudio.trackID] = dtscAudio; + } + + sdpAnswer.setDirection("recvonly"); + + // allow peer to push video/audio + if (!allowPush("")) { + FAIL_MSG("Failed to allow push for stream %s.", streamName.c_str()); + /* \todo when I try to send a error message back to the browser it fails; probably because socket gets closed (?). */ + return false; + } + + // start our receive thread (handles STUN, DTLS, RTP input) + webRTCInputOutputThread = new tthread::thread(webRTCInputOutputThreadFunc, NULL); + rtcpTimeoutInMillis = Util::getMS() + 2000; + rtcpKeyFrameTimeoutInMillis = Util::getMS() + 2000; + + // create result command for websock client. + JSON::Value commandResult; + commandResult["type"] = "on_answer_sdp"; + commandResult["result"] = true; + commandResult["answer_sdp"] = sdpAnswer.toString(); + webSock.sendFrame(commandResult.toString()); + + return true; + } + + // Get the IP address on which we should bind our UDP socket that is + // used to receive the STUN, DTLS, SRTP, etc.. + std::string OutWebRTC::getLocalCandidateAddress() { + + std::vector localIP4Addresses = getLocalIP4Addresses(); + if (localIP4Addresses.size() == 0) { + FAIL_MSG("No IP4 addresses found."); + return ""; + } + + return localIP4Addresses[0]; + } + + bool OutWebRTC::bindUDPSocketOnLocalCandidateAddress(uint16_t port) { + + if (udpPort != 0) { + FAIL_MSG("Already bound the UDP socket."); + return false; + } + + std::string local_ip = getLocalCandidateAddress(); + if (local_ip.empty()) { + FAIL_MSG("Failed to get the local candidate address. W/o ICE will probably fail or will have issues during startup. See https://gist.github.com/roxlu/6c5ab696840256dac71b6247bab59ce9"); + } + + udpPort = udp.bind(port, local_ip); + sdpAnswer.setCandidate(local_ip, udpPort); + + return true; + } + + /* ------------------------------------------------ */ + + // This function is called from the `webRTCInputOutputThreadFunc()` + // function. The `webRTCInputOutputThreadFunc()` is basically empty + // and all work for the thread is done here. + void OutWebRTC::handleWebRTCInputOutputFromThread() { + udp.SetDestination("0.0.0.0", 4444); + while (keepGoing()) { + if (!handleWebRTCInputOutput()) { + Util::sleep(100); + } + } + } + + // Checks if there is data on our UDP socket. The data can be + // STUN, DTLS, SRTP or SRTCP. When we're receiving media from + // the browser (e.g. from webcam) this function is called from + // a separate thread. When we're pushing media to the browser + // this is called from the main thread. + bool OutWebRTC::handleWebRTCInputOutput() { + + if (!udp.Receive()) { + DONTEVEN_MSG("Waiting for data..."); + return false; + } + + myConn.addDown(udp.data_len); + + uint8_t fb = (uint8_t)udp.data[0]; + + if (fb > 127 && fb < 192) { + handleReceivedRTPOrRTCPPacket(); + } + else if (fb > 19 && fb < 64) { + handleReceivedDTLSPacket(); + } + else if (fb < 2) { + handleReceivedSTUNPacket(); + } + else { + FAIL_MSG("Unhandled WebRTC data. Type: %02X", fb); + } + + return true; + } + + void OutWebRTC::handleReceivedSTUNPacket() { + + size_t nparsed = 0; + StunMessage stun_msg; + if (stunReader.parse((uint8_t*)udp.data, udp.data_len, nparsed, stun_msg) != 0) { + FAIL_MSG("Failed to parse a stun message."); + return; + } + + if(stun_msg.type != STUN_MSG_TYPE_BINDING_REQUEST) { + INFO_MSG("We only handle STUN binding requests as we're an ice-lite implementation."); + return; + } + + // get the username for whom we got a binding request. + StunAttribute* usernameAttrib = stun_msg.getAttributeByType(STUN_ATTR_TYPE_USERNAME); + if (!usernameAttrib) { + ERROR_MSG("No username attribute found in the STUN binding request. Cannot create success binding response."); + return; + } + if (usernameAttrib->username.value == 0) { + ERROR_MSG("The username attribute is empty."); + return; + } + std::string username(usernameAttrib->username.value, usernameAttrib->length); + std::size_t usernameColonPos = username.find(":"); + if (usernameColonPos == std::string::npos) { + ERROR_MSG("The username in the STUN attribute has an invalid format: %s.", username.c_str()); + return; + } + std::string usernameLocal = username.substr(0, usernameColonPos); + + // get the password for the username that is used to create our message-integrity. + std::string passwordLocal; + std::map::iterator rtcTrackIt = webrtcTracks.begin(); + while (rtcTrackIt != webrtcTracks.end()) { + WebRTCTrack& tr = rtcTrackIt->second; + if (tr.localIceUFrag == usernameLocal) { + passwordLocal = tr.localIcePwd; + } + ++rtcTrackIt; + } + if (passwordLocal.empty()) { + ERROR_MSG("No local ICE password found for username %s. Did you create a WebRTCTrack?", usernameLocal.c_str()); + return; + } + + std::string remoteIP = ""; + uint32_t remotePort = 0; + udp.GetDestination(remoteIP, remotePort); + + // create the binding success response + stun_msg.removeAttributes(); + stun_msg.setType(STUN_MSG_TYPE_BINDING_RESPONSE_SUCCESS); + + StunWriter stun_writer; + stun_writer.begin(stun_msg); + stun_writer.writeXorMappedAddress(STUN_IP4, remotePort, remoteIP); + stun_writer.writeMessageIntegrity(passwordLocal); + stun_writer.writeFingerprint(); + stun_writer.end(); + + udp.SendNow((const char*)stun_writer.getBufferPtr(), stun_writer.getBufferSize()); + myConn.addUp(stun_writer.getBufferSize()); + } + + void OutWebRTC::handleReceivedDTLSPacket() { + + if (dtlsHandshake.hasKeyingMaterial()) { + DONTEVEN_MSG("Not feeding data into the handshake .. already done."); + return; + } + + if (dtlsHandshake.parse((const uint8_t*)udp.data, udp.data_len) != 0) { + FAIL_MSG("Failed to parse a DTLS packet."); + return; + } + + if (!dtlsHandshake.hasKeyingMaterial()) { + return; + } + + if (srtpReader.init(dtlsHandshake.cipher, dtlsHandshake.remote_key, dtlsHandshake.remote_salt) != 0) { + FAIL_MSG("Failed to initialize the SRTP reader."); + return; + } + + if (srtpWriter.init(dtlsHandshake.cipher, dtlsHandshake.local_key, dtlsHandshake.local_salt) != 0) { + FAIL_MSG("Failed to initialize the SRTP writer."); + return; + } + } + + void OutWebRTC::handleReceivedRTPOrRTCPPacket() { + + uint8_t pt = udp.data[1] & 0x7F; + + if ((pt < 64) || (pt >= 96)) { + + int len = (int)udp.data_len; + if (srtpReader.unprotectRtp((uint8_t*)udp.data, &len) != 0) { + FAIL_MSG("Failed to unprotect a RTP packet."); + return; + } + + RTP::Packet rtp_pkt((const char*)udp.data, (unsigned int)len); + uint8_t payloadType = rtp_pkt.getPayloadType(); + uint64_t rtcTrackID = payloadType; + + // Do we need to map the payload type to a WebRTC Track? (e.g. RED) + if (payloadTypeToWebRTCTrack.count(payloadType) != 0) { + rtcTrackID = payloadTypeToWebRTCTrack[payloadType]; + } + + if (webrtcTracks.count(rtcTrackID) == 0) { + FAIL_MSG("Received an RTP packet for a track that we didn't prepare for. PayloadType is %llu", rtp_pkt.getPayloadType()); + // \todo @jaron should we close the socket here? + return; + } + + // Here follows a very rudimentary algo for requesting lost + // packets; I guess after some experimentation a better + // algorithm should be used; this is used to trigger NACKs. + WebRTCTrack& rtcTrack = webrtcTracks[rtcTrackID]; + uint16_t expectedSeqNum = rtcTrack.prevReceivedSequenceNumber + 1; + uint16_t currSeqNum = rtp_pkt.getSequence(); + if (rtcTrack.prevReceivedSequenceNumber != 0 + && (rtcTrack.prevReceivedSequenceNumber + 1) != currSeqNum) + { + while (rtcTrack.prevReceivedSequenceNumber < currSeqNum) { + FAIL_MSG("=> nack seqnum: %u", rtcTrack.prevReceivedSequenceNumber); + sendRTCPFeedbackNACK(rtcTrack, rtcTrack.prevReceivedSequenceNumber); + rtcTrack.prevReceivedSequenceNumber++; + } + } + + rtcTrack.prevReceivedSequenceNumber = rtp_pkt.getSequence(); + + if (payloadType == rtcTrack.REDPayloadType) { + rtcTrack.sorter.addREDPacket(udp.data, len, rtcTrack.payloadType, rtcTrack.REDPayloadType, rtcTrack.ULPFECPayloadType); + } + else { + rtcTrack.sorter.addPacket(rtp_pkt); + } + } + else if ((pt >= 64) && (pt < 96)) { + +#if 0 + // \todo it seems that we don't need handle RTCP messages. + int len = udp.data_len; + if (srtpReader.unprotectRtcp((uint8_t*)udp.data, &len) != 0) { + FAIL_MSG("Failed to unprotect RTCP."); + return; + } +#endif + + } + else { + FAIL_MSG("Unknown payload type: %u", pt); + } + } + + /* ------------------------------------------------ */ + + int OutWebRTC::onDTLSHandshakeWantsToWrite(const uint8_t* data, int* nbytes) { + // \todo udp.SendNow() does not return an error code so we assume everything is sent nicely. + udp.SendNow((const char*)data, (size_t)*nbytes); + myConn.addUp(*nbytes); + return 0; + } + + void OutWebRTC::onDTSCConverterHasPacket(const DTSC::Packet& pkt) { + + // extract meta data (init data, width/height, etc); + uint64_t trackID = pkt.getTrackId(); + DTSC::Track& DTSCTrack = myMeta.tracks[trackID]; + if (DTSCTrack.codec == "H264") { + if (DTSCTrack.init.empty()) { + FAIL_MSG("No init data found for trackID %llu (note: we use payloadType as trackID)", trackID); + return; + } + } + else if (DTSCTrack.codec == "VP8") { + if (pkt.getFlag("keyframe")) { + extractFrameSizeFromVP8KeyFrame(pkt); + } + } + + // create rtcp packet (set bitrate and request keyframe). + if (DTSCTrack.codec == "H264" || DTSCTrack.codec == "VP8") { + uint64_t now = Util::getMS(); + + if (now >= rtcpTimeoutInMillis) { + WebRTCTrack& rtcTrack = webrtcTracks[trackID]; + sendRTCPFeedbackREMB(rtcTrack); + sendRTCPFeedbackRR(rtcTrack); + rtcpTimeoutInMillis = now + 1000; /* @todo was 5000, lowered for FEC */ + } + + if (now >= rtcpKeyFrameTimeoutInMillis) { + WebRTCTrack& rtcTrack = webrtcTracks[trackID]; + sendRTCPFeedbackPLI(rtcTrack); + rtcpKeyFrameTimeoutInMillis = now + rtcpKeyFrameDelayInMillis; + } + } + + bufferLivePacket(pkt); + } + + void OutWebRTC::onDTSCConverterHasInitData(const uint64_t trackID, const std::string &initData) { + + if (webrtcTracks.count(trackID) == 0) { + ERROR_MSG("Recieved init data for a track that we don't manager. TrackID/PayloadType: %llu", trackID); + return; + } + + MP4::AVCC avccbox; + avccbox.setPayload(initData); + if (avccbox.getSPSLen() == 0 || avccbox.getPPSLen() == 0) { + WARN_MSG("Received init data, but partially. SPS nbytes: %u, PPS nbytes: %u.", avccbox.getSPSLen(), avccbox.getPPSLen()); + return; + } + + h264::sequenceParameterSet sps(avccbox.getSPS(), avccbox.getSPSLen()); + h264::SPSMeta hMeta = sps.getCharacteristics(); + DTSC::Track& Trk = myMeta.tracks[trackID]; + Trk.width = hMeta.width; + Trk.height = hMeta.height; + Trk.fpks = hMeta.fps * 1000; + + avccbox.multiplyPPS(57);//Inject all possible PPS packets into init + myMeta.tracks[trackID].init = std::string(avccbox.payload(), avccbox.payloadSize()); + } + + void OutWebRTC::onRTPSorterHasPacket(const uint64_t trackID, const RTP::Packet &pkt) { + + if (webrtcTracks.count(trackID) == 0) { + ERROR_MSG("Received a sorted RTP packet for track %llu but we don't manage this track.", trackID); + return; + } + + webrtcTracks[trackID].rtpToDTSC.addRTP(pkt); + } + + // \todo when rtpOutputBuffer is allocated on the stack (I + // created a member with 2048 as its size and when I called + // `memcpy()` to copy the `data` it ran into a + // segfault. valgrind pointed me to EBML (I was testing VP8); + // it feels like somewhere the stack gets overwritten. Shortly + // discussed this with Jaron and he told me this could be + // indeed the case. For now I'm allocating the buffer on the + // heap. This function will be called when we're sending data + // to the browser (other peer). + void OutWebRTC::onRTPPacketizerHasRTPPacket(char* data, uint32_t nbytes) { + + memcpy(rtpOutBuffer, data, nbytes); + + int protectedSize = nbytes; + if (srtpWriter.protectRtp((uint8_t*)rtpOutBuffer, &protectedSize) != 0) { + ERROR_MSG("Failed to protect the RTCP message."); + return; + } + + udp.SendNow((const char*)rtpOutBuffer, (size_t)protectedSize); + myConn.addUp(protectedSize); + + /* << TODO: remove if this doesn't work; testing output pacing >> */ + if (didReceiveKeyFrame) { + //Util::sleep(4); + } + } + + void OutWebRTC::onRTPPacketizerHasRTCPPacket(char* data, uint32_t nbytes) { + + if (nbytes > 2048) { + FAIL_MSG("The received RTCP packet is too big to handle."); + return; + } + if (!rtpOutBuffer) { + FAIL_MSG("rtpOutBuffer not yet allocated."); + return; + } + if (!data) { + FAIL_MSG("Invalid RTCP packet given."); + return; + } + + FAIL_MSG("# Copy data into rtpOutBuffer (%u bytes)", nbytes); + + memcpy((void*)rtpOutBuffer, data, nbytes); + int rtcpPacketSize = nbytes; + FAIL_MSG("# Protect rtcp"); + if (srtpWriter.protectRtcp((uint8_t*)rtpOutBuffer, &rtcpPacketSize) != 0) { + ERROR_MSG("Failed to protect the RTCP message."); + return; + } + + WARN_MSG("# has RTCP packet, %d bytes.", rtcpPacketSize); + + udp.SendNow((const char*)rtpOutBuffer, rtcpPacketSize); + + /* @todo add myConn.addUp(). */ + } + + // This function was implemented (it's virtual) to handle + // pushing of media to the browser. This function blocks until + // the DTLS handshake has been finished. This prevents + // `sendNext()` from being called which is correct because we + // don't want to send packets when we can't protect them with + // DTLS. + void OutWebRTC::sendHeader() { + + // first make sure that we complete the DTLS handshake. + while (!dtlsHandshake.hasKeyingMaterial()){ + if (!handleWebRTCInputOutput()) { + Util::sleep(10); + } + } + + sentHeader = true; + } + + void OutWebRTC::sendNext() { + + // once the DTLS handshake has been done, we still have to + // deal with STUN consent messages and RTCP. + handleWebRTCInputOutput(); + + char* dataPointer = 0; + size_t dataLen = 0; + thisPacket.getString("data", dataPointer, dataLen); + + // make sure the webrtcTracks were setup correctly for output. + uint32_t tid = thisPacket.getTrackId(); + if (webrtcTracks.count(tid) == 0) { + FAIL_MSG("No WebRTCTrack found for track id %llu.", tid); + return; + } + + WebRTCTrack& rtcTrack = webrtcTracks[tid]; + if (rtcTrack.timestampMultiplier == 0) { + FAIL_MSG("The WebRTCTrack::timestampMultiplier is 0; invalid."); + return; + } + + uint64_t timestamp = thisPacket.getTime(); + uint64_t offset = thisPacket.getInt("offset"); + rtcTrack.rtpPacketizer.setTimestamp((timestamp + offset) * rtcTrack.timestampMultiplier); + + DTSC::Track& dtscTrack = myMeta.tracks[tid]; + + /* ----------------------- BEGIN NEEDS CLEANUP -------------------------------------- */ + + bool isKeyFrame = thisPacket.getFlag("keyframe"); + didReceiveKeyFrame = isKeyFrame; + if (isKeyFrame && dtscTrack.codec == "H264") { + uint8_t nalType = dataPointer[5] & 0x1F; + if (nalType == 5) { + sendSPSPPS(dtscTrack, rtcTrack); + } + } + +#if 0 + /* + @todo + - wel versturen als isKeyframe true is en op byte positie 5, nal type 5 + 4 bytes size, 1 byte nal type. + - jaron heeft een keyframe check functie + */ + bool isKeyFrame = thisPacket.getFlag("keyframe"); + + if (!dtscTrack.init.empty()) { + static bool didSentInit = false; + if (!didSentInit || isKeyFrame) { + + MP4::AVCC avcc; + avcc.setPayload(dtscTrack.init); + std::vector buf; + for (uint32_t i = 0; i < avcc.getSPSCount(); ++i) { + uint32_t len = avcc.getSPSLen(i); + buf.clear(); + buf.assign(4, 0); + *(uint32_t*)&buf[0] = htonl(len); + std::copy(avcc.getSPS(i), avcc.getSPS(i) + avcc.getSPSLen(i), std::back_inserter(buf)); + rtcTrack.rtpPacketizer.sendData(&udp, + onRTPPacketizerHasDataCallback, + &buf[0], + buf.size(), + rtcTrack.payloadType, + dtscTrack.codec); + } + + for (uint32_t i = 0; i < avcc.getPPSCount(); ++i) { + uint32_t len = avcc.getPPSLen(i); + buf.clear(); + buf.assign(4, 0); + *(uint32_t*)&buf[0] = htonl(len); + std::copy(avcc.getPPS(i), avcc.getPPS(i) + avcc.getPPSLen(i), std::back_inserter(buf)); + rtcTrack.rtpPacketizer.sendData(&udp, + onRTPPacketizerHasDataCallback, + &buf[0], + buf.size(), + rtcTrack.payloadType, + dtscTrack.codec); + + } + + didSentInit = true; + } + } +#endif + /* test: repeat sending of SPS. */ +#if 0 + static uint64_t repeater = Util::getMS() + 2000; + if (Util::getMS() >= repeater) { + FAIL_MSG("=> send next, sending SPS again."); + repeater = Util::getMS() + 2000; + MP4::AVCC avcc; + avcc.setPayload(dtscTrack.init); + std::vector buf; + for (uint32_t i = 0; i < avcc.getSPSCount(); ++i) { + uint32_t len = avcc.getSPSLen(i); + buf.clear(); + buf.assign(4, 0); + *(uint32_t*)&buf[0] = htonl(len); + std::copy(avcc.getSPS(i), avcc.getSPS(i) + avcc.getSPSLen(i), std::back_inserter(buf)); + rtcTrack.rtpPacketizer.sendData(&udp, + onRTPPacketizerHasDataCallback, + &buf[0], + buf.size(), + rtcTrack.payloadType, + dtscTrack.codec); + } + } +#endif + /* ----------------------- END NEEDS CLEANUP -------------------------------------- */ + + rtcTrack.rtpPacketizer.sendData(&udp, + onRTPPacketizerHasDataCallback, + dataPointer, + dataLen, + rtcTrack.payloadType, + dtscTrack.codec); + } + + // When the RTP::toDTSC converter collected a complete VP8 + // frame, it wil call our callback `onDTSCConverterHasPacket()` + // with a valid packet that can be fed into + // MistServer. Whenever we receive a keyframe we update the + // width and height of the corresponding track. + void OutWebRTC::extractFrameSizeFromVP8KeyFrame(const DTSC::Packet &pkt) { + + char *vp8PayloadBuffer = 0; + size_t vp8PayloadLen = 0; + pkt.getString("data", vp8PayloadBuffer, vp8PayloadLen); + + if (!vp8PayloadBuffer || vp8PayloadLen < 9) { + FAIL_MSG("Cannot extract vp8 frame size. Failed to get data."); + return; + } + + if (vp8PayloadBuffer[3] != 0x9d + || vp8PayloadBuffer[4] != 0x01 + || vp8PayloadBuffer[5] != 0x2a) + { + FAIL_MSG("Invalid signature. It seems that either the VP8 frames is incorrect or our parsing is wrong."); + return; + } + + uint32_t width = ((vp8PayloadBuffer[7] << 8) + vp8PayloadBuffer[6]) & 0x3FFF; + uint32_t height = ((vp8PayloadBuffer[9] << 8) + vp8PayloadBuffer[8]) & 0x3FFF; + + DONTEVEN_MSG("Recieved VP8 keyframe with resolution: %u x %u", width, height); + + if (width == 0) { + FAIL_MSG("VP8 frame width is 0; parse error?"); + return; + } + + if (height == 0) { + FAIL_MSG("VP8 frame height is 0; parse error?"); + return; + } + + uint64_t trackID = pkt.getTrackId(); + if (myMeta.tracks.count(trackID) == 0) { + FAIL_MSG("No track found with ID %llu.", trackID); + return; + } + + DTSC::Track& Trk = myMeta.tracks[trackID]; + Trk.width = width; + Trk.height = height; + } + + void OutWebRTC::sendRTCPFeedbackREMB(const WebRTCTrack& rtcTrack) { + + if (videoBitrate == 0) { + FAIL_MSG("videoBitrate is 0, which is invalid. Resetting to our default value."); + videoBitrate = 6 * 1000 * 1000; + } + + // create the `BR Exp` and `BR Mantissa parts. + uint32_t br_exponent = 0; + uint32_t br_mantissa = videoBitrate; + while (br_mantissa > 0x3FFFF) { + br_mantissa >>= 1; + ++br_exponent; + } + + std::vector buffer; + buffer.push_back(0x80 | 0x0F); // V =2 (0x80) | FMT=15 (0x0F) + buffer.push_back(0xCE); // payload type = 206 + buffer.push_back(0x00); // tmp length + buffer.push_back(0x00); // tmp length + buffer.push_back((SSRC >> 24) & 0xFF); // ssrc of sender + buffer.push_back((SSRC >> 16) & 0xFF); // ssrc of sender + buffer.push_back((SSRC >> 8) & 0xFF); // ssrc of sender + buffer.push_back((SSRC) & 0xFF); // ssrc of sender + buffer.push_back(0x00); // ssrc of media source (always 0) + buffer.push_back(0x00); // ssrc of media source (always 0) + buffer.push_back(0x00); // ssrc of media source (always 0) + buffer.push_back(0x00); // ssrc of media source (always 0) + buffer.push_back('R'); // `R`, `E`, `M`, `B` + buffer.push_back('E'); // `R`, `E`, `M`, `B` + buffer.push_back('M'); // `R`, `E`, `M`, `B` + buffer.push_back('B'); // `R`, `E`, `M`, `B` + buffer.push_back(0x01); // num ssrc + buffer.push_back((uint8_t) (br_exponent << 2) + ((br_mantissa >> 16) & 0x03)); // br-exp and br-mantissa + buffer.push_back((uint8_t) (br_mantissa >> 8)); // br-exp and br-mantissa + buffer.push_back((uint8_t) br_mantissa); // br-exp and br-mantissa + buffer.push_back((rtcTrack.SSRC >> 24) & 0xFF); // ssrc to which this remb packet applies to. + buffer.push_back((rtcTrack.SSRC >> 16) & 0xFF); // ssrc to which this remb packet applies to. + buffer.push_back((rtcTrack.SSRC >> 8) & 0xFF); // ssrc to which this remb packet applies to. + buffer.push_back((rtcTrack.SSRC) & 0xFF); // ssrc to which this remb packet applies to. + + // rewrite size + int buffer_size_in_bytes = (int)buffer.size(); + int buffer_size_in_words_minus1 = ((int)buffer.size() / 4) - 1; + buffer[2] = (buffer_size_in_words_minus1 >> 8) & 0xFF; + buffer[3] = buffer_size_in_words_minus1 & 0xFF; + + // protect. + size_t trailer_space = SRTP_MAX_TRAILER_LEN + 4; + for (size_t i = 0; i < trailer_space; ++i) { + buffer.push_back(0x00); + } + + if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0) { + ERROR_MSG("Failed to protect the RTCP message."); + return; + } + + udp.SendNow((const char*)&buffer[0], buffer_size_in_bytes); + myConn.addUp(buffer_size_in_bytes); + } + + void OutWebRTC::sendRTCPFeedbackPLI(const WebRTCTrack& rtcTrack) { + + std::vector buffer; + buffer.push_back(0x80 | 0x01); // V=2 (0x80) | FMT=1 (0x01) + buffer.push_back(0xCE); // payload type = 206 + buffer.push_back(0x00); // payload size in words minus 1 (2) + buffer.push_back(0x02); // payload size in words minus 1 (2) + buffer.push_back((SSRC >> 24) & 0xFF); // ssrc of sender + buffer.push_back((SSRC >> 16) & 0xFF); // ssrc of sender + buffer.push_back((SSRC >> 8) & 0xFF); // ssrc of sender + buffer.push_back((SSRC) & 0xFF); // ssrc of sender + buffer.push_back((rtcTrack.SSRC >> 24) & 0xFF); // ssrc of receiver + buffer.push_back((rtcTrack.SSRC >> 16) & 0xFF); // ssrc of receiver + buffer.push_back((rtcTrack.SSRC >> 8) & 0xFF); // ssrc of receiver + buffer.push_back((rtcTrack.SSRC) & 0xFF); // ssrc of receiver + + // space for protection + size_t trailer_space = SRTP_MAX_TRAILER_LEN + 4; + for (size_t i = 0; i < trailer_space; ++i) { + buffer.push_back(0x00); + } + + // protect. + int buffer_size_in_bytes = (int)buffer.size(); + if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0) { + ERROR_MSG("Failed to protect the RTCP message."); + return; + } + + udp.SendNow((const char*)&buffer[0], buffer_size_in_bytes); + myConn.addUp(buffer_size_in_bytes); + } + + // Notify sender that we lost a packet. See + // https://tools.ietf.org/html/rfc4585#section-6.1 which + // describes the use of the `BLP` field; when more successive + // sequence numbers are lost it makes sense to implement this + // too. + void OutWebRTC::sendRTCPFeedbackNACK(const WebRTCTrack &rtcTrack, uint16_t lostSequenceNumber) { + + std::vector buffer; + buffer.push_back(0x80 | 0x01); // V=2 (0x80) | FMT=1 (0x01) + buffer.push_back(0xCD); // payload type = 205, RTPFB, https://tools.ietf.org/html/rfc4585#section-6.1 + buffer.push_back(0x00); // payload size in words minus 1 (3) + buffer.push_back(0x03); // payload size in words minus 1 (3) + buffer.push_back((SSRC >> 24) & 0xFF); // ssrc of sender + buffer.push_back((SSRC >> 16) & 0xFF); // ssrc of sender + buffer.push_back((SSRC >> 8) & 0xFF); // ssrc of sender + buffer.push_back((SSRC) & 0xFF); // ssrc of sender + buffer.push_back((rtcTrack.SSRC >> 24) & 0xFF); // ssrc of receiver + buffer.push_back((rtcTrack.SSRC >> 16) & 0xFF); // ssrc of receiver + buffer.push_back((rtcTrack.SSRC >> 8) & 0xFF); // ssrc of receiver + buffer.push_back((rtcTrack.SSRC) & 0xFF); // ssrc of receiver + buffer.push_back((lostSequenceNumber >> 8) & 0xFF); // PID: missing sequence number + buffer.push_back((lostSequenceNumber) & 0xFF); // PID: missing sequence number + buffer.push_back(0x00); // BLP: Bitmask of following losses. (not implemented atm). + buffer.push_back(0x00); // BLP: Bitmask of following losses. (not implemented atm). + + // space for protection + size_t trailer_space = SRTP_MAX_TRAILER_LEN + 4; + for (size_t i = 0; i < trailer_space; ++i) { + buffer.push_back(0x00); + } + + // protect. + int buffer_size_in_bytes = (int)buffer.size(); + if (srtpWriter.protectRtcp(&buffer[0], &buffer_size_in_bytes) != 0) { + ERROR_MSG("Failed to protect the RTCP message."); + return; + } + + udp.SendNow((const char*)&buffer[0], buffer_size_in_bytes); + myConn.addUp(buffer_size_in_bytes); + } + + void OutWebRTC::sendRTCPFeedbackRR(WebRTCTrack &rtcTrack) { + + ((RTP::FECPacket*)&(rtcTrack.rtpPacketizer))->sendRTCP_RR(rtcTrack.sorter, + SSRC, + rtcTrack.SSRC, + (void*)&udp, + onRTPPacketizerHasRTCPDataCallback); + + } + + void OutWebRTC::sendSPSPPS(DTSC::Track& dtscTrack, WebRTCTrack& rtcTrack) { + + if (dtscTrack.init.empty()) { + WARN_MSG("No init data found in the DTSC::Track. Not sending SPS and PPS"); + return; + } + + std::vector buf; + MP4::AVCC avcc; + avcc.setPayload(dtscTrack.init); + + /* SPS */ + for (uint32_t i = 0; i < avcc.getSPSCount(); ++i) { + + uint32_t len = avcc.getSPSLen(i); + if (len == 0) { + WARN_MSG("Empty SPS stored?"); + continue; + } + + buf.clear(); + buf.assign(4, 0); + *(uint32_t*)&buf[0] = htonl(len); + std::copy(avcc.getSPS(i), avcc.getSPS(i) + avcc.getSPSLen(i), std::back_inserter(buf)); + + rtcTrack.rtpPacketizer.sendData(&udp, + onRTPPacketizerHasDataCallback, + &buf[0], + buf.size(), + rtcTrack.payloadType, + dtscTrack.codec); + } + + /* PPS */ + for (uint32_t i = 0; i < avcc.getPPSCount(); ++i) { + + uint32_t len = avcc.getPPSLen(i); + if (len == 0) { + WARN_MSG("Empty PPS stored?"); + continue; + } + + buf.clear(); + buf.assign(4, 0); + *(uint32_t*)&buf[0] = htonl(len); + std::copy(avcc.getPPS(i), avcc.getPPS(i) + avcc.getPPSLen(i), std::back_inserter(buf)); + + rtcTrack.rtpPacketizer.sendData(&udp, + onRTPPacketizerHasDataCallback, + &buf[0], + buf.size(), + rtcTrack.payloadType, + dtscTrack.codec); + + } + } + + /* ------------------------------------------------ */ + + // This is our thread function that is started right before we + // call `allowPush()` and send our answer SDP back to the + // client. + static void webRTCInputOutputThreadFunc(void* arg) { + if (!classPointer) { + FAIL_MSG("classPointer hasn't been set. Exiting thread."); + return; + } + classPointer->handleWebRTCInputOutputFromThread(); + } + + static int onDTLSHandshakeWantsToWriteCallback(const uint8_t* data, int* nbytes) { + if (!classPointer) { + FAIL_MSG("Requested to send DTLS handshake data but the `classPointer` hasn't been set."); + return -1; + } + return classPointer->onDTLSHandshakeWantsToWrite(data, nbytes); + } + + static void onRTPSorterHasPacketCallback(const uint64_t track, const RTP::Packet &p) { + if (!classPointer) { + FAIL_MSG("We received a sorted RTP packet but our `classPointer` is invalid."); + return; + } + classPointer->onRTPSorterHasPacket(track, p); + } + + static void onDTSCConverterHasInitDataCallback(const uint64_t track, const std::string &initData) { + if (!classPointer) { + FAIL_MSG("Received a init data, but our `classPointer` is invalid."); + return; + } + classPointer->onDTSCConverterHasInitData(track, initData); + } + + static void onDTSCConverterHasPacketCallback(const DTSC::Packet& pkt) { + if (!classPointer) { + FAIL_MSG("Received a DTSC packet that was created from RTP data, but our `classPointer` is invalid."); + return; + } + classPointer->onDTSCConverterHasPacket(pkt); + } + + static void onRTPPacketizerHasDataCallback(void* socket, char* data, unsigned int len, unsigned int channel) { + if (!classPointer) { + FAIL_MSG("Received a RTP packet but our `classPointer` is invalid."); + return; + } + classPointer->onRTPPacketizerHasRTPPacket(data, len); + } + + static void onRTPPacketizerHasRTCPDataCallback(void* socket, const char* data, uint32_t len) { + if (!classPointer) { + FAIL_MSG("Received a RTCP packet, but out `classPointer` is invalid."); + return; + } + classPointer->onRTPPacketizerHasRTCPPacket((char*)data, len); + } + + /* ------------------------------------------------ */ + + static uint32_t generateSSRC() { + + uint32_t ssrc = 0; + + do { + ssrc = rand(); + ssrc = ssrc << 16; + ssrc += rand(); + } while (ssrc == 0 || ssrc == 0xffffffff); + + return ssrc; + } + + // This function is used to return a vector of the IP4 + // addresses of the interfaces on this machine. This is used + // when we create the candidate address that is shared in our + // SDP answer. The other WebRTC-peer will use this address to + // deliver data. + static std::vector getLocalIP4Addresses() { + + std::vector result; + ifaddrs* ifaddr = NULL; + ifaddrs* ifa = NULL; + sockaddr_in* addr = NULL; + char host[128] = { 0 }; + int s = 0; + int i = 0; + + if (getifaddrs(&ifaddr) == -1) { + FAIL_MSG("Failed to get the local interface addresses."); + return result; + } + + for (ifa = ifaddr, i = 0; ifa != NULL; ifa = ifa->ifa_next, i++) { + if (ifa->ifa_addr == NULL) { + continue; + } + if (ifa->ifa_addr->sa_family != AF_INET) { + continue; + } + addr = (sockaddr_in*)ifa->ifa_addr; + if (addr->sin_addr.s_addr == htonl(INADDR_LOOPBACK)) { + continue; + } + s = getnameinfo(ifa->ifa_addr, sizeof(sockaddr_in), host, sizeof(host), NULL, 0, NI_NUMERICHOST); + if (0 != s) { + FAIL_MSG("FAiled to get name info for an interface."); + continue; + } + result.push_back(host); + } + + if (ifaddr != NULL) { + freeifaddrs(ifaddr); + ifaddr = NULL; + } + + return result; + } + + /* ------------------------------------------------ */ + +} // mist namespace diff --git a/src/output/output_webrtc.h b/src/output/output_webrtc.h new file mode 100644 index 00000000..d83f0da6 --- /dev/null +++ b/src/output/output_webrtc.h @@ -0,0 +1,173 @@ +/* + + SOME NOTES ON MIST + + - When a user wants to start pushing video into Mist we need to + check if the user is actually allowed to do this. When the user + is allowed to push we have to call the function `allowPush("")`. + + SIGNALING + + 1. Client sends the offer: + + { + type: "offer_sdp", + offer_sdp: , + } + + Server responds with: + + SUCCESS: + { + type: "on_answer_sdp", + result: true, + answer_sdp: , + } + + ERROR: + { + type: "on_answer_sdp", + result: false, + } + + 2. Client request new bitrate: + + { + type: "video_bitrate" + video_bitrate: 600000 + } + + Server responds with: + + SUCCESS: + { + type: "on_video_bitrate" + result: true + } + + ERROR: + { + type: "on_video_bitrate" + result: false + } + + */ +#pragma once + +#include "output.h" +#include "output_http.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(WEBRTC_PCAP) +# include +#endif + +namespace Mist { + + /* ------------------------------------------------ */ + + class WebRTCTrack { + public: + WebRTCTrack(); ///< Initializes to some defaults. + + public: + RTP::toDTSC rtpToDTSC; ///< Converts RTP packets into DTSC packets. + RTP::FECSorter sorter; ///< Takes care of sorting the received RTP packet and keeps track of some statistics. Will call a callback whenever a packet can be used. (e.g. not lost, in correct order). + RTP::Packet rtpPacketizer; ///< Used when we're sending RTP data back to the other peer. + uint64_t payloadType; ///< The payload type that was extracted from the `m=` media line in the SDP. + std::string localIcePwd; + std::string localIceUFrag; + uint32_t SSRC; ///< The SSRC of the RTP packets. + uint32_t timestampMultiplier; ///< Used for outgoing streams to convert the DTSC timestamps into RTP timestamps. + uint8_t ULPFECPayloadType; ///< When we've enabled FEC for a video stream this holds the payload type that is used to distinguish between ordinary video RTP packets and FEC packets. + uint8_t REDPayloadType; ///< When using RED and ULPFEC this holds the payload type of the RED stream. + uint8_t RTXPayloadType; ///< The retransmission payload type when we use RTX (retransmission with separate SSRC/payload type) + uint16_t prevReceivedSequenceNumber; ///< The previously received sequence number. This is used to NACK packets when we loose one. + }; + + /* ------------------------------------------------ */ + + class OutWebRTC : public HTTPOutput { + public: + OutWebRTC(Socket::Connection &myConn); + ~OutWebRTC(); + static void init(Util::Config *cfg); + virtual void sendHeader(); + virtual void sendNext(); + virtual void onWebsocketFrame(); + bool doesWebsockets(){return true;} + void handleWebRTCInputOutputFromThread(); + int onDTLSHandshakeWantsToWrite(const uint8_t* data, int* nbytes); + void onRTPSorterHasPacket(const uint64_t trackID, const RTP::Packet &pkt); + void onDTSCConverterHasPacket(const DTSC::Packet& pkt); + void onDTSCConverterHasInitData(const uint64_t trackID, const std::string &initData); + void onRTPPacketizerHasRTPPacket(char* data, uint32_t nbytes); + void onRTPPacketizerHasRTCPPacket(char* data, uint32_t nbytes); + + private: + bool handleWebRTCInputOutput(); ///< Reads data from the UDP socket. Returns true when we read some data, othewise false. + void handleReceivedSTUNPacket(); + void handleReceivedDTLSPacket(); + void handleReceivedRTPOrRTCPPacket(); + void handleSignalingCommand(HTTP::Websocket& webSock, const JSON::Value &command); + bool handleSignalingCommandRemoteOffer(HTTP::Websocket &webSock, const JSON::Value &command); + bool handleSignalingCommandRemoteOfferForInput(HTTP::Websocket &webSocket, SDP::Session &sdpSession, const std::string &sdpOffer); + bool handleSignalingCommandRemoteOfferForOutput(HTTP::Websocket &webSocket, SDP::Session &sdpSession, const std::string &sdpOffer); + bool handleSignalingCommandVideoBitrate(HTTP::Websocket& webSock, const JSON::Value &command); + bool handleSignalingCommandSeek(HTTP::Websocket& webSock, const JSON::Value &command); + bool handleSignalingCommandKeyFrameInterval(HTTP::Websocket &webSock, const JSON::Value &command); ///< Handles the command that can be used to set the keyframe interval for the current connection. We will sent RTCP PLI messages every X-millis; the other agent -should- generate keyframes when it receives PLI messages (Picture Loss Indication). + void sendSignalingError(HTTP::Websocket& webSock, const std::string& commandType, const std::string& errorMessage); + bool validateSignalingCommand(HTTP::Websocket& webSock, const JSON::Value &command, JSON::Value &errorResult); + + bool createWebRTCTrackFromAnswer(const SDP::Media& mediaAnswer, const SDP::MediaFormat& formatAnswer, WebRTCTrack& result); + void sendRTCPFeedbackREMB(const WebRTCTrack &rtcTrack); + void sendRTCPFeedbackPLI(const WebRTCTrack &rtcTrack); ///< Picture Los Indication: request keyframe. + void sendRTCPFeedbackRR(WebRTCTrack &rtcTrack); + void sendRTCPFeedbackNACK(const WebRTCTrack &rtcTrack, uint16_t missingSequenceNumber); ///< Notify sender that we're missing a sequence number. + void sendSPSPPS(DTSC::Track& dtscTrack, WebRTCTrack& rtcTrack);///< When we're streaming H264 to e.g. the browser we inject the PPS and SPS nals. + void extractFrameSizeFromVP8KeyFrame(const DTSC::Packet &pkt); + void updateCapabilitiesWithSDPOffer(SDP::Session &sdpSession); + bool bindUDPSocketOnLocalCandidateAddress(uint16_t port); ///< Binds our UDP socket onto the IP address that we shared via our SDP answer. We *have to* bind on a specific IP, see https://gist.github.com/roxlu/6c5ab696840256dac71b6247bab59ce9 + std::string getLocalCandidateAddress(); + + private: + SDP::Session sdp; ///< SDP parser. + SDP::Answer sdpAnswer; ///< WIP: Replacing our `sdp` member .. + Certificate cert; ///< The TLS certificate. Used to generate a fingerprint in SDP answers. + DTLSSRTPHandshake dtlsHandshake; ///< Implements the DTLS handshake using the mbedtls library (fork). + SRTPReader srtpReader; ///< Used to unprotect incoming RTP and RTCP data. Uses the keys that were exchanged with DTLS. + SRTPWriter srtpWriter; ///< Used to protect our RTP and RTCP data when sending data to another peer. Uses the keys that were exchanged with DTLS. + Socket::UDPConnection udp; ///< Our UDP socket over which WebRTC data is received and sent. + StunReader stunReader; ///< Decodes STUN messages; during a session we keep receiving STUN messages to which we need to reply. + std::map webrtcTracks; ///< WebRTCTracks indexed by payload type for incoming data and indexed by myMeta.tracks[].trackID for outgoing data. + tthread::thread* webRTCInputOutputThread; ///< The thread in which we read WebRTC data when we're receive media from another peer. + uint16_t udpPort; ///< The port on which our webrtc socket is bound. This is where we receive RTP, STUN, DTLS, etc. */ + uint32_t SSRC; ///< The SSRC for this local instance. Is used when generating RTCP reports. */ + uint64_t rtcpTimeoutInMillis; ///< When current time in millis exceeds this timeout we have to send a new RTCP packet. + uint64_t rtcpKeyFrameTimeoutInMillis; + uint64_t rtcpKeyFrameDelayInMillis; + char* rtpOutBuffer; ///< Buffer into which we copy (unprotected) RTP data that we need to deliver to the other peer. This gets protected. + uint32_t videoBitrate; ///< The bitrate to use for incoming video streams. Can be configured via the signaling channel. Defaults to 6mbit. + + bool didReceiveKeyFrame; /* TODO burst delay */ + +#if defined(WEBRTC_PCAP) + PCAPWriter pcapOut; ///< Used during development to write unprotected packets that can be inspected in e.g. wireshark. + PCAPWriter pcapIn; ///< Used during development to write unprotected packets that can be inspected in e.g. wireshark. +#endif + + std::map payloadTypeToWebRTCTrack; ///< Maps e.g. RED to the corresponding track. Used when input supports RED/ULPFEC; can also be used to map RTX in the future. + }; +} + +typedef Mist::OutWebRTC mistOut;