420 lines
12 KiB
C++
420 lines
12 KiB
C++
#include <algorithm>
|
|
#include "defines.h"
|
|
#include "dtls_srtp_handshake.h"
|
|
|
|
/* Write mbedtls into a log file. */
|
|
#define LOG_TO_FILE 0
|
|
#if LOG_TO_FILE
|
|
# include <fstream>
|
|
#endif
|
|
|
|
/* ----------------------------------------- */
|
|
|
|
static void print_mbedtls_error(int r);
|
|
static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str);
|
|
static int on_mbedtls_wants_to_read(void* user, unsigned char* buf, size_t len); /* Called when mbedtls wants to read data from e.g. a socket. */
|
|
static int on_mbedtls_wants_to_write(void* user, const unsigned char* buf, size_t len); /* Called when mbedtls wants to write data to e.g. a socket. */
|
|
static std::string mbedtls_err_to_string(int r);
|
|
|
|
/* ----------------------------------------- */
|
|
|
|
DTLSSRTPHandshake::DTLSSRTPHandshake()
|
|
:write_callback(NULL)
|
|
,cert(NULL)
|
|
,key(NULL)
|
|
{
|
|
memset((void*)&entropy_ctx, 0x00, sizeof(entropy_ctx));
|
|
memset((void*)&rand_ctx, 0x00, sizeof(rand_ctx));
|
|
memset((void*)&ssl_ctx, 0x00, sizeof(ssl_ctx));
|
|
memset((void*)&ssl_conf, 0x00, sizeof(ssl_conf));
|
|
memset((void*)&cookie_ctx, 0x00, sizeof(cookie_ctx));
|
|
memset((void*)&timer_ctx, 0x00, sizeof(timer_ctx));
|
|
}
|
|
|
|
int DTLSSRTPHandshake::init(mbedtls_x509_crt* certificate,
|
|
mbedtls_pk_context* privateKey,
|
|
int(*writeCallback)(const uint8_t* data, int* nbytes)
|
|
)
|
|
{
|
|
|
|
int r = 0;
|
|
mbedtls_ssl_srtp_profile srtp_profiles[] = {
|
|
MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80,
|
|
MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32
|
|
};
|
|
|
|
if (!writeCallback) {
|
|
FAIL_MSG("No writeCallack function given.");
|
|
r = -3;
|
|
goto error;
|
|
}
|
|
|
|
if (!certificate) {
|
|
FAIL_MSG("Given certificate is null.");
|
|
r = -5;
|
|
goto error;
|
|
}
|
|
|
|
if (!privateKey) {
|
|
FAIL_MSG("Given key is null.");
|
|
r = -10;
|
|
goto error;
|
|
}
|
|
|
|
cert = certificate;
|
|
key = privateKey;
|
|
|
|
/* init the contexts */
|
|
mbedtls_entropy_init(&entropy_ctx);
|
|
mbedtls_ctr_drbg_init(&rand_ctx);
|
|
mbedtls_ssl_init(&ssl_ctx);
|
|
mbedtls_ssl_config_init(&ssl_conf);
|
|
mbedtls_ssl_cookie_init(&cookie_ctx);
|
|
|
|
/* seed and setup the random number generator */
|
|
r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, (const unsigned char*)"mist-srtp", 9);
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
r = -20;
|
|
goto error;
|
|
}
|
|
|
|
/* load defaults into our ssl_conf */
|
|
r = mbedtls_ssl_config_defaults(&ssl_conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT);
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
r = -30;
|
|
goto error;
|
|
}
|
|
|
|
mbedtls_ssl_conf_authmode(&ssl_conf, MBEDTLS_SSL_VERIFY_NONE);
|
|
mbedtls_ssl_conf_rng(&ssl_conf, mbedtls_ctr_drbg_random, &rand_ctx);
|
|
mbedtls_ssl_conf_dbg(&ssl_conf, print_mbedtls_debug_message, stdout);
|
|
mbedtls_ssl_conf_ca_chain(&ssl_conf, cert, NULL);
|
|
mbedtls_ssl_conf_cert_profile(&ssl_conf, &mbedtls_x509_crt_profile_default);
|
|
mbedtls_debug_set_threshold(10);
|
|
|
|
/* enable SRTP */
|
|
r = mbedtls_ssl_conf_dtls_srtp_protection_profiles(&ssl_conf, srtp_profiles, sizeof(srtp_profiles) / sizeof(srtp_profiles[0]));
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
r = -40;
|
|
goto error;
|
|
}
|
|
|
|
/* cert certificate chain + key, so we can verify the client-hello signed data */
|
|
r = mbedtls_ssl_conf_own_cert(&ssl_conf, cert, key);
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
r = -50;
|
|
goto error;
|
|
}
|
|
|
|
/* cookie setup (e.g. to prevent ddos). */
|
|
r = mbedtls_ssl_cookie_setup(&cookie_ctx, mbedtls_ctr_drbg_random, &rand_ctx);
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
r = -60;
|
|
goto error;
|
|
}
|
|
|
|
/* register callbacks for dtls cookies (server only). */
|
|
mbedtls_ssl_conf_dtls_cookies(&ssl_conf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &cookie_ctx);
|
|
|
|
/* setup the ssl context for use. note that ssl_conf will be referenced internall by the context and therefore should be kept around. */
|
|
r = mbedtls_ssl_setup(&ssl_ctx, &ssl_conf);
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
r = -70;
|
|
goto error;
|
|
}
|
|
|
|
/* set bio handlers */
|
|
mbedtls_ssl_set_bio(&ssl_ctx, (void*)this, on_mbedtls_wants_to_write, on_mbedtls_wants_to_read, NULL);
|
|
|
|
/* set temp id, just adds some exta randomness */
|
|
{
|
|
std::string remote_id = "mist";
|
|
r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char*)remote_id.c_str(), remote_id.size());
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
r = -80;
|
|
goto error;
|
|
}
|
|
}
|
|
|
|
/* set timer callbacks */
|
|
mbedtls_ssl_set_timer_cb(&ssl_ctx, &timer_ctx, mbedtls_timing_set_delay, mbedtls_timing_get_delay);
|
|
|
|
write_callback = writeCallback;
|
|
|
|
error:
|
|
|
|
if (r < 0) {
|
|
shutdown();
|
|
}
|
|
|
|
return r;
|
|
}
|
|
|
|
int DTLSSRTPHandshake::shutdown() {
|
|
|
|
/* cleanup the refs from the settings. */
|
|
cert = NULL;
|
|
key = NULL;
|
|
buffer.clear();
|
|
cipher.clear();
|
|
remote_key.clear();
|
|
remote_salt.clear();
|
|
local_key.clear();
|
|
local_salt.clear();
|
|
|
|
/* free our contexts; we do not free the `settings.cert` and `settings.key` as they are owned by the user of this class. */
|
|
mbedtls_entropy_free(&entropy_ctx);
|
|
mbedtls_ctr_drbg_free(&rand_ctx);
|
|
mbedtls_ssl_free(&ssl_ctx);
|
|
mbedtls_ssl_config_free(&ssl_conf);
|
|
mbedtls_ssl_cookie_free(&cookie_ctx);
|
|
|
|
return 0;
|
|
}
|
|
|
|
/* ----------------------------------------- */
|
|
|
|
int DTLSSRTPHandshake::parse(const uint8_t* data, size_t nbytes) {
|
|
|
|
if (NULL == data) {
|
|
ERROR_MSG("Given `data` is NULL.");
|
|
return -1;
|
|
}
|
|
|
|
if (0 == nbytes) {
|
|
ERROR_MSG("Given nbytes is 0.");
|
|
return -2;
|
|
}
|
|
|
|
if (MBEDTLS_SSL_HANDSHAKE_OVER == ssl_ctx.state) {
|
|
ERROR_MSG("Already finished the handshake.");
|
|
return -3;
|
|
}
|
|
|
|
/* copy incoming data into a temporary buffer which is read via our `bio` read function. */
|
|
int r = 0;
|
|
std::copy(data, data + nbytes, std::back_inserter(buffer));
|
|
|
|
do {
|
|
|
|
r = mbedtls_ssl_handshake(&ssl_ctx);
|
|
|
|
switch (r) {
|
|
/* 0 = handshake done. */
|
|
case 0: {
|
|
if (0 != extractKeyingMaterial()) {
|
|
ERROR_MSG("Failed to extract keying material after handshake was done.");
|
|
return -2;
|
|
}
|
|
return 0;
|
|
}
|
|
/* see the dtls server example; this is used to prevent certain attacks (ddos) */
|
|
case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: {
|
|
if (0 != resetSession()) {
|
|
ERROR_MSG("Failed to reset the session which is necessary when we need to verify the HELLO.");
|
|
return -3;
|
|
}
|
|
break;
|
|
}
|
|
case MBEDTLS_ERR_SSL_WANT_READ: {
|
|
DONTEVEN_MSG("mbedtls wants a bit more data before it can continue parsing the DTLS handshake.");
|
|
break;
|
|
}
|
|
default: {
|
|
ERROR_MSG("A serious mbedtls error occured.");
|
|
print_mbedtls_error(r);
|
|
return -2;
|
|
}
|
|
}
|
|
}
|
|
while (MBEDTLS_ERR_SSL_WANT_WRITE == r);
|
|
|
|
return 0;
|
|
}
|
|
|
|
/* ----------------------------------------- */
|
|
|
|
int DTLSSRTPHandshake::resetSession() {
|
|
|
|
std::string remote_id = "mist"; /* @todo for now we hardcoded this... */
|
|
int r = 0;
|
|
|
|
r = mbedtls_ssl_session_reset(&ssl_ctx);
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
return -1;
|
|
}
|
|
|
|
r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char*)remote_id.c_str(), remote_id.size());
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
return -2;
|
|
}
|
|
|
|
buffer.clear();
|
|
|
|
return 0;
|
|
}
|
|
|
|
/*
|
|
master key is 128 bits => 16 bytes.
|
|
master salt is 112 bits => 14 bytes
|
|
*/
|
|
int DTLSSRTPHandshake::extractKeyingMaterial() {
|
|
|
|
int r = 0;
|
|
uint8_t keying_material[MBEDTLS_DTLS_SRTP_MAX_KEY_MATERIAL_LENGTH] = {};
|
|
size_t keying_material_len = sizeof(keying_material);
|
|
|
|
r = mbedtls_ssl_get_dtls_srtp_key_material(&ssl_ctx, keying_material, &keying_material_len);
|
|
if (0 != r) {
|
|
print_mbedtls_error(r);
|
|
return -1;
|
|
}
|
|
|
|
/* @todo following code is for server mode only */
|
|
mbedtls_ssl_srtp_profile srtp_profile = mbedtls_ssl_get_dtls_srtp_protection_profile(&ssl_ctx);
|
|
switch (srtp_profile) {
|
|
case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80: {
|
|
cipher = "SRTP_AES128_CM_SHA1_80";
|
|
break;
|
|
}
|
|
case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32: {
|
|
cipher = "SRTP_AES128_CM_SHA1_32";
|
|
break;
|
|
}
|
|
default: {
|
|
ERROR_MSG("Unhandled SRTP profile, cannot extract keying material.");
|
|
return -6;
|
|
}
|
|
}
|
|
|
|
remote_key.assign((char*)(&keying_material[0]) + 0, 16);
|
|
local_key.assign((char*)(&keying_material[0]) + 16, 16);
|
|
remote_salt.assign((char*)(&keying_material[0]) + 32, 14);
|
|
local_salt.assign((char*)(&keying_material[0]) + 46, 14);
|
|
|
|
DONTEVEN_MSG("Extracted the DTLS SRTP keying material with cipher %s.", cipher.c_str());
|
|
DONTEVEN_MSG("Remote DTLS SRTP key size is %zu.", remote_key.size());
|
|
DONTEVEN_MSG("Remote DTLS SRTP salt size is %zu.", remote_salt.size());
|
|
DONTEVEN_MSG("Local DTLS SRTP key size is %zu.", local_key.size());
|
|
DONTEVEN_MSG("Local DTLS SRTP salt size is %zu.", local_salt.size());
|
|
|
|
return 0;
|
|
}
|
|
|
|
/* ----------------------------------------- */
|
|
|
|
/*
|
|
|
|
This function is called by mbedtls whenever it wants to read
|
|
some data. The documentation states the following: "For DTLS,
|
|
you need to provide either a non-NULL f_recv_timeout
|
|
callback, or a f_recv that doesn't block." As this
|
|
implementation is completely decoupled from any I/O and uses
|
|
a "push" model instead of a "pull" model we have to copy new
|
|
input bytes into a temporary buffer (see parse), but we act
|
|
as if we were using a non-blocking socket, which means:
|
|
|
|
- we return MBETLS_ERR_SSL_WANT_READ when there is no data left to read
|
|
- when there is data in our temporary buffer, we read from that
|
|
|
|
*/
|
|
static int on_mbedtls_wants_to_read(void* user, unsigned char* buf, size_t len) {
|
|
|
|
DTLSSRTPHandshake* hs = static_cast<DTLSSRTPHandshake*>(user);
|
|
if (NULL == hs) {
|
|
ERROR_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake.");
|
|
return -1;
|
|
}
|
|
|
|
/* figure out how much we can read. */
|
|
if (hs->buffer.size() == 0) {
|
|
return MBEDTLS_ERR_SSL_WANT_READ;
|
|
}
|
|
|
|
size_t nbytes = hs->buffer.size();
|
|
if (nbytes > len) {
|
|
nbytes = len;
|
|
}
|
|
|
|
/* "read" into the given buffer. */
|
|
memcpy(buf, &hs->buffer[0], nbytes);
|
|
hs->buffer.erase(hs->buffer.begin(), hs->buffer.begin() + nbytes);
|
|
|
|
return (int)nbytes;
|
|
}
|
|
|
|
static int on_mbedtls_wants_to_write(void* user, const unsigned char* buf, size_t len) {
|
|
|
|
DTLSSRTPHandshake* hs = static_cast<DTLSSRTPHandshake*>(user);
|
|
if (!hs) {
|
|
FAIL_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake.");
|
|
return -1;
|
|
}
|
|
|
|
if (!hs->write_callback) {
|
|
FAIL_MSG("The `write_callback` member is NULL.");
|
|
return -2;
|
|
}
|
|
|
|
int nwritten = (int)len;
|
|
if (0 != hs->write_callback(buf, &nwritten)) {
|
|
FAIL_MSG("Failed to write some DTLS handshake data.");
|
|
return -3;
|
|
}
|
|
|
|
if (nwritten != (int)len) {
|
|
FAIL_MSG("The DTLS-SRTP handshake listener MUST write all the data.");
|
|
return -4;
|
|
}
|
|
|
|
return nwritten;
|
|
}
|
|
|
|
/* ----------------------------------------- */
|
|
|
|
static void print_mbedtls_error(int r) {
|
|
char buf[1024] = {};
|
|
mbedtls_strerror(r, buf, sizeof(buf));
|
|
ERROR_MSG("mbedtls error: %s", buf);
|
|
}
|
|
|
|
static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str) {
|
|
DONTEVEN_MSG("%s:%04d: %.*s", file, line, strlen(str) - 1, str);
|
|
|
|
#if LOG_TO_FILE
|
|
static std::ofstream ofs;
|
|
if (!ofs.is_open()) {
|
|
ofs.open("mbedtls.log", std::ios::out);
|
|
}
|
|
if (!ofs.is_open()) {
|
|
return;
|
|
}
|
|
ofs << str;
|
|
ofs.flush();
|
|
#endif
|
|
|
|
}
|
|
|
|
static std::string mbedtls_err_to_string(int r) {
|
|
switch (r) {
|
|
case MBEDTLS_ERR_SSL_WANT_READ: { return "MBEDTLS_ERR_SSL_WANT_READ"; }
|
|
case MBEDTLS_ERR_SSL_WANT_WRITE: { return "MBEDTLS_ERR_SSL_WANT_WRITE"; }
|
|
default: {
|
|
print_mbedtls_error(r);
|
|
return "UNKNOWN";
|
|
}
|
|
}
|
|
}
|
|
|
|
/* ---------------------------------------- */
|
|
|
|
|