#include "defines.h" #include "dtls_srtp_handshake.h" #include #include /* Write mbedtls into a log file. */ #define LOG_TO_FILE 0 #if LOG_TO_FILE #include #endif /* ----------------------------------------- */ static void print_mbedtls_error(int r); static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str); static int on_mbedtls_wants_to_read(void *user, unsigned char *buf, size_t len); /* Called when mbedtls wants to read data from e.g. a socket. */ static int on_mbedtls_wants_to_write(void *user, const unsigned char *buf, size_t len); /* Called when mbedtls wants to write data to e.g. a socket. */ /* ----------------------------------------- */ DTLSSRTPHandshake::DTLSSRTPHandshake() : cert(NULL), key(NULL), write_callback(NULL){ memset((void *)&entropy_ctx, 0x00, sizeof(entropy_ctx)); memset((void *)&rand_ctx, 0x00, sizeof(rand_ctx)); memset((void *)&ssl_ctx, 0x00, sizeof(ssl_ctx)); memset((void *)&ssl_conf, 0x00, sizeof(ssl_conf)); memset((void *)&cookie_ctx, 0x00, sizeof(cookie_ctx)); memset((void *)&timer_ctx, 0x00, sizeof(timer_ctx)); } int DTLSSRTPHandshake::init(mbedtls_x509_crt *certificate, mbedtls_pk_context *privateKey, int (*writeCallback)(const uint8_t *data, int *nbytes)){ int r = 0; mbedtls_ssl_srtp_profile srtp_profiles[] ={MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80, MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32}; if (!writeCallback){ FAIL_MSG("No writeCallack function given."); r = -3; goto error; } if (!certificate){ FAIL_MSG("Given certificate is null."); r = -5; goto error; } if (!privateKey){ FAIL_MSG("Given key is null."); r = -10; goto error; } cert = certificate; key = privateKey; /* init the contexts */ mbedtls_entropy_init(&entropy_ctx); mbedtls_ctr_drbg_init(&rand_ctx); mbedtls_ssl_init(&ssl_ctx); mbedtls_ssl_config_init(&ssl_conf); mbedtls_ssl_cookie_init(&cookie_ctx); /* seed and setup the random number generator */ r = mbedtls_ctr_drbg_seed(&rand_ctx, mbedtls_entropy_func, &entropy_ctx, (const unsigned char *)"mist-srtp", 9); if (0 != r){ print_mbedtls_error(r); r = -20; goto error; } /* load defaults into our ssl_conf */ r = mbedtls_ssl_config_defaults(&ssl_conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT); if (0 != r){ print_mbedtls_error(r); r = -30; goto error; } mbedtls_ssl_conf_authmode(&ssl_conf, MBEDTLS_SSL_VERIFY_NONE); mbedtls_ssl_conf_rng(&ssl_conf, mbedtls_ctr_drbg_random, &rand_ctx); mbedtls_ssl_conf_dbg(&ssl_conf, print_mbedtls_debug_message, stdout); mbedtls_ssl_conf_ca_chain(&ssl_conf, cert, NULL); mbedtls_ssl_conf_cert_profile(&ssl_conf, &mbedtls_x509_crt_profile_default); mbedtls_debug_set_threshold(10); /* enable SRTP */ r = mbedtls_ssl_conf_dtls_srtp_protection_profiles(&ssl_conf, srtp_profiles, sizeof(srtp_profiles) / sizeof(srtp_profiles[0])); if (0 != r){ print_mbedtls_error(r); r = -40; goto error; } /* cert certificate chain + key, so we can verify the client-hello signed data */ r = mbedtls_ssl_conf_own_cert(&ssl_conf, cert, key); if (0 != r){ print_mbedtls_error(r); r = -50; goto error; } /* cookie setup (e.g. to prevent ddos). */ r = mbedtls_ssl_cookie_setup(&cookie_ctx, mbedtls_ctr_drbg_random, &rand_ctx); if (0 != r){ print_mbedtls_error(r); r = -60; goto error; } /* register callbacks for dtls cookies (server only). */ mbedtls_ssl_conf_dtls_cookies(&ssl_conf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &cookie_ctx); /* setup the ssl context for use. note that ssl_conf will be referenced internall by the context and therefore should be kept around. */ r = mbedtls_ssl_setup(&ssl_ctx, &ssl_conf); if (0 != r){ print_mbedtls_error(r); r = -70; goto error; } /* set bio handlers */ mbedtls_ssl_set_bio(&ssl_ctx, (void *)this, on_mbedtls_wants_to_write, on_mbedtls_wants_to_read, NULL); /* set temp id, just adds some exta randomness */ { std::string remote_id = "mist"; r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char *)remote_id.c_str(), remote_id.size()); if (0 != r){ print_mbedtls_error(r); r = -80; goto error; } } /* set timer callbacks */ mbedtls_ssl_set_timer_cb(&ssl_ctx, &timer_ctx, mbedtls_timing_set_delay, mbedtls_timing_get_delay); write_callback = writeCallback; error: if (r < 0){shutdown();} return r; } int DTLSSRTPHandshake::shutdown(){ /* cleanup the refs from the settings. */ cert = NULL; key = NULL; buffer.clear(); cipher.clear(); remote_key.clear(); remote_salt.clear(); local_key.clear(); local_salt.clear(); /* free our contexts; we do not free the `settings.cert` and `settings.key` as they are owned by the user of this class. */ mbedtls_entropy_free(&entropy_ctx); mbedtls_ctr_drbg_free(&rand_ctx); mbedtls_ssl_free(&ssl_ctx); mbedtls_ssl_config_free(&ssl_conf); mbedtls_ssl_cookie_free(&cookie_ctx); return 0; } /* ----------------------------------------- */ int DTLSSRTPHandshake::parse(const uint8_t *data, size_t nbytes){ if (NULL == data){ ERROR_MSG("Given `data` is NULL."); return -1; } if (0 == nbytes){ ERROR_MSG("Given nbytes is 0."); return -2; } if (MBEDTLS_SSL_HANDSHAKE_OVER == ssl_ctx.state){ ERROR_MSG("Already finished the handshake."); return -3; } /* copy incoming data into a temporary buffer which is read via our `bio` read function. */ int r = 0; std::copy(data, data + nbytes, std::back_inserter(buffer)); do{ r = mbedtls_ssl_handshake(&ssl_ctx); switch (r){ /* 0 = handshake done. */ case 0:{ if (0 != extractKeyingMaterial()){ ERROR_MSG("Failed to extract keying material after handshake was done."); return -2; } return 0; } /* see the dtls server example; this is used to prevent certain attacks (ddos) */ case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:{ if (0 != resetSession()){ ERROR_MSG( "Failed to reset the session which is necessary when we need to verify the HELLO."); return -3; } break; } case MBEDTLS_ERR_SSL_WANT_READ:{ DONTEVEN_MSG( "mbedtls wants a bit more data before it can continue parsing the DTLS handshake."); break; } default:{ ERROR_MSG("A serious mbedtls error occured."); print_mbedtls_error(r); return -2; } } }while (MBEDTLS_ERR_SSL_WANT_WRITE == r); return 0; } /* ----------------------------------------- */ int DTLSSRTPHandshake::resetSession(){ std::string remote_id = "mist"; /* @todo for now we hardcoded this... */ int r = 0; r = mbedtls_ssl_session_reset(&ssl_ctx); if (0 != r){ print_mbedtls_error(r); return -1; } r = mbedtls_ssl_set_client_transport_id(&ssl_ctx, (const unsigned char *)remote_id.c_str(), remote_id.size()); if (0 != r){ print_mbedtls_error(r); return -2; } buffer.clear(); return 0; } /* master key is 128 bits => 16 bytes. master salt is 112 bits => 14 bytes */ int DTLSSRTPHandshake::extractKeyingMaterial(){ int r = 0; uint8_t keying_material[MBEDTLS_DTLS_SRTP_MAX_KEY_MATERIAL_LENGTH] ={}; size_t keying_material_len = sizeof(keying_material); r = mbedtls_ssl_get_dtls_srtp_key_material(&ssl_ctx, keying_material, &keying_material_len); if (0 != r){ print_mbedtls_error(r); return -1; } /* @todo following code is for server mode only */ mbedtls_ssl_srtp_profile srtp_profile = mbedtls_ssl_get_dtls_srtp_protection_profile(&ssl_ctx); switch (srtp_profile){ case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_80:{ cipher = "SRTP_AES128_CM_SHA1_80"; break; } case MBEDTLS_SRTP_AES128_CM_HMAC_SHA1_32:{ cipher = "SRTP_AES128_CM_SHA1_32"; break; } default:{ ERROR_MSG("Unhandled SRTP profile, cannot extract keying material."); return -6; } } remote_key.assign((char *)(&keying_material[0]) + 0, 16); local_key.assign((char *)(&keying_material[0]) + 16, 16); remote_salt.assign((char *)(&keying_material[0]) + 32, 14); local_salt.assign((char *)(&keying_material[0]) + 46, 14); DONTEVEN_MSG("Extracted the DTLS SRTP keying material with cipher %s.", cipher.c_str()); DONTEVEN_MSG("Remote DTLS SRTP key size is %zu.", remote_key.size()); DONTEVEN_MSG("Remote DTLS SRTP salt size is %zu.", remote_salt.size()); DONTEVEN_MSG("Local DTLS SRTP key size is %zu.", local_key.size()); DONTEVEN_MSG("Local DTLS SRTP salt size is %zu.", local_salt.size()); return 0; } /* ----------------------------------------- */ /* This function is called by mbedtls whenever it wants to read some data. The documentation states the following: "For DTLS, you need to provide either a non-NULL f_recv_timeout callback, or a f_recv that doesn't block." As this implementation is completely decoupled from any I/O and uses a "push" model instead of a "pull" model we have to copy new input bytes into a temporary buffer (see parse), but we act as if we were using a non-blocking socket, which means: - we return MBETLS_ERR_SSL_WANT_READ when there is no data left to read - when there is data in our temporary buffer, we read from that */ static int on_mbedtls_wants_to_read(void *user, unsigned char *buf, size_t len){ DTLSSRTPHandshake *hs = static_cast(user); if (NULL == hs){ ERROR_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake."); return -1; } /* figure out how much we can read. */ if (hs->buffer.size() == 0){return MBEDTLS_ERR_SSL_WANT_READ;} size_t nbytes = hs->buffer.size(); if (nbytes > len){nbytes = len;} /* "read" into the given buffer. */ memcpy(buf, &hs->buffer[0], nbytes); hs->buffer.erase(hs->buffer.begin(), hs->buffer.begin() + nbytes); return (int)nbytes; } static int on_mbedtls_wants_to_write(void *user, const unsigned char *buf, size_t len){ DTLSSRTPHandshake *hs = static_cast(user); if (!hs){ FAIL_MSG("Failed to cast the user pointer into a DTLSSRTPHandshake."); return -1; } if (!hs->write_callback){ FAIL_MSG("The `write_callback` member is NULL."); return -2; } int nwritten = (int)len; if (0 != hs->write_callback(buf, &nwritten)){ FAIL_MSG("Failed to write some DTLS handshake data."); return -3; } if (nwritten != (int)len){ FAIL_MSG("The DTLS-SRTP handshake listener MUST write all the data."); return -4; } return nwritten; } /* ----------------------------------------- */ static void print_mbedtls_error(int r){ char buf[1024] ={}; mbedtls_strerror(r, buf, sizeof(buf)); ERROR_MSG("mbedtls error: %s", buf); } static void print_mbedtls_debug_message(void *ctx, int level, const char *file, int line, const char *str){ DONTEVEN_MSG("%s:%04d: %.*s", file, line, (int)strlen(str) - 1, str); #if LOG_TO_FILE static std::ofstream ofs; if (!ofs.is_open()){ofs.open("mbedtls.log", std::ios::out);} if (!ofs.is_open()){return;} ofs << str; ofs.flush(); #endif } /* ---------------------------------------- */