Implemented WebRTC

This commit is contained in:
roxlu 2018-06-01 09:19:32 +02:00 committed by Thulinma
parent dce4cddadd
commit 7e8eb634e6
20 changed files with 6712 additions and 1 deletions

240
lib/certificate.cpp Normal file
View file

@ -0,0 +1,240 @@
#include "certificate.h"
#include "defines.h"
Certificate::Certificate()
:rsa_ctx(NULL)
{
memset((void*)&cert, 0x00, sizeof(cert));
memset((void*)&key, 0x00, sizeof(key));
}
int Certificate::init(const std::string &countryName,
const std::string &organization,
const std::string& commonName)
{
mbedtls_ctr_drbg_context rand_ctx = {};
mbedtls_entropy_context entropy_ctx = {};
mbedtls_x509write_cert write_cert = {};
const char* personalisation = "mbedtls-self-signed-key";
std::string subject_name = "C=" +countryName +",O=" +organization +",CN=" +commonName;
time_t time_from = { 0 };
time_t time_to = { 0 };
char time_from_str[20] = { 0 };
char time_to_str[20] = { 0 };
mbedtls_mpi serial_mpi = { 0 };
char serial_hex[17] = { 0 };
uint64_t serial_num = 0;
uint8_t* serial_ptr = (uint8_t*)&serial_num;
int r = 0;
int i = 0;
uint8_t buf[4096] = { 0 };
// validate
if (countryName.empty()) {
FAIL_MSG("Given `countryName`, C=<countryName>, is empty.");
r = -1;
goto error;
}
if (organization.empty()) {
FAIL_MSG("Given `organization`, O=<organization>, is empty.");
r = -2;
goto error;
}
if (commonName.empty()) {
FAIL_MSG("Given `commonName`, CN=<commonName>, is empty.");
r = -3;
goto error;
}
// initialize random number generator
mbedtls_ctr_drbg_init(&rand_ctx);
mbedtls_entropy_init(&entropy_ctx);
r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, (const unsigned char*)personalisation, strlen(personalisation));
if (0 != r) {
FAIL_MSG("Failed to initialize and seed the entropy context.");
r = -10;
goto error;
}
// initialize the public key context
mbedtls_pk_init(&key);
r = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA));
if (0 != r) {
FAIL_MSG("Faild to initialize the PK context.");
r = -20;
goto error;
}
rsa_ctx = mbedtls_pk_rsa(key);
if (NULL == rsa_ctx) {
FAIL_MSG("Failed to get the RSA context from from the public key context (key).");
r = -30;
goto error;
}
r = mbedtls_rsa_gen_key(rsa_ctx, mbedtls_ctr_drbg_random, &rand_ctx, 2048, 65537);
if (0 != r) {
FAIL_MSG("Failed to generate a private key.");
r = -40;
goto error;
}
// calc the valid from and until time.
time_from = time(NULL);
time_from = (time_from < 1000000000) ? 1000000000 : time_from;
time_to = time_from + (60 * 60 * 24 * 365); // valid for a year
if (time_to < time_from) {
time_to = INT_MAX;
}
r = strftime(time_from_str, sizeof(time_from_str), "%Y%m%d%H%M%S", gmtime(&time_from));
if (0 == r) {
FAIL_MSG("Failed to generate the valid-from time string.");
r = -50;
goto error;
}
r = strftime(time_to_str, sizeof(time_to_str), "%Y%m%d%H%M%S", gmtime(&time_to));
if (0 == r) {
FAIL_MSG("Failed to generate the valid-to time string.");
r = -60;
goto error;
}
r = mbedtls_ctr_drbg_random((void*)&rand_ctx, (uint8_t*)&serial_num, sizeof(serial_num));
if (0 != r) {
FAIL_MSG("Failed to generate a random u64.");
r = -70;
goto error;
}
for (i = 0; i < 8; ++i) {
sprintf(serial_hex + (i * 2), "%02x", serial_ptr[i]);
}
// start creating the certificate
mbedtls_x509write_crt_init(&write_cert);
mbedtls_x509write_crt_set_md_alg(&write_cert, MBEDTLS_MD_SHA256);
mbedtls_x509write_crt_set_issuer_key(&write_cert, &key);
mbedtls_x509write_crt_set_subject_key(&write_cert, &key);
r = mbedtls_x509write_crt_set_subject_name(&write_cert, subject_name.c_str());
if (0 != r) {
FAIL_MSG("Failed to set the subject name.");
r = -80;
goto error;
}
r = mbedtls_x509write_crt_set_issuer_name(&write_cert, subject_name.c_str());
if (0 != r) {
FAIL_MSG("Failed to set the issuer name.");
r = -90;
goto error;
}
r = mbedtls_x509write_crt_set_validity(&write_cert, time_from_str, time_to_str);
if (0 != r) {
FAIL_MSG("Failed to set the x509 validity string.");
r = -100;
goto error;
}
r = mbedtls_x509write_crt_set_basic_constraints(&write_cert, 0, -1);
if (0 != r) {
FAIL_MSG("Failed ot set the basic constraints for the certificate.");
r = -110;
goto error;
}
r = mbedtls_x509write_crt_set_subject_key_identifier(&write_cert);
if (0 != r) {
FAIL_MSG("Failed to set the subjectKeyIdentifier.");
r = -120;
goto error;
}
r = mbedtls_x509write_crt_set_authority_key_identifier(&write_cert);
if (0 != r) {
FAIL_MSG("Failed to set the authorityKeyIdentifier.");
r = -130;
goto error;
}
// set certificate serial; mpi is used to perform i/o
mbedtls_mpi_init(&serial_mpi);
mbedtls_mpi_read_string(&serial_mpi, 16, serial_hex);
r = mbedtls_x509write_crt_set_serial(&write_cert, &serial_mpi);
if (0 != r) {
FAIL_MSG("Failed to set the certificate serial.");
r = -140;
goto error;
}
// write the certificate into a PEM structure
r = mbedtls_x509write_crt_pem(&write_cert, buf, sizeof(buf), mbedtls_ctr_drbg_random, &rand_ctx);
if (0 != r) {
FAIL_MSG("Failed to create the PEM data from the x509 write structure.");
r = -150;
goto error;
}
// convert the PEM data into out `mbedtls_x509_cert` member.
// len should be PEM including the string null terminating
// char. @todo there must be a way to convert the write
// struct into a `mbedtls_x509_cert` w/o calling this parse
// function.
mbedtls_x509_crt_init(&cert);
r = mbedtls_x509_crt_parse(&cert, (const unsigned char*)buf, strlen((char*)buf) + 1);
if (0 != r) {
mbedtls_strerror(r, (char*)buf, sizeof(buf));
FAIL_MSG("Failed to convert the mbedtls_x509write_crt into a mbedtls_x509_crt: %s", buf);
r = -160;
goto error;
}
error:
// cleanup
mbedtls_ctr_drbg_free(&rand_ctx);
mbedtls_entropy_free(&entropy_ctx);
mbedtls_x509write_crt_free(&write_cert);
mbedtls_mpi_free(&serial_mpi);
if (r < 0) {
shutdown();
}
return r;
}
int Certificate::shutdown() {
rsa_ctx = NULL;
mbedtls_pk_free(&key);
mbedtls_x509_crt_free(&cert);
return 0;
}
std::string Certificate::getFingerprintSha256() {
uint8_t fingerprint_raw[32] = {};
uint8_t fingerprint_hex[128] = {};
mbedtls_md_type_t hash_type = MBEDTLS_MD_SHA256;
mbedtls_sha256(cert.raw.p, cert.raw.len, fingerprint_raw, 0);
for (int i = 0; i < 32; ++i) {
sprintf((char*)(fingerprint_hex + (i * 3)), ":%02X", (int)fingerprint_raw[i]);
}
fingerprint_hex[32 * 3] = '\0';
std::string result = std::string((char*)fingerprint_hex + 1, (32 * 3) - 1);
return result;
}

34
lib/certificate.h Normal file
View file

@ -0,0 +1,34 @@
#pragma once
/*
MBEDTLS BASED CERTIFICATE
=========================
This class can be used to generate a self-signed x509
certificate which enables you to perform secure
communication. This certificate uses a 2048 bits RSA key.
*/
#include <string>
#include <mbedtls/config.h>
#include <mbedtls/x509_crt.h>
#include <mbedtls/x509_csr.h>
#include <mbedtls/entropy.h>
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/md.h>
#include <mbedtls/error.h>
#include <mbedtls/sha256.h>
class Certificate {
public:
Certificate();
int init(const std::string &countryName, const std::string &organization, const std::string& commonName);
int shutdown();
std::string getFingerprintSha256();
public:
mbedtls_x509_crt cert;
mbedtls_pk_context key; /* key context, stores private and public key. */
mbedtls_rsa_context* rsa_ctx; /* rsa context, stored in key_ctx */
};

420
lib/dtls_srtp_handshake.cpp Normal file
View file

@ -0,0 +1,420 @@
#include <algorithm>
#include "defines.h"
#include "dtls_srtp_handshake.h"
/* Write mbedtls into a log file. */
#define LOG_TO_FILE 0
#if LOG_TO_FILE
# include <fstream>
#endif
/* ----------------------------------------- */
static void print_mbedtls_error(int r);
static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str);
static int on_mbedtls_wants_to_read(void* user, unsigned char* buf, size_t len); /* Called when mbedtls wants to read data from e.g. a socket. */
static int on_mbedtls_wants_to_write(void* user, const unsigned char* buf, size_t len); /* Called when mbedtls wants to write data to e.g. a socket. */
static std::string mbedtls_err_to_string(int r);
/* ----------------------------------------- */
DTLSSRTPHandshake::DTLSSRTPHandshake()
:write_callback(NULL)
,cert(NULL)
,key(NULL)
{
memset((void*)&entropy_ctx, 0x00, sizeof(entropy_ctx));
memset((void*)&rand_ctx, 0x00, sizeof(rand_ctx));
memset((void*)&ssl_ctx, 0x00, sizeof(ssl_ctx));
memset((void*)&ssl_conf, 0x00, sizeof(ssl_conf));
memset((void*)&cookie_ctx, 0x00, sizeof(cookie_ctx));
memset((void*)&timer_ctx, 0x00, sizeof(timer_ctx));
}
int DTLSSRTPHandshake::init(mbedtls_x509_crt* certificate,
mbedtls_pk_context* privateKey,
int(*writeCallback)(const uint8_t* data, int* nbytes)
)
{
int r = 0;
mbedtls_ssl_srtp_profile srtp_profiles[] = {
MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80,
MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32
};
if (!writeCallback) {
FAIL_MSG("No writeCallack function given.");
r = -3;
goto error;
}
if (!certificate) {
FAIL_MSG("Given certificate is null.");
r = -5;
goto error;
}
if (!privateKey) {
FAIL_MSG("Given key is null.");
r = -10;
goto error;
}
cert = certificate;
key = privateKey;
/* init the contexts */
mbedtls_entropy_init(&entropy_ctx);
mbedtls_ctr_drbg_init(&rand_ctx);
mbedtls_ssl_init(&ssl_ctx);
mbedtls_ssl_config_init(&ssl_conf);
mbedtls_ssl_cookie_init(&cookie_ctx);
/* seed and setup the random number generator */
r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, (const unsigned char*)"mist-srtp", 9);
if (0 != r) {
print_mbedtls_error(r);
r = -20;
goto error;
}
/* load defaults into our ssl_conf */
r = mbedtls_ssl_config_defaults(&ssl_conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT);
if (0 != r) {
print_mbedtls_error(r);
r = -30;
goto error;
}
mbedtls_ssl_conf_authmode(&ssl_conf, MBEDTLS_SSL_VERIFY_NONE);
mbedtls_ssl_conf_rng(&ssl_conf, mbedtls_ctr_drbg_random, &rand_ctx);
mbedtls_ssl_conf_dbg(&ssl_conf, print_mbedtls_debug_message, stdout);
mbedtls_ssl_conf_ca_chain(&ssl_conf, cert, NULL);
mbedtls_ssl_conf_cert_profile(&ssl_conf, &mbedtls_x509_crt_profile_default);
mbedtls_debug_set_threshold(10);
/* enable SRTP */
r = mbedtls_ssl_conf_dtls_srtp_protection_profiles(&ssl_conf, srtp_profiles, sizeof(srtp_profiles) / sizeof(srtp_profiles[0]));
if (0 != r) {
print_mbedtls_error(r);
r = -40;
goto error;
}
/* cert certificate chain + key, so we can verify the client-hello signed data */
r = mbedtls_ssl_conf_own_cert(&ssl_conf, cert, key);
if (0 != r) {
print_mbedtls_error(r);
r = -50;
goto error;
}
/* cookie setup (e.g. to prevent ddos). */
r = mbedtls_ssl_cookie_setup(&cookie_ctx, mbedtls_ctr_drbg_random, &rand_ctx);
if (0 != r) {
print_mbedtls_error(r);
r = -60;
goto error;
}
/* register callbacks for dtls cookies (server only). */
mbedtls_ssl_conf_dtls_cookies(&ssl_conf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &cookie_ctx);
/* setup the ssl context for use. note that ssl_conf will be referenced internall by the context and therefore should be kept around. */
r = mbedtls_ssl_setup(&ssl_ctx, &ssl_conf);
if (0 != r) {
print_mbedtls_error(r);
r = -70;
goto error;
}
/* set bio handlers */
mbedtls_ssl_set_bio(&ssl_ctx, (void*)this, on_mbedtls_wants_to_write, on_mbedtls_wants_to_read, NULL);
/* set temp id, just adds some exta randomness */
{
std::string remote_id = "mist";
r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char*)remote_id.c_str(), remote_id.size());
if (0 != r) {
print_mbedtls_error(r);
r = -80;
goto error;
}
}
/* set timer callbacks */
mbedtls_ssl_set_timer_cb(&ssl_ctx, &timer_ctx, mbedtls_timing_set_delay, mbedtls_timing_get_delay);
write_callback = writeCallback;
error:
if (r < 0) {
shutdown();
}
return r;
}
int DTLSSRTPHandshake::shutdown() {
/* cleanup the refs from the settings. */
cert = NULL;
key = NULL;
buffer.clear();
cipher.clear();
remote_key.clear();
remote_salt.clear();
local_key.clear();
local_salt.clear();
/* free our contexts; we do not free the `settings.cert` and `settings.key` as they are owned by the user of this class. */
mbedtls_entropy_free(&entropy_ctx);
mbedtls_ctr_drbg_free(&rand_ctx);
mbedtls_ssl_free(&ssl_ctx);
mbedtls_ssl_config_free(&ssl_conf);
mbedtls_ssl_cookie_free(&cookie_ctx);
return 0;
}
/* ----------------------------------------- */
int DTLSSRTPHandshake::parse(const uint8_t* data, size_t nbytes) {
if (NULL == data) {
ERROR_MSG("Given `data` is NULL.");
return -1;
}
if (0 == nbytes) {
ERROR_MSG("Given nbytes is 0.");
return -2;
}
if (MBEDTLS_SSL_HANDSHAKE_OVER == ssl_ctx.state) {
ERROR_MSG("Already finished the handshake.");
return -3;
}
/* copy incoming data into a temporary buffer which is read via our `bio` read function. */
int r = 0;
std::copy(data, data + nbytes, std::back_inserter(buffer));
do {
r = mbedtls_ssl_handshake(&ssl_ctx);
switch (r) {
/* 0 = handshake done. */
case 0: {
if (0 != extractKeyingMaterial()) {
ERROR_MSG("Failed to extract keying material after handshake was done.");
return -2;
}
return 0;
}
/* see the dtls server example; this is used to prevent certain attacks (ddos) */
case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: {
if (0 != resetSession()) {
ERROR_MSG("Failed to reset the session which is necessary when we need to verify the HELLO.");
return -3;
}
break;
}
case MBEDTLS_ERR_SSL_WANT_READ: {
DONTEVEN_MSG("mbedtls wants a bit more data before it can continue parsing the DTLS handshake.");
break;
}
default: {
ERROR_MSG("A serious mbedtls error occured.");
print_mbedtls_error(r);
return -2;
}
}
}
while (MBEDTLS_ERR_SSL_WANT_WRITE == r);
return 0;
}
/* ----------------------------------------- */
int DTLSSRTPHandshake::resetSession() {
std::string remote_id = "mist"; /* @todo for now we hardcoded this... */
int r = 0;
r = mbedtls_ssl_session_reset(&ssl_ctx);
if (0 != r) {
print_mbedtls_error(r);
return -1;
}
r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char*)remote_id.c_str(), remote_id.size());
if (0 != r) {
print_mbedtls_error(r);
return -2;
}
buffer.clear();
return 0;
}
/*
master key is 128 bits => 16 bytes.
master salt is 112 bits => 14 bytes
*/
int DTLSSRTPHandshake::extractKeyingMaterial() {
int r = 0;
uint8_t keying_material[MBEDTLS_DTLS_SRTP_MAX_KEY_MATERIAL_LENGTH] = {};
size_t keying_material_len = sizeof(keying_material);
r = mbedtls_ssl_get_dtls_srtp_key_material(&ssl_ctx, keying_material, &keying_material_len);
if (0 != r) {
print_mbedtls_error(r);
return -1;
}
/* @todo following code is for server mode only */
mbedtls_ssl_srtp_profile srtp_profile = mbedtls_ssl_get_dtls_srtp_protection_profile(&ssl_ctx);
switch (srtp_profile) {
case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80: {
cipher = "SRTP_AES128_CM_SHA1_80";
break;
}
case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32: {
cipher = "SRTP_AES128_CM_SHA1_32";
break;
}
default: {
ERROR_MSG("Unhandled SRTP profile, cannot extract keying material.");
return -6;
}
}
remote_key.assign((char*)(&keying_material[0]) + 0, 16);
local_key.assign((char*)(&keying_material[0]) + 16, 16);
remote_salt.assign((char*)(&keying_material[0]) + 32, 14);
local_salt.assign((char*)(&keying_material[0]) + 46, 14);
DONTEVEN_MSG("Extracted the DTLS SRTP keying material with cipher %s.", cipher.c_str());
DONTEVEN_MSG("Remote DTLS SRTP key size is %zu.", remote_key.size());
DONTEVEN_MSG("Remote DTLS SRTP salt size is %zu.", remote_salt.size());
DONTEVEN_MSG("Local DTLS SRTP key size is %zu.", local_key.size());
DONTEVEN_MSG("Local DTLS SRTP salt size is %zu.", local_salt.size());
return 0;
}
/* ----------------------------------------- */
/*
This function is called by mbedtls whenever it wants to read
some data. The documentation states the following: "For DTLS,
you need to provide either a non-NULL f_recv_timeout
callback, or a f_recv that doesn't block." As this
implementation is completely decoupled from any I/O and uses
a "push" model instead of a "pull" model we have to copy new
input bytes into a temporary buffer (see parse), but we act
as if we were using a non-blocking socket, which means:
- we return MBETLS_ERR_SSL_WANT_READ when there is no data left to read
- when there is data in our temporary buffer, we read from that
*/
static int on_mbedtls_wants_to_read(void* user, unsigned char* buf, size_t len) {
DTLSSRTPHandshake* hs = static_cast<DTLSSRTPHandshake*>(user);
if (NULL == hs) {
ERROR_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake.");
return -1;
}
/* figure out how much we can read. */
if (hs->buffer.size() == 0) {
return MBEDTLS_ERR_SSL_WANT_READ;
}
size_t nbytes = hs->buffer.size();
if (nbytes > len) {
nbytes = len;
}
/* "read" into the given buffer. */
memcpy(buf, &hs->buffer[0], nbytes);
hs->buffer.erase(hs->buffer.begin(), hs->buffer.begin() + nbytes);
return (int)nbytes;
}
static int on_mbedtls_wants_to_write(void* user, const unsigned char* buf, size_t len) {
DTLSSRTPHandshake* hs = static_cast<DTLSSRTPHandshake*>(user);
if (!hs) {
FAIL_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake.");
return -1;
}
if (!hs->write_callback) {
FAIL_MSG("The `write_callback` member is NULL.");
return -2;
}
int nwritten = (int)len;
if (0 != hs->write_callback(buf, &nwritten)) {
FAIL_MSG("Failed to write some DTLS handshake data.");
return -3;
}
if (nwritten != (int)len) {
FAIL_MSG("The DTLS-SRTP handshake listener MUST write all the data.");
return -4;
}
return nwritten;
}
/* ----------------------------------------- */
static void print_mbedtls_error(int r) {
char buf[1024] = {};
mbedtls_strerror(r, buf, sizeof(buf));
ERROR_MSG("mbedtls error: %s", buf);
}
static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str) {
DONTEVEN_MSG("%s:%04d: %.*s", file, line, strlen(str) - 1, str);
#if LOG_TO_FILE
static std::ofstream ofs;
if (!ofs.is_open()) {
ofs.open("mbedtls.log", std::ios::out);
}
if (!ofs.is_open()) {
return;
}
ofs << str;
ofs.flush();
#endif
}
static std::string mbedtls_err_to_string(int r) {
switch (r) {
case MBEDTLS_ERR_SSL_WANT_READ: { return "MBEDTLS_ERR_SSL_WANT_READ"; }
case MBEDTLS_ERR_SSL_WANT_WRITE: { return "MBEDTLS_ERR_SSL_WANT_WRITE"; }
default: {
print_mbedtls_error(r);
return "UNKNOWN";
}
}
}
/* ---------------------------------------- */

59
lib/dtls_srtp_handshake.h Normal file
View file

@ -0,0 +1,59 @@
#pragma once
#include <stdint.h>
#include <deque>
#include <mbedtls/config.h>
#include <mbedtls/entropy.h>
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/certs.h>
#include <mbedtls/x509.h>
#include <mbedtls/ssl.h>
#include <mbedtls/ssl_cookie.h>
#include <mbedtls/error.h>
#include <mbedtls/debug.h>
#include <mbedtls/timing.h>
/* ----------------------------------------- */
class DTLSSRTPHandshake {
public:
DTLSSRTPHandshake();
int init(mbedtls_x509_crt* certificate, mbedtls_pk_context* privateKey, int(*writeCallback)(const uint8_t* data, int* nbytes)); // writeCallback should return 0 on succes < 0 on error. nbytes holds the number of bytes to be sent and needs to be set to the number of bytes actually sent.
int shutdown();
int parse(const uint8_t* data, size_t nbytes);
bool hasKeyingMaterial();
private:
int extractKeyingMaterial();
int resetSession();
private:
mbedtls_x509_crt* cert; /* Certificate, we do not own the key. Make sure it's kept alive during the livetime of this class instance. */
mbedtls_pk_context* key; /* Private key, we do not own the key. Make sure it's kept alive during the livetime of this class instance. */
mbedtls_entropy_context entropy_ctx;
mbedtls_ctr_drbg_context rand_ctx;
mbedtls_ssl_context ssl_ctx;
mbedtls_ssl_config ssl_conf;
mbedtls_ssl_cookie_ctx cookie_ctx;
mbedtls_timing_delay_context timer_ctx;
public:
int (*write_callback)(const uint8_t* data, int* nbytes);
std::deque<uint8_t> buffer; /* Accessed from BIO callbback. We copy the bytes you pass into `parse()` into this temporary buffer which is read by a trigger to `mbedlts_ssl_handshake()`. */
std::string cipher; /* selected SRTP cipher. */
std::string remote_key;
std::string remote_salt;
std::string local_key;
std::string local_salt;
};
/* ----------------------------------------- */
inline bool DTLSSRTPHandshake::hasKeyingMaterial() {
return (0 != remote_key.size()
&& 0 != remote_salt.size()
&& 0 != local_key.size()
&& 0 != local_salt.size());
}
/* ----------------------------------------- */

569
lib/rtp_fec.cpp Normal file
View file

@ -0,0 +1,569 @@
#include "defines.h"
#include "rtp_fec.h"
#include "rtp.h"
namespace RTP{
/// Based on the `block PT` value, we can either find the
/// contents of the codec payload (e.g. H264, VP8) or a ULPFEC header
/// (RFC 5109). The structure of the ULPFEC data is as follows.
///
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | RTP Header (12 octets or more) |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | FEC Header (10 octets) |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | FEC Level 0 Header |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | FEC Level 0 Payload |
/// | |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | FEC Level 1 Header |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | FEC Level 1 Payload |
/// | |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | Cont. |
/// | |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
///
/// FEC HEADER:
///
/// 0 1 2 3
/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |E|L|P|X| CC |M| PT recovery | SN base |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | TS recovery |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | length recovery |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
///
///
/// FEC LEVEL HEADER
///
/// 0 1 2 3
/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | Protection Length | mask |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | mask cont. (present only when L = 1) |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
///
PacketFEC::PacketFEC() {
}
PacketFEC::~PacketFEC() {
receivedSeqNums.clear();
coveredSeqNums.clear();
}
bool PacketFEC::initWithREDPacket(const char* data, size_t nbytes) {
if (!data) {
FAIL_MSG("Given fecData pointer is NULL.");
return false;
}
if (nbytes < 23) {
FAIL_MSG("Given fecData is too small. Should be at least: 12 (RTP) + 1 (RED) + 10 (FEC) 23 bytes.");
return false;
}
if (coveredSeqNums.size() != 0) {
FAIL_MSG("It seems we're already initialized; coveredSeqNums already set.");
return false;
}
if (receivedSeqNums.size() != 0) {
FAIL_MSG("It seems we're already initialized; receivedSeqNums is not empty.");
return false;
}
// Decode RED header.
RTP::Packet rtpPkt(data, nbytes);
uint8_t* redHeader = (uint8_t*)(data + rtpPkt.getHsize());
uint8_t moreBlocks = redHeader[0] & 0x80;
if (moreBlocks == 1) {
FAIL_MSG("RED header indicates there are multiple blocks. Haven't seen this before (@todo implement, exiting now).");
// \todo do not EXIT!
return false;
}
// Copy the data, starting at the FEC header (skip RTP + RED header)
size_t numHeaderBytes = rtpPkt.getHsize() + 1;
if (numHeaderBytes > nbytes) {
FAIL_MSG("Invalid FEC packet; too small to contain FEC data.");
return false;
}
fecPacketData.assign(NULL, 0);
fecPacketData.append(data + numHeaderBytes, nbytes - numHeaderBytes);
// Extract the sequence numbers this packet protects.
if (!extractCoveringSequenceNumbers()) {
FAIL_MSG("Failed to extract the protected sequence numbers for this FEC.");
// @todo we probably want to reset our set.
return false;
}
return true;
}
uint8_t PacketFEC::getExtensionFlag() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get extension-flag from the FEC header; fecPacketData member is not set. Not initialized?");
return 0;
}
return ((fecPacketData[0] & 0x80) >> 7);
}
uint8_t PacketFEC::getLongMaskFlag() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get the long-mask-flag from the FEC header. fecPacketData member is not set. Not initialized?");
return 0;
}
return ((fecPacketData[0] & 0x40) >> 6);
}
// Returns 0 (error), 2 or 6, wich are the valid sizes of the mask.
uint8_t PacketFEC::getNumBytesUsedForMask() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get the number of bytes used by the mask. fecPacketData member is not set. Not initialized?");
return 0;
}
if (getLongMaskFlag() == 0) {
return 2;
}
return 6;
}
uint16_t PacketFEC::getSequenceBaseNumber() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get the sequence base number. fecPacketData member is not set. Not initialized?");
return 0;
}
return (uint16_t) (fecPacketData[2] << 8) | fecPacketData[3];
}
char* PacketFEC::getFECHeader() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get fec header. fecPacketData member is not set. Not initialized?");
}
return fecPacketData;
}
char* PacketFEC::getLevel0Header() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get the level 0 header. fecPacketData member is not set. Not initialized?");
return NULL;
}
return (char*)(fecPacketData + 10);
}
char* PacketFEC::getLevel0Payload() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get the level 0 payload. fecPacketData member is not set. Not initialized?");
return NULL;
}
// 10 bytes for FEC header
// 2 bytes for `Protection Length`
// 2 or 6 bytes for `mask`.
return (char*)(fecPacketData + 10 + 2 + getNumBytesUsedForMask());
}
uint16_t PacketFEC::getLevel0ProtectionLength() {
if (fecPacketData.size() == 0) {
FAIL_MSG("Cannot get the level 0 protection length. fecPacketData member is not set. Not initialized?");
return 0;
}
char* level0Header = getLevel0Header();
if (!level0Header) {
FAIL_MSG("Failed to get the level 0 header, cannot get protection length.");
return 0;
}
uint16_t protectionLength = (level0Header[0] << 8) | level0Header[1];
return protectionLength;
}
uint16_t PacketFEC::getLengthRecovery() {
char* fecHeader = getFECHeader();
if (!fecHeader) {
FAIL_MSG("Cannot get the FEC header which we need to get the `length recovery` field. Not initialized?");
return 0;
}
uint16_t lengthRecovery = (fecHeader[8] << 8) | fecHeader[9];
return lengthRecovery;
}
// Based on InsertFecPacket of forward_error_correction.cc from
// Chromium. (used as reference). The `mask` from the
// FEC-level-header can be 2 or 6 bytes long. Whenever a bit is
// set to 1 it means that we have to calculate the sequence
// number for that bit. To calculate the sequence number we
// start with the `SN base` value (base sequence number) and
// use the bit offset to increment the SN-base value. E.g.
// when it's bit 4 and SN-base is 230, it meas that this FEC
// packet protects the media packet with sequence number
// 230. We have to start counting the bit numbers from the
// most-significant-bit (e.g. 1 << 7).
bool PacketFEC::extractCoveringSequenceNumbers() {
if (coveredSeqNums.size() != 0) {
FAIL_MSG("Cannot extract protected sequence numbers; looks like we already did that.");
return false;
}
size_t maskNumBytes = getNumBytesUsedForMask();
if (maskNumBytes != 2 && maskNumBytes != 6) {
FAIL_MSG("Invalid mask size (%u) cannot extract sequence numbers.", maskNumBytes);
return false;
}
char* maskPtr = getLevel0Header();
if (!maskPtr) {
FAIL_MSG("Failed to get the level-0 header ptr. Cannot extract protected sequence numbers.");
return false;
}
uint16_t seqNumBase = getSequenceBaseNumber();
if (seqNumBase == 0) {
WARN_MSG("Base sequence number is 0; it's possible but unlikely.");
}
// Skip the `Protection Length`
maskPtr = maskPtr + 2;
for (uint16_t byteDX = 0; byteDX < maskNumBytes; ++byteDX) {
uint8_t maskByte = maskPtr[byteDX];
for (uint16_t bitDX = 0; bitDX < 8; ++bitDX) {
if (maskByte & (1 << 7 - bitDX)) {
uint16_t seqNum = seqNumBase + (byteDX << 3) + bitDX;
coveredSeqNums.insert(seqNum);
}
}
}
return true;
}
// \todo rename coversSequenceNumber
bool PacketFEC::coversSequenceNumber(uint16_t sn) {
return (coveredSeqNums.count(sn) == 0) ? false : true;
}
void PacketFEC::addReceivedSequenceNumber(uint16_t sn) {
if (false == coversSequenceNumber(sn)) {
FAIL_MSG("Trying to add a received sequence number this instance is not handling (%u).", sn);
return;
}
receivedSeqNums.insert(sn);
}
/// This function can be called to recover a missing packet. A
/// FEC packet is received with a list of media packets it
/// might be able to recover; this PacketFEC is received after
/// we should have received the media packets it's protecting.
///
/// Here we first fill al list with the received sequence
/// numbers that we're protecting; when we're missing one
/// packet this function will try to recover it.
///
/// The `receivedMediaPackets` is the history of media packets
/// that you received and keep in a memory. These are used
/// when XORing when we reconstruct a packet.
void PacketFEC::tryToRecoverMissingPacket(std::map<uint16_t, Packet>& receivedMediaPackets, Packet& reconstructedPacket) {
// Mark all the media packets that we protect and which have
// been received as "received" in our internal list.
std::set<uint16_t>::iterator protIt = coveredSeqNums.begin();
while (protIt != coveredSeqNums.end()) {
if (receivedMediaPackets.count(*protIt) == 1) {
addReceivedSequenceNumber(*protIt);
}
protIt++;
}
// We have received all media packets that we could recover;
// so there is no need for this FEC packet.
// @todo Jaron shall we reuse allocs/PacketFECs?
if (receivedSeqNums.size() == coveredSeqNums.size()) {
return;
}
if (coveredSeqNums.size() != (receivedSeqNums.size() + 1)) {
// missing more then 1 packet. we can only recover when
// one packet is lost.
return;
}
// Find missing sequence number.
uint16_t missingSeqNum = 0;
protIt = coveredSeqNums.begin();
while (protIt != coveredSeqNums.end()) {
if (receivedSeqNums.count(*protIt) == 0) {
missingSeqNum = *protIt;
break;
}
++protIt;
}
if (!coversSequenceNumber(missingSeqNum)) {
WARN_MSG("We cannot recover %u.", missingSeqNum);
return;
}
// Copy FEC into new RTP-header
char* fecHeader = getFECHeader();
if (!fecHeader) {
FAIL_MSG("Failed to get the fec header. Cannot recover.");
return;
}
recoverData.assign(NULL, 0);
recoverData.append(fecHeader, 12);
// Copy FEC into new RTP-payload
char* level0Payload = getLevel0Payload();
if (!level0Payload) {
FAIL_MSG("Failed to get the level-0 payload data (XOR'd media data from FEC packet).");
return;
}
uint16_t level0ProtLen = getLevel0ProtectionLength();
if (level0ProtLen == 0) {
FAIL_MSG("Failed to get the level-0 protection length.");
return;
}
recoverData.append(level0Payload, level0ProtLen);
uint8_t recoverLength[2] = { fecHeader[8], fecHeader[9] };
// XOR headers
protIt = coveredSeqNums.begin();
while (protIt != coveredSeqNums.end()) {
uint16_t seqNum = *protIt;
if (seqNum == missingSeqNum) {
++protIt;
continue;
}
Packet& mediaPacket = receivedMediaPackets[seqNum];
char* mediaData = mediaPacket.ptr();
uint16_t mediaSize = mediaPacket.getPayloadSize();
uint8_t* mediaSizePtr = (uint8_t*)&mediaSize;
WARN_MSG(" => XOR header with %u, size: %u.", seqNum, mediaSize);
// V, P, X, CC, M, PT
recoverData[0] ^= mediaData[0];
recoverData[1] ^= mediaData[1];
// Timestamp
recoverData[4] ^= mediaData[4];
recoverData[5] ^= mediaData[5];
recoverData[6] ^= mediaData[6];
recoverData[7] ^= mediaData[7];
// Length of recovered media packet
recoverLength[0] ^= mediaSizePtr[1];
recoverLength[1] ^= mediaSizePtr[0];
++protIt;
}
uint16_t recoverPayloadSize = ntohs(*(uint16_t*)recoverLength);
// XOR payloads
protIt = coveredSeqNums.begin();
while (protIt != coveredSeqNums.end()) {
uint16_t seqNum = *protIt;
if (seqNum == missingSeqNum) {
++protIt;
continue;
}
Packet& mediaPacket = receivedMediaPackets[seqNum];
char* mediaData = mediaPacket.ptr() + mediaPacket.getHsize();
for (size_t i = 0; i < recoverPayloadSize; ++i) {
recoverData[12 + i] ^= mediaData[i];
}
++protIt;
}
// And setup the reconstructed packet.
reconstructedPacket = Packet(recoverData, recoverPayloadSize);
reconstructedPacket.setSequence(missingSeqNum);
// @todo check what other header fields we need to fix.
}
void FECSorter::addPacket(const Packet &pack){
if (tmpVideoLossPrevention & SDP_LOSS_PREVENTION_ULPFEC) {
packetHistory[pack.getSequence()] = pack;
while (packetHistory.begin()->first < pack.getSequence() - 500){
packetHistory.erase(packetHistory.begin());
}
}
Sorter::addPacket(pack);
}
/// This function will handle RED packets that may be used to
/// encapsulate ULPFEC or simply the codec payload (e.g. H264,
/// VP8). This function is created to handle FEC with
/// WebRTC. When we want to use FEC with WebRTC we have to add
/// both the `a=rtpmap:<ulp-fmt> ulpfec/90000` and
/// `a=rtpmap<red-fmt> red/90000` lines to the SDP. FEC is
/// always used together with RED (RFC 1298). It turns out
/// that with WebRTC the RED only adds one byte after the RTP
/// header (only the `F` and `block PT`, see below)`. The
/// `block PT` is the payload type of the data that
/// follows. This would be `<ulp-fmt>` for FEC data. Though
/// these RED packets may contain FEC or just the media:
/// H264/VP8.
///
/// RED HEADER:
///
/// 0 1 2 3
/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |F| block PT | timestamp offset | block length |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
///
void FECSorter::addREDPacket(char* dat,
unsigned int len,
uint8_t codecPayloadType,
uint8_t REDPayloadType,
uint8_t ULPFECPayloadType)
{
RTP::Packet pkt(dat, len);
if (pkt.getPayloadType() != REDPayloadType) {
FAIL_MSG("Requested to add a RED packet, but it has an invalid payload type.");
return;
}
// Check if the `F` flag is set. Chromium will always set
// this to 0 (at time of writing, check: https://goo.gl/y1eJ6k
uint8_t* REDHeader = (uint8_t*)(dat + pkt.getHsize());
uint8_t moreBlocksAvailable = REDHeader[0] & 0x80;
if (moreBlocksAvailable == 1) {
FAIL_MSG("Not yet received a RED packet that had it's F bit set; @todo implement.");
exit(EXIT_FAILURE);
return;
}
// Extract the `block PT` field which can be the media-pt,
// fec-pt. When it's just media that follows, we move all
// data one byte up and reconstruct a normal media packet.
uint8_t blockPayloadType = REDHeader[0] & 0x7F;
if (blockPayloadType == codecPayloadType) {
memmove(dat + pkt.getHsize(), dat + pkt.getHsize() + 1, len - pkt.getHsize() - 1);
dat[1] &= 0x80;
dat[1] |= codecPayloadType;
RTP::Packet mediaPacket((const char*)dat, len -1);
addPacket(mediaPacket);
return;
}
// When the payloadType equals our ULP/FEC payload type, we
// received a REC packet (RFC 5109) that contains FEC data
// and a list of sequence number that it covers and can
// reconstruct.
//
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//
// \todo Jaron, I'm now just generating a `PacketFEC` on the heap
// and we're not managing destruction anywhere atm; I guess
// re-use or destruction needs to be part of the algo that
// is going to deal with FEC.
//
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
if (blockPayloadType == ULPFECPayloadType) {
WARN_MSG(" => got fec packet: %u", pkt.getSequence());
PacketFEC* fec = new PacketFEC();
if (!fec->initWithREDPacket(dat, len)) {
delete fec;
fec = NULL;
FAIL_MSG("Failed to initialize a `PacketFEC`");
}
fecPackets.push_back(fec);
Packet recreatedPacket;
fec->tryToRecoverMissingPacket(packetHistory, recreatedPacket);
if (recreatedPacket.ptr() != NULL) {
char* pl = recreatedPacket.getPayload();
WARN_MSG(" => reconstructed %u, %02X %02X %02X %02X | %02X %02X %02X %02X", recreatedPacket.getSequence(), pl[0], pl[1], pl[2], pl[3], pl[4], pl[5], pl[6], pl[7]);
addPacket(recreatedPacket);
}
return;
}
FAIL_MSG("Unhandled RED block payload type %u. Check the answer SDP.", blockPayloadType);
}
/// Each FEC packet is capable of recovering a limited amount
/// of media packets. Some experimentation showed that most
/// often one FEC is used to protect somewhere between 2-10
/// media packets. Each FEC packet has a list of sequence
/// number that it can recover when all other media packets
/// have been received except the one that we want to
/// recover. This function returns the FEC packet might be able
/// to recover the given sequence number.
PacketFEC* FECSorter::getFECPacketWhichCoversSequenceNumber(uint16_t sn) {
size_t nfecs = fecPackets.size();
for (size_t i = 0; i < nfecs; ++i) {
PacketFEC* fec = fecPackets[i];
if (fec->coversSequenceNumber(sn)) {
return fec;
}
}
return NULL;
}
void FECPacket::sendRTCP_RR(RTP::FECSorter &sorter, uint32_t mySSRC, uint32_t theirSSRC, void* userData, void callBack(void* userData, const char* payload, uint32_t nbytes)) {
char *rtcpData = (char *)malloc(32);
if (!rtcpData){
FAIL_MSG("Could not allocate 32 bytes. Something is seriously messed up.");
return;
}
if (!(sorter.lostCurrent + sorter.packCurrent)){sorter.packCurrent++;}
rtcpData[0] = 0x81; // version 2, no padding, one receiver report
rtcpData[1] = 201; // receiver report
Bit::htobs(rtcpData + 2, 7); // 7 4-byte words follow the header
Bit::htobl(rtcpData + 4, mySSRC); // set receiver identifier
Bit::htobl(rtcpData + 8, theirSSRC); // set source identifier
rtcpData[12] =
(sorter.lostCurrent * 255) /
(sorter.lostCurrent + sorter.packCurrent); // fraction lost since prev RR
Bit::htob24(rtcpData + 13, sorter.lostTotal); // cumulative packets lost since start
Bit::htobl(rtcpData + 16, sorter.rtpSeq | (sorter.packTotal &
0xFFFF0000ul)); // highest sequence received
Bit::htobl(rtcpData + 20, 0); /// \TODO jitter (diff in timestamp vs packet arrival)
Bit::htobl(rtcpData + 24, 0); /// \TODO last SR (middle 32 bits of last SR or zero)
Bit::htobl(rtcpData + 28, 0); /// \TODO delay since last SR in 2b seconds + 2b fraction
callBack(userData, rtcpData, 32);
sorter.lostCurrent = 0;
sorter.packCurrent = 0;
free(rtcpData);
}
}

100
lib/rtp_fec.h Normal file
View file

@ -0,0 +1,100 @@
#pragma once
#include "rtp.h"
#include "sdp_media.h"
#include "util.h"
#include <set>
namespace RTP{
/// Util class that can be used to retrieve information from a
/// FEC packet. A FEC packet is contains recovery data. This
/// data can be used to reconstruct a media packet. This class
/// was created and tested for the WebRTC implementation where
/// each FEC packet is encapsulated by a RED packet (RFC 1298).
/// A RED packet may contain ordinary payload data -or- FEC
/// data (RFC 5109). We assume that the data given into
/// `initWithREDPacket()` contains FEC data and you did a bit
/// of parsing to figure this out: by checking if the `block
/// PT` from the RED header is the ULPFEC payload type; if so
/// this PacketFEC class can be used.
class PacketFEC{
public:
PacketFEC();
~PacketFEC();
bool initWithREDPacket(
const char *data,
size_t nbytes); /// Initialize using the given data. `data` must point to the first byte of
/// the RTP packet which contains the RED and FEC headers and data.
uint8_t getExtensionFlag(); ///< From fec header: should be 0, see
///< https://tools.ietf.org/html/rfc5109#section-7.3.
uint8_t
getLongMaskFlag(); ///< From fec header: returns 0 when the short mask version is used (16
///< bits), otherwise 1 (48 bits). The mask is used to calculate what
///< sequence numbers are protected, starting at the base sequence number.
uint16_t getSequenceBaseNumber(); ///< From fec header: get the base sequence number. The base
///< sequence number is used together with the mask to
///< determine what the sequence numbers of the media packets
///< are that the fec data protects.
uint8_t getNumBytesUsedForMask(); ///< From fec level header: a fec packet can protected up to
///< 48 media packets. Which sequence numbers are stored using
///< a mask bit string. This returns either 2 or 6.
char *getLevel0Header(); ///< Get a pointer to the start of the fec-level-0 header (contains the
///< protection-length and mask)
char *getLevel0Payload(); /// < Get a pointer to the actual FEC data. This is the XOR'd header
/// and paylaod.
char *getFECHeader(); ///< Get a pointer to the first byte of the FEC header.
uint16_t getLevel0ProtectionLength(); ///< Get the length of the `getLevel0Payload()`.
uint16_t
getLengthRecovery(); ///< Get the `length recovery` value (Little Endian). This value is used
///< while XORing to recover the length of the missing media packet.
bool coversSequenceNumber(uint16_t sn); ///< Returns true when this `PacketFEC` instance is used
///< to protect the given sequence number.
void
addReceivedSequenceNumber(uint16_t sn); ///< Whenever you receive a media packet (complete) call
///< this as we need to know if enough media packets
///< exist that are needed to recover another one.
void tryToRecoverMissingPacket(
std::map<uint16_t, RTP::Packet> &receivedMediaPackets,
Packet &reconstructedPacket); ///< Pass in a `std::map` indexed by sequence number of -all-
///< the media packets that you keep as history. When this
///< `PacketFEC` is capable of recovering a media packet it
///< will fill the packet passed by reference.
private:
bool
extractCoveringSequenceNumbers(); ///< Used internally to fill the `coveredSeqNums` member which
///< tell us what media packets this FEC packet rotects.
public:
Util::ResizeablePointer fecPacketData;
Util::ResizeablePointer recoverData;
std::set<uint16_t>
coveredSeqNums; ///< The sequence numbers of the packets that this FEC protects.
std::set<uint16_t>
receivedSeqNums; ///< We keep track of sequence numbers that were received (at some higher
///< level). We can only recover 1 media packet and this is used to check
///< if this `PacketFEC` instance is capable of recovering anything.
};
class FECSorter : public Sorter{
public:
void addPacket(const Packet &pack);
void addREDPacket(char *dat, unsigned int len, uint8_t codecPayloadType, uint8_t REDPayloadType,
uint8_t ULPFECPayloadType);
PacketFEC *getFECPacketWhichCoversSequenceNumber(uint16_t sn);
uint8_t tmpVideoLossPrevention; ///< TMP used to drop packets for FEC; see output_webrtc.cpp
///< `handleSignalingCommandRemoteOfferForInput()`. This
///< variable should be rmeoved when cleaning up.
private:
std::map<uint16_t, Packet> packetHistory;
std::vector<PacketFEC *> fecPackets;
};
class FECPacket : public Packet{
public:
void sendRTCP_RR(RTP::FECSorter &sorter, uint32_t mySSRC, uint32_t theirSSRC, void *userData,
void callBack(void *userData, const char *payload, uint32_t nbytes));
};
}// namespace RTP

1184
lib/sdp_media.cpp Normal file

File diff suppressed because it is too large Load diff

220
lib/sdp_media.h Normal file
View file

@ -0,0 +1,220 @@
#pragma once
#include <string>
#include <map>
#include <set>
#include <vector>
#include "dtsc.h"
#define SDP_PAYLOAD_TYPE_NONE 9999 /// Define that is used to indicate a payload type is not set.
#define SDP_LOSS_PREVENTION_NONE 0
#define SDP_LOSS_PREVENTION_NACK (1 << 1) /// Use simple NACK based loss prevention. (e.g. send a NACK to pusher of video stream when a packet is lost)
#define SDP_LOSS_PREVENTION_ULPFEC (1 << 2) /// Use FEC (See rtp.cpp, PacketRED). When used we try to add the correct `a=rtpmap` for RED and ULPFEC to the SDP when supported by the offer.
namespace SDP{
/// A MediaFormat stores the infomation that is specific for an
/// encoding. With RTSP there is often just one media format
/// per media line. Though with WebRTC, where an SDP is used to
/// determine a common capability, one media line can contain
/// different formats. These formats are indicated by with the
/// <fmt> attribute of the media line. For each <fmt> there may
/// be one or more custom properties. For each property, like
/// the `encodingName` (e.g. VP8, VP9, H264, etc). we create a
/// new `SDP::MediaFormat` object and store it in the `formats`
/// member of `SDP::Media`.
///
/// When you want to retrieve some specific data and there is a
/// getter function defined for it, you SHOULD use this
/// function as these functions add some extra logic based on
/// the set members.
class MediaFormat{
public:
MediaFormat();
std::string getFormatParameterForName(
const std::string &name) const; ///< Get a parameter which was part of the `a=fmtp:` line.
uint32_t getAudioSampleRate()
const; ///< Returns the audio sample rate. When `audioSampleRate` has been set this will be
///< returned, otherwise we use the `payloadType` to determine the samplerate or
///< return 0 when we fail to determine to samplerate.
uint32_t getAudioNumChannels()
const; ///< Returns the number of audio channels. When `audioNumChannels` has been set this
///< will be returned, otherwise we use the `payloadType` when it's set to determine
///< the samplerate or we return 0 when we can't determine the number of channels.
uint32_t getAudioBitSize()
const; ///< Returns the audio bitsize. When `audioBitSize` has been set this will be
///< returned, othwerise we use the `encodingName` to determine the right
///< `audioBitSize` or 0 when we can't determine the `audioBitSize`
uint32_t
getVideoRate() const; ///< Returns the video time base. When `videoRate` has been set this will
///< be returned, otherwise we use the `encodingName` to determine the
///< right value or 0 when we can't determine the video rate.
uint32_t getVideoOrAudioRate() const; ///< Returns whichever rate has been set.
uint64_t getPayloadType() const; ///< Returns the `payloadType` member.
int32_t
getPacketizationModeForH264(); ///< When this represents a h264 format this will return the
///< packetization mode when it was provided in the SDP
std::string getProfileLevelIdForH264(); ///< When this represents a H264 format, this will return the profile-level-id from the format parameters.
operator bool() const;
public:
uint64_t payloadType; ///< The payload type as set in the media line (the <fmt> is -the-
///< payloadType).
uint64_t associatedPayloadType; ///< From `a=fmtp:<pt> apt=<apt>`; maps this format to another payload type.
int32_t audioSampleRate; ///< Samplerate of the audio type.
int32_t audioNumChannels; ///< Number of audio channels extracted from the `a=fmtp` or set in
///< `setDefaultsForPayloadType()`.
int32_t audioBitSize; ///< Number of bits used in case this is an audio type 8, 16, set in
///< `setDefaultsForCodec()` and `setDefaultsForPayloadType()`.
int32_t videoRate; ///< Video framerate, e.g. 9000
std::string encodingName; ///< Stores the UPPERCASED encoding name from the `a=rtpmap:<payload
///< type> <encoding name>
std::string iceUFrag; ///< From `a=ice-ufrag:<ufrag>, used with WebRTC / STUN.
std::string icePwd; ///< From `a=ice-pwd:<pwd>`, used with WebRTC / STUN
std::string rtpmap; ///< The `a=<rtpmap:...> value; value between brackets.
std::map<std::string, std::string>
formatParameters; ///< Stores the var-val pairs from `a=fmtp:<fmt>` entry e.g. =
///< `packetization-mode=1;profile-level-id=4d0029;sprop-parameter-sets=Z00AKeKQCADDYC3AQEBpB4kRUA==,aO48gA==`
///< */
std::set<std::string>
rtcpFormats; ///< Stores the `fb-val` from the line with `a=rtcp-fb:<fmt> <fb-val>`.
};
class Media{
public:
Media();
bool parseMediaLine(const std::string &sdpLine); ///< Parses `m=` line. Creates a `MediaFormat`
///< entry for each of the found <fmt> values.
bool parseRtpMapLine(
const std::string &sdpLine); ///< Parses `a=rtpmap:` line which contains the some codec
///< specific info. When this line contains the samplerate and
///< number of audio channels they will be extracted.
bool parseRtspControlLine(const std::string &sdpLine); ///< Parses `a=control:`
bool parseFrameRateLine(const std::string &sdpLine); ///< Parses `a=framerate:`
bool parseFormatParametersLine(const std::string &sdpLine); ///< Parses `a=fmtp:<payload-type>`.
bool parseRtcpFeedbackLine(
const std::string &sdpLine); ///< Parses `a=rtcp-fb:<payload-type>`. See RFC4584
bool parseFingerprintLine(
const std::string
&sdpLine); ///< Parses `a=fingerprint:<hash-func> <value>`. See
///< https://tools.ietf.org/html/rfc8122#section-5, used with WebRTC
bool parseSSRCLine(const std::string &sdpLine); ///< Parses `a=ssrc:<ssrc>`.
MediaFormat *getFormatForSdpLine(
const std::string
&sdpLine); ///< Returns the track to which this SDP line applies. This means that the
///< SDP line should be formatteed like: `a=something:[payloadtype]`.
MediaFormat *getFormatForPayloadType(
uint64_t &payloadType); ///< Finds the `MediaFormat` in `formats`. Returns NULL when no
///< format was found for the given payload type. .
MediaFormat *getFormatForEncodingName(
const std::string
&encName); ///< Finds the `MediaFormats in `formats`. Returns NULL when no format was
///< found for the given encoding name. E.g. `VP8`, `VP9`, `H264`
std::vector<SDP::MediaFormat *> getFormatsForEncodingName(const std::string &encName);
std::string getIcePwdForFormat(
const MediaFormat
&fmt); ///< The `a=ice-pwd` can be session global or media specific. This function will
///< check if the `SDP::MediaFormat` has a ice-pwd that we should use.
uint32_t
getSSRC() const; ///< Returns the first SSRC `a=ssrc:<value>` value found for the media.
operator bool() const;
MediaFormat* getRetransMissionFormatForPayloadType(uint64_t pt); ///< When available, it resurns the RTX format that is directly associated with the media (not encapsulated with a RED header). RTX can be combined with FEC in which case it's supposed to be stored in RED packets. The `encName` should be something like H264,VP8; e.g. the format for which you want to get the RTX format.
public:
std::string type; ///< The `media` field of the media line: `m=<media> <port> <proto> <fmt>`,
///< like "video" or "audio"
std::string proto; ///< The `proto` field of the media line: `m=<media> <port> <proto> <fmt>`,
///< like "video" or "audio"
std::string control; ///< From `a=control:` The RTSP control url.
std::string direction; ///< From `a=sendonly`, `a=recvonly` and `a=sendrecv`
std::string iceUFrag; ///< From `a=ice-ufrag:<ufrag>, used with WebRTC / STUN.
std::string icePwd; ///< From `a=ice-pwd:<pwd>`, used with WebRTC / STUN
std::string setupMethod; ///< From `a=setup:<passive, active, actpass>, used with WebRTC / STUN
std::string fingerprintHash; ///< From `a=fingerprint:<hash> <value>`, e.g. sha-256, used with
///< WebRTC / STUN
std::string
fingerprintValue; ///< From `a=fingerprint:<hash> <value>`, the actual fingerprint, used
///< with WebRTC / STUN, see https://tools.ietf.org/html/rfc8122#section-5
std::string mediaID; ///< From `a=mid:<value>`. When generating an WebRTC answer this value must
///< be the same as in the offer.
std::string candidateIP; ///< Used when we generate a WebRTC answer.
uint16_t candidatePort; ///< Used when we generate a WebRTC answer.
uint32_t SSRC; ///< From `a=ssrc:<SSRC> <something>`; the first SSRC that we encountered.
double framerate; ///< From `a=framerate`.
bool supportsRTCPMux; ///< From `a=rtcp-mux`, indicates if it can mux RTP and RTCP on one
///< transport channel.
bool supportsRTCPReducedSize; ///< From `a=rtcp-rsize`, reduced size RTCP packets.
std::string
payloadTypes; ///< From `m=` line, all the payload types as string, separated by space.
std::map<uint64_t, MediaFormat>
formats; ///< Formats indexed by payload type. Payload type is the number in the <fmt>
///< field(s) from the `m=` line.
};
class Session{
public:
bool parseSDP(const std::string &sdp);
Media *getMediaForType(
const std::string &type); ///< Get a `SDP::Media*` for the given type, e.g. `video` or
///< `audio`. Returns NULL when the type was not found.
MediaFormat *getMediaFormatByEncodingName(const std::string &mediaType,
const std::string &encodingName);
bool hasReceiveOnlyMedia(); ///< Returns true when one of the media sections has a `a=recvonly`
///< attribute. This is used to determine if the other peer only
///< wants to receive or also sent data. */
public:
std::vector<SDP::Media> medias; ///< For each `m=` line we create a `SDP::Media` instance. The
///< stream specific infomration is stored in a `MediaFormat`
std::string icePwd; ///< From `a=ice-pwd`, this property can be session-wide or media specific.
///< Used with WebRTC and STUN when calculating the message-integrity.
std::string
iceUFrag; ///< From `a=ice-ufag`, this property can be session-wide or media specific. Used
///< with WebRTC and STUN when calculating the message-integrity.
};
class Answer{
public:
Answer();
bool parseOffer(const std::string &sdp);
bool hasVideo(); ///< Check if the offer has video.
bool hasAudio(); ///< Check if the offer has audio.
bool enableVideo(const std::string &codecName);
bool enableAudio(const std::string &codecName);
void setCandidate(const std::string &ip, uint16_t port);
void setFingerprint(const std::string &fingerprintSha); ///< Set the SHA265 that represents the
///< certificate that is used with DTLS.
void setDirection(const std::string &dir);
bool setupVideoDTSCTrack(DTSC::Track &result);
bool setupAudioDTSCTrack(DTSC::Track &result);
std::string toString();
private:
bool enableMedia(const std::string &type, const std::string &codecName, SDP::Media &outMedia,
SDP::MediaFormat &outFormat);
void addLine(const std::string &fmt, ...);
std::string generateSessionId();
std::string generateIceUFrag(); ///< Generates the `ice-ufrag` value.
std::string generateIcePwd(); ///< Generates the `ice-pwd` value.
std::string generateRandomString(const int len);
std::vector<std::string> splitString(const std::string &str, char delim);
public:
SDP::Session sdpOffer;
SDP::Media answerVideoMedia;
SDP::Media answerAudioMedia;
SDP::MediaFormat answerVideoFormat;
SDP::MediaFormat answerAudioFormat;
bool isAudioEnabled;
bool isVideoEnabled;
std::string candidateIP; ///< We use rtcp-mux and BUNDLE; so only one candidate necessary.
uint16_t candidatePort; ///< We use rtcp-mux and BUNDLE; so only one candidate necessary.
std::string fingerprint;
std::string direction; ///< The direction used when generating the answer SDP string.
std::vector<std::string> output; ///< The lines that are used when adding lines (see `addLine()`
///< for the answer sdp.).
uint8_t videoLossPrevention; ///< See the SDP_LOSS_PREVENTION_* values at the top of this header.
};
}

422
lib/srtp.cpp Normal file
View file

@ -0,0 +1,422 @@
#include <algorithm>
#include "defines.h"
#include "srtp.h"
/* --------------------------------------- */
static std::string srtp_status_to_string(uint32_t status);
/* --------------------------------------- */
SRTPReader::SRTPReader() {
memset((void*)&session, 0x00, sizeof(session));
memset((void*)&policy, 0x00, sizeof(policy));
}
/*
Before initializing the srtp library we shut it down first
because initializing the library twice results in an error.
*/
int SRTPReader::init(const std::string& cipher, const std::string& key, const std::string& salt) {
int r = 0;
srtp_err_status_t status = srtp_err_status_ok;
srtp_profile_t profile;
memset((void*) &profile, 0x00, sizeof(profile));
/* validate input */
if (cipher.empty()) {
FAIL_MSG("Given `cipher` is empty.");
r = -1;
goto error;
}
if (key.empty()) {
FAIL_MSG("Given `key` is empty.");
r = -2;
goto error;
}
if (salt.empty()) {
FAIL_MSG("Given `salt` is empty.");
r = -3;
goto error;
}
/* re-initialize the srtp library. */
status = srtp_shutdown();
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to shutdown the srtp lib %s", srtp_status_to_string(status).c_str());
r = -1;
goto error;
}
status = srtp_init();
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to initialize the SRTP library. %s", srtp_status_to_string(status).c_str());
r = -2;
goto error;
}
/* select the right profile from exchanged cipher */
if ("SRTP_AES128_CM_SHA1_80" == cipher) {
profile = srtp_profile_aes128_cm_sha1_80;
}
else if ("SRTP_AES128_CM_SHA1_32" == cipher) {
profile = srtp_profile_aes128_cm_sha1_32;
}
else {
ERROR_MSG("Unsupported SRTP cipher used: %s.", cipher.c_str());
r = -2;
goto error;
}
/* set the crypto policy using the profile. */
status = srtp_crypto_policy_set_from_profile_for_rtp(&policy.rtp, profile);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to set the crypto policy for RTP for cipher %s.", cipher.c_str());
r = -3;
goto error;
}
status = srtp_crypto_policy_set_from_profile_for_rtcp(&policy.rtcp, profile);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to set the crypto policy for RTCP for cipher %s.", cipher.c_str());
r = -4;
goto error;
}
/* set the keying material. */
std::copy(key.begin(), key.end(), std::back_inserter(key_salt));
std::copy(salt.begin(), salt.end(), std::back_inserter(key_salt));
policy.key = (unsigned char*)&key_salt[0];
/* only unprotecting data for now, so using inbound; and some other settings. */
policy.ssrc.type = ssrc_any_inbound;
policy.window_size = 1024;
policy.allow_repeat_tx = 1;
/* create the srtp session. */
status = srtp_create(&session, &policy);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to initialize our SRTP session. Status: %s. ", srtp_status_to_string(status).c_str());
r = -3;
goto error;
}
error:
if (r < 0) {
shutdown();
}
return r;
}
int SRTPReader::shutdown() {
int r = 0;
srtp_err_status_t status = srtp_dealloc(session);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to cleanly shutdown the SRTP session. Status: %s", srtp_status_to_string(status).c_str());
r -= 5;
}
memset((void*)&policy, 0x00, sizeof(policy));
memset((char*)&session, 0x00, sizeof(session));
return r;
}
/* --------------------------------------- */
int SRTPReader::unprotectRtp(uint8_t* data, int* nbytes) {
if (NULL == data) {
ERROR_MSG("Cannot unprotect the given SRTP, because data is NULL.");
return -1;
}
if (NULL == nbytes) {
ERROR_MSG("Cannot unprotect the given SRTP, becuase nbytes is NULL.");
return -2;
}
if (0 == (*nbytes)) {
ERROR_MSG("Cannot unprotect the given SRTP, because nbytes is 0.");
return -3;
}
if (NULL == policy.key) {
ERROR_MSG("Cannot unprotect the SRTP packet, it seems we're not initialized.");
return -4;
}
srtp_err_status_t status = srtp_unprotect(session, data, nbytes);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to unprotect the given SRTP. %s.", srtp_status_to_string(status).c_str());
return -5;
}
DONTEVEN_MSG("Unprotected SRTP into %d bytes.", *nbytes);
return 0;
}
int SRTPReader::unprotectRtcp(uint8_t* data, int* nbytes) {
if (NULL == data) {
ERROR_MSG("Cannot unprotect the given SRTCP, because data is NULL.");
return -1;
}
if (NULL == nbytes) {
ERROR_MSG("Cannot unprotect the given SRTCP, becuase nbytes is NULL.");
return -2;
}
if (0 == (*nbytes)) {
ERROR_MSG("Cannot unprotect the given SRTCP, because nbytes is 0.");
return -3;
}
if (NULL == policy.key) {
ERROR_MSG("Cannot unprotect the SRTCP packet, it seems we're not initialized.");
return -4;
}
srtp_err_status_t status = srtp_unprotect_rtcp(session, data, nbytes);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to unprotect the given SRTCP. %s.", srtp_status_to_string(status).c_str());
return -5;
}
return 0;
}
/* --------------------------------------- */
SRTPWriter::SRTPWriter() {
memset((void*)&session, 0x00, sizeof(session));
memset((void*)&policy, 0x00, sizeof(policy));
}
/*
Before initializing the srtp library we shut it down first
because initializing the library twice results in an error.
*/
int SRTPWriter::init(const std::string& cipher, const std::string& key, const std::string& salt) {
int r = 0;
srtp_err_status_t status = srtp_err_status_ok;
srtp_profile_t profile;
memset((void*)&profile, 0x00, sizeof(profile));
/* validate input */
if (cipher.empty()) {
FAIL_MSG("Given `cipher` is empty.");
r = -1;
goto error;
}
if (key.empty()) {
FAIL_MSG("Given `key` is empty.");
r = -2;
goto error;
}
if (salt.empty()) {
FAIL_MSG("Given `salt` is empty.");
r = -3;
goto error;
}
/* re-initialize the srtp library. */
status = srtp_shutdown();
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to shutdown the srtp lib %s", srtp_status_to_string(status).c_str());
r = -1;
goto error;
}
status = srtp_init();
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to initialize the SRTP library. %s", srtp_status_to_string(status).c_str());
r = -2;
goto error;
}
/* select the exchanged cipher */
if ("SRTP_AES128_CM_SHA1_80" == cipher) {
profile = srtp_profile_aes128_cm_sha1_80;
}
else if ("SRTP_AES128_CM_SHA1_32" == cipher) {
profile = srtp_profile_aes128_cm_sha1_32;
}
else {
ERROR_MSG("Unsupported SRTP cipher used: %s.", cipher.c_str());
r = -2;
goto error;
}
/* set the crypto policy using the profile. */
status = srtp_crypto_policy_set_from_profile_for_rtp(&policy.rtp, profile);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to set the crypto policy for RTP for cipher %s.", cipher.c_str());
r = -3;
goto error;
}
status = srtp_crypto_policy_set_from_profile_for_rtcp(&policy.rtcp, profile);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to set the crypto policy for RTCP for cipher %s.", cipher.c_str());
r = -4;
goto error;
}
/* set the keying material. */
std::copy(key.begin(), key.end(), std::back_inserter(key_salt));
std::copy(salt.begin(), salt.end(), std::back_inserter(key_salt));
policy.key = (unsigned char*)&key_salt[0];
/* only unprotecting data for now, so using inbound; and some other settings. */
policy.ssrc.type = ssrc_any_outbound;
policy.window_size = 128;
policy.allow_repeat_tx = 0;
/* create the srtp session. */
status = srtp_create(&session, &policy);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to initialize our SRTP session. Status: %s. ", srtp_status_to_string(status).c_str());
r = -3;
goto error;
}
error:
if (r < 0) {
shutdown();
}
return r;
}
int SRTPWriter::shutdown() {
int r = 0;
srtp_err_status_t status = srtp_dealloc(session);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to cleanly shutdown the SRTP session. Status: %s", srtp_status_to_string(status).c_str());
r -= 5;
}
memset((char*)&policy, 0x00, sizeof(policy));
memset((char*)&session, 0x00, sizeof(session));
return r;
}
/* --------------------------------------- */
int SRTPWriter::protectRtp(uint8_t* data, int* nbytes) {
if (NULL == data) {
ERROR_MSG("Cannot protect the RTP packet because given data is NULL.");
return -1;
}
if (NULL == nbytes) {
ERROR_MSG("Cannot protect the RTP packet because the given nbytes is NULL.");
return -2;
}
if ((*nbytes) <= 0) {
ERROR_MSG("Cannot protect the RTP packet because the given nbytes has a value <= 0.");
return -3;
}
if (NULL == policy.key) {
ERROR_MSG("Cannot protect the RTP packet because we're not initialized.");
return -4;
}
srtp_err_status_t status = srtp_protect(session, (void*)data, nbytes);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to protect the RTP packet. %s.", srtp_status_to_string(status).c_str());
return -5;
}
return 0;
}
/*
Make sure that `data` has `SRTP_MAX_TRAILER_LEN + 4` number
of bytes at the into which libsrtp can write the
authentication tag
*/
int SRTPWriter::protectRtcp(uint8_t* data, int* nbytes) {
if (NULL == data) {
ERROR_MSG("Cannot protect the RTCP packet because given data is NULL.");
return -1;
}
if (NULL == nbytes) {
ERROR_MSG("Cannot protect the RTCP packet because nbytes is NULL.");
return -2;
}
if ((*nbytes) <= 0) {
ERROR_MSG("Cannot protect the RTCP packet because *nbytes is <= 0.");
return -3;
}
if (NULL == policy.key) {
ERROR_MSG("Not initialized cannot protect the RTCP packet.");
return -4;
}
srtp_err_status_t status = srtp_protect_rtcp(session, (void*)data, nbytes);
if (srtp_err_status_ok != status) {
ERROR_MSG("Failed to protect the RTCP packet. %s.", srtp_status_to_string(status).c_str());
return -3;
}
return 0;
}
/* --------------------------------------- */
static std::string srtp_status_to_string(uint32_t status) {
switch (status) {
case srtp_err_status_ok: { return "srtp_err_status_ok"; }
case srtp_err_status_fail: { return "srtp_err_status_fail"; }
case srtp_err_status_bad_param: { return "srtp_err_status_bad_param"; }
case srtp_err_status_alloc_fail: { return "srtp_err_status_alloc_fail"; }
case srtp_err_status_dealloc_fail: { return "srtp_err_status_dealloc_fail"; }
case srtp_err_status_init_fail: { return "srtp_err_status_init_fail"; }
case srtp_err_status_terminus: { return "srtp_err_status_terminus"; }
case srtp_err_status_auth_fail: { return "srtp_err_status_auth_fail"; }
case srtp_err_status_cipher_fail: { return "srtp_err_status_cipher_fail"; }
case srtp_err_status_replay_fail: { return "srtp_err_status_replay_fail"; }
case srtp_err_status_replay_old: { return "srtp_err_status_replay_old"; }
case srtp_err_status_algo_fail: { return "srtp_err_status_algo_fail"; }
case srtp_err_status_no_such_op: { return "srtp_err_status_no_such_op"; }
case srtp_err_status_no_ctx: { return "srtp_err_status_no_ctx"; }
case srtp_err_status_cant_check: { return "srtp_err_status_cant_check"; }
case srtp_err_status_key_expired: { return "srtp_err_status_key_expired"; }
case srtp_err_status_socket_err: { return "srtp_err_status_socket_err"; }
case srtp_err_status_signal_err: { return "srtp_err_status_signal_err"; }
case srtp_err_status_nonce_bad: { return "srtp_err_status_nonce_bad"; }
case srtp_err_status_read_fail: { return "srtp_err_status_read_fail"; }
case srtp_err_status_write_fail: { return "srtp_err_status_write_fail"; }
case srtp_err_status_parse_err: { return "srtp_err_status_parse_err"; }
case srtp_err_status_encode_err: { return "srtp_err_status_encode_err"; }
case srtp_err_status_semaphore_err: { return "srtp_err_status_semaphore_err"; }
case srtp_err_status_pfkey_err: { return "srtp_err_status_pfkey_err"; }
case srtp_err_status_bad_mki: { return "srtp_err_status_bad_mki"; }
case srtp_err_status_pkt_idx_old: { return "srtp_err_status_pkt_idx_old"; }
case srtp_err_status_pkt_idx_adv: { return "srtp_err_status_pkt_idx_adv"; }
default: { return "UNKNOWN"; }
}
}
/* --------------------------------------- */

43
lib/srtp.h Normal file
View file

@ -0,0 +1,43 @@
#pragma once
#include <stdint.h>
#include <string>
#include <srtp2/srtp.h>
#define SRTP_PARSER_MASTER_KEY_LEN 16
#define SRTP_PARSER_MASTER_SALT_LEN 14
#define SRTP_PARSER_MASTER_LEN (SRTP_PARSER_MASTER_KEY_LEN + SRTP_PARSER_MASTER_SALT_LEN)
/* --------------------------------------- */
class SRTPReader {
public:
SRTPReader();
int init(const std::string& cipher, const std::string& key, const std::string& salt);
int shutdown();
int unprotectRtp(uint8_t* data, int* nbytes); /* `nbytes` should contain the number of bytes in `data`. On success `nbytes` will hold the number of bytes of the decoded RTP packet. */
int unprotectRtcp(uint8_t* data, int* nbytes); /* `nbytes` should contains the number of bytes in `data`. On success `nbytes` will hold the number of bytes the decoded RTCP packet. */
private:
srtp_t session;
srtp_policy_t policy;
std::vector<uint8_t> key_salt; /* Combination of key + salt which is used to unprotect the SRTP/SRTCP data. */
};
/* --------------------------------------- */
class SRTPWriter {
public:
SRTPWriter();
int init(const std::string& cipher, const std::string& key, const std::string& salt);
int shutdown();
int protectRtp(uint8_t* data, int* nbytes);
int protectRtcp(uint8_t* data, int* nbytes);
private:
srtp_t session;
srtp_policy_t policy;
std::vector<uint8_t> key_salt; /* Combination of key + salt which is used to protect the SRTP/SRTCP data. */
};
/* --------------------------------------- */

1051
lib/stun.cpp Normal file

File diff suppressed because it is too large Load diff

250
lib/stun.h Normal file
View file

@ -0,0 +1,250 @@
#pragma once
#include <stdint.h>
#include <netinet/in.h>
#include <vector>
#include <string>
/* --------------------------------------- */
#define STUN_IP4 0x01
#define STUN_IP6 0x02
#define STUN_MSG_TYPE_NONE 0x0000
#define STUN_MSG_TYPE_BINDING_REQUEST 0x0001
#define STUN_MSG_TYPE_BINDING_RESPONSE_SUCCESS 0x0101
#define STUN_MSG_TYPE_BINDING_RESPONSE_ERROR 0x0111
#define STUN_MSG_TYPE_BINDING_INDICATION 0x0011
#define STUN_ATTR_TYPE_NONE 0x0000
#define STUN_ATTR_TYPE_MAPPED_ADDR 0x0001
#define STUN_ATTR_TYPE_CHANGE_REQ 0x0003
#define STUN_ATTR_TYPE_USERNAME 0x0006
#define STUN_ATTR_TYPE_MESSAGE_INTEGRITY 0x0008
#define STUN_ATTR_TYPE_ERR_CODE 0x0009
#define STUN_ATTR_TYPE_UNKNOWN_ATTRIBUTES 0x000a
#define STUN_ATTR_TYPE_CHANNEL_NUMBER 0x000c
#define STUN_ATTR_TYPE_LIFETIME 0x000d
#define STUN_ATTR_TYPE_XOR_PEER_ADDR 0x0012
#define STUN_ATTR_TYPE_DATA 0x0013
#define STUN_ATTR_TYPE_REALM 0x0014
#define STUN_ATTR_TYPE_NONCE 0x0015
#define STUN_ATTR_TYPE_XOR_RELAY_ADDRESS 0x0016
#define STUN_ATTR_TYPE_REQ_ADDRESS_FAMILY 0x0017
#define STUN_ATTR_TYPE_EVEN_PORT 0x0018
#define STUN_ATTR_TYPE_REQUESTED_TRANSPORT 0x0019
#define STUN_ATTR_TYPE_DONT_FRAGMENT 0x001a
#define STUN_ATTR_TYPE_XOR_MAPPED_ADDRESS 0x0020
#define STUN_ATTR_TYPE_RESERVATION_TOKEN 0x0022
#define STUN_ATTR_TYPE_PRIORITY 0x0024
#define STUN_ATTR_TYPE_USE_CANDIDATE 0x0025
#define STUN_ATTR_TYPE_PADDING 0x0026
#define STUN_ATTR_TYPE_RESPONSE_PORT 0x0027
#define STUN_ATTR_TYPE_SOFTWARE 0x8022
#define STUN_ATTR_TYPE_ALTERNATE_SERVER 0x8023
#define STUN_ATTR_TYPE_FINGERPRINT 0x8028
#define STUN_ATTR_TYPE_ICE_CONTROLLED 0x8029
#define STUN_ATTR_TYPE_ICE_CONTROLLING 0x802a
#define STUN_ATTR_TYPE_RESPONSE_ORIGIN 0x802b
#define STUN_ATTR_TYPE_OTHER_ADDRESS 0x802c
/* --------------------------------------- */
std::string stun_message_type_to_string(uint16_t type);
std::string stun_attribute_type_to_string(uint16_t type);
std::string stun_family_type_to_string(uint8_t type);
/*
Compute the hmac-sha1 over message.
uint8_t* message: the data over which we compute the hmac sha
uint32_t nbytes: the number of bytse in message
std::string key: key to use for hmac
uint8_t* output: we write the sha1 into this buffer.
*/
int stun_compute_hmac_sha1(uint8_t* message, uint32_t nbytes, std::string key, uint8_t* output);
/*
Compute the Message-Integrity of a stun message.
This will not change the given buffer.
std::vector<uint8_t>& buffer: the buffer that contains a valid stun message
std::string key: key to use for hmac
uint8_t* output: will be filled with the correct hmac-sha1 of that represents the integrity message value.
*/
int stun_compute_message_integrity(std::vector<uint8_t>& buffer, std::string key, uint8_t* output);
/*
Compute the fingerprint value for the stun message.
This will not change the given buffer.
std::vector<uint8_t>& buffer: the buffer that contains a valid stun message.
uint32_t& result: will be set to the calculated crc value.
*/
int stun_compute_fingerprint(std::vector<uint8_t>& buffer, uint32_t& result);
/* --------------------------------------- */
/* https://tools.ietf.org/html/rfc5389#section-15.10 */
class StunAttribSoftware {
public:
char* value;
};
class StunAttribFingerprint {
public:
uint32_t value;
};
/* https://tools.ietf.org/html/rfc5389#section-15.4 */
class StunAttribMessageIntegrity {
public:
uint8_t* sha1;
};
/* https://tools.ietf.org/html/rfc5245#section-19.1 */
class StunAttribPriority {
public:
uint32_t value;
};
/* https://tools.ietf.org/html/rfc5245#section-19.1 */
class StunAttribIceControllling {
public:
uint64_t tie_breaker;
};
/* https://tools.ietf.org/html/rfc3489#section-11.2.6 */
class StunAttribUsername {
public:
char* value; /* Must use `length` member of attribute that indicates the number of valid bytes in the username. */
};
/* https://tools.ietf.org/html/rfc5389#section-15.2 */
class StunAttribXorMappedAddress {
public:
uint8_t family;
uint16_t port;
uint8_t ip[16];
};
/* --------------------------------------- */
class StunAttribute {
public:
StunAttribute();
void print();
public:
uint16_t type;
uint16_t length;
union {
StunAttribXorMappedAddress xor_address;
StunAttribUsername username;
StunAttribIceControllling ice_controlling;
StunAttribPriority priority;
StunAttribSoftware software;
StunAttribMessageIntegrity message_integrity;
StunAttribFingerprint fingerprint;
};
};
/* --------------------------------------- */
class StunMessage {
public:
StunMessage();
void setType(uint16_t type);
void setTransactionId(uint32_t a, uint32_t b, uint32_t c);
void print();
void addAttribute(StunAttribute& attr);
void removeAttributes();
StunAttribute* getAttributeByType(uint16_t type);
public:
uint16_t type;
uint16_t length;
uint32_t cookie;
uint32_t transaction_id[3];
std::vector<StunAttribute> attributes;
};
/* --------------------------------------- */
class StunReader {
public:
StunReader();
int parse(uint8_t* data, size_t nbytes, size_t& nparsed, StunMessage& msg); /* `nparsed` and `msg` are filled. */
private:
int parseXorMappedAddress(StunAttribute& attr);
int parseUsername(StunAttribute& attr);
int parseIceControlling(StunAttribute& attr);
int parsePriority(StunAttribute& attr);
int parseSoftware(StunAttribute& attr);
int parseMessageIntegrity(StunAttribute& attr);
int parseFingerprint(StunAttribute& attr);
uint8_t readU8();
uint16_t readU16();
uint32_t readU32();
uint64_t readU64();
private:
uint8_t* buffer_data;
size_t buffer_size;
size_t read_dx;
};
/* --------------------------------------- */
class StunWriter {
public:
StunWriter();
/* write header and finalize. call for each stun message */
int begin(StunMessage& msg, uint8_t paddingByte = 0x00); /* I've added the padding byte here so that we can use the examples that can be found here https://tools.ietf.org/html/rfc5769#section-2.2 as they use 0x20 or 0x00 as the padding byte which is correct as you are free to use w/e padding byte you want. */
int end();
/* write attributes */
int writeXorMappedAddress(sockaddr_in addr);
int writeXorMappedAddress(uint8_t family, uint16_t port, uint32_t ip);
int writeXorMappedAddress(uint8_t family, uint16_t port, const std::string& ip);
int writeUsername(const std::string& username);
int writeSoftware(const std::string& software);
int writeMessageIntegrity(const std::string& password); /* When using WebRtc this is the ice-upwd of the other agent. */
int writeFingerprint(); /* Must be the last attribute in the message. When adding a fingerprint, make sure that it is added after the message-integrity (when you also use a message-integrity). */
/* get buffer */
uint8_t* getBufferPtr();
size_t getBufferSize();
private:
void writeU8(uint8_t v);
void writeU16(uint16_t v);
void writeU32(uint32_t v);
void writeU64(uint64_t v);
void rewriteU16(size_t dx, uint16_t v);
void rewriteU32(size_t dx, uint32_t v);
void writeString(const std::string& str);
void writePadding();
int convertIp4StringToInt(const std::string& ip, uint32_t& result);
private:
std::vector<uint8_t> buffer;
uint8_t padding_byte;
};
/* --------------------------------------- */
inline uint8_t* StunWriter::getBufferPtr() {
if (0 == buffer.size()) {
return NULL;
}
return &buffer[0];
}
inline size_t StunWriter::getBufferSize() {
return buffer.size();
}
/* --------------------------------------- */