From e6489f2d6abbe08ce95aa5e75228677fb11cd728 Mon Sep 17 00:00:00 2001 From: Thulinma Date: Wed, 28 Oct 2020 16:15:56 +0100 Subject: [PATCH] WebRTC certificate improvement --- lib/certificate.cpp | 36 ++++++++++++--------- lib/certificate.h | 7 ++-- src/output/output_webrtc.cpp | 62 +++++++++++++++++++++++++++++++++--- src/output/output_webrtc.h | 2 +- 4 files changed, 83 insertions(+), 24 deletions(-) diff --git a/lib/certificate.cpp b/lib/certificate.cpp index a2c4d9d1..b504d55a 100644 --- a/lib/certificate.cpp +++ b/lib/certificate.cpp @@ -3,9 +3,9 @@ #include -Certificate::Certificate() : rsa_ctx(NULL){ - memset((void *)&cert, 0x00, sizeof(cert)); - memset((void *)&key, 0x00, sizeof(key)); +Certificate::Certificate(){ + mbedtls_pk_init(&key); + mbedtls_x509_crt_init(&cert); } int Certificate::init(const std::string &countryName, const std::string &organization, @@ -14,6 +14,7 @@ int Certificate::init(const std::string &countryName, const std::string &organiz mbedtls_ctr_drbg_context rand_ctx ={}; mbedtls_entropy_context entropy_ctx ={}; mbedtls_x509write_cert write_cert ={}; + mbedtls_rsa_context *rsa_ctx; const char *personalisation = "mbedtls-self-signed-key"; std::string subject_name = "C=" + countryName + ",O=" + organization + ",CN=" + commonName; @@ -58,7 +59,6 @@ int Certificate::init(const std::string &countryName, const std::string &organiz } // initialize the public key context - mbedtls_pk_init(&key); r = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); if (0 != r){ FAIL_MSG("Faild to initialize the PK context."); @@ -66,6 +66,8 @@ int Certificate::init(const std::string &countryName, const std::string &organiz goto error; } + //This call returns a reference to the existing RSA context inside the key. + //Hence, it does not need to be cleaned up later. rsa_ctx = mbedtls_pk_rsa(key); if (NULL == rsa_ctx){ FAIL_MSG("Failed to get the RSA context from from the public key context (key)."); @@ -181,8 +183,6 @@ int Certificate::init(const std::string &countryName, const std::string &organiz // char. @todo there must be a way to convert the write // struct into a `mbedtls_x509_cert` w/o calling this parse // function. - mbedtls_x509_crt_init(&cert); - r = mbedtls_x509_crt_parse(&cert, (const unsigned char *)buf, strlen((char *)buf) + 1); if (0 != r){ mbedtls_strerror(r, (char *)buf, sizeof(buf)); @@ -198,24 +198,31 @@ error: mbedtls_entropy_free(&entropy_ctx); mbedtls_x509write_crt_free(&write_cert); mbedtls_mpi_free(&serial_mpi); - - if (r < 0){shutdown();} - return r; } -int Certificate::shutdown(){ - rsa_ctx = NULL; +Certificate::~Certificate(){ mbedtls_pk_free(&key); mbedtls_x509_crt_free(&cert); - return 0; } -std::string Certificate::getFingerprintSha256(){ +/// Loads a single file into the certificate. Returns true on success. +bool Certificate::loadCert(const std::string & certFile){ + if (!certFile.size()){return true;} + return mbedtls_x509_crt_parse_file(&cert, certFile.c_str()) == 0; +} +/// Loads a single key. Returns true on success. +bool Certificate::loadKey(const std::string & keyFile){ + if (!keyFile.size()){return true;} + return mbedtls_pk_parse_keyfile(&key, keyFile.c_str(), 0) == 0; +} + +/// Calculates SHA256 fingerprint over the loaded certificate(s) +/// Returns the fingerprint as hex-string. +std::string Certificate::getFingerprintSha256() const{ uint8_t fingerprint_raw[32] ={}; uint8_t fingerprint_hex[128] ={}; - mbedtls_sha256(cert.raw.p, cert.raw.len, fingerprint_raw, 0); for (int i = 0; i < 32; ++i){ @@ -223,7 +230,6 @@ std::string Certificate::getFingerprintSha256(){ } fingerprint_hex[32 * 3] = '\0'; - std::string result = std::string((char *)fingerprint_hex + 1, (32 * 3) - 1); return result; } diff --git a/lib/certificate.h b/lib/certificate.h index 0b717810..a8ccf747 100644 --- a/lib/certificate.h +++ b/lib/certificate.h @@ -23,12 +23,13 @@ class Certificate{ public: Certificate(); + bool loadCert(const std::string & certFile); + bool loadKey(const std::string & certFile); int init(const std::string &countryName, const std::string &organization, const std::string &commonName); - int shutdown(); - std::string getFingerprintSha256(); + ~Certificate(); + std::string getFingerprintSha256() const; public: mbedtls_x509_crt cert; mbedtls_pk_context key; /* key context, stores private and public key. */ - mbedtls_rsa_context *rsa_ctx; /* rsa context, stored in key_ctx */ }; diff --git a/src/output/output_webrtc.cpp b/src/output/output_webrtc.cpp index 6e1d2846..2dbd4de1 100644 --- a/src/output/output_webrtc.cpp +++ b/src/output/output_webrtc.cpp @@ -5,6 +5,7 @@ #include #include #include // ifaddr, listing ip addresses. +#include namespace Mist{ @@ -79,16 +80,55 @@ namespace Mist{ volkswagenMode = false; syncedNTPClock = false; - if (cert.init("NL", "webrtc", "webrtc") != 0){ - onFail("Failed to create the certificate.", true); - return; + + JSON::Value & certOpt = config->getOption("cert", true); + JSON::Value & keyOpt = config->getOption("key", true); + + //Attempt to read certificate config from other connectors + if (certOpt.size() < 2 || keyOpt.size() < 2){ + Util::DTSCShmReader rProto(SHM_PROTO); + DTSC::Scan prtcls = rProto.getScan(); + unsigned int pro_cnt = prtcls.getSize(); + for (unsigned int i = 0; i < pro_cnt; ++i){ + if (prtcls.getIndice(i).hasMember("key") && prtcls.getIndice(i).hasMember("cert")){ + std::string conn = prtcls.getIndice(i).getMember("connector").asString(); + INFO_MSG("No cert/key configured for WebRTC explicitly, reading from %s connector config", conn.c_str()); + JSON::Value newCert = prtcls.getIndice(i).getMember("cert").asJSON(); + certOpt.shrink(0); + jsonForEach(newCert, k){certOpt.append(*k);} + keyOpt.shrink(0); + keyOpt.append(prtcls.getIndice(i).getMember("key").asJSON()); + break; + } + } } + + if (certOpt.size() < 2 || keyOpt.size() < 2){ + if (cert.init("NL", "webrtc", "webrtc") != 0){ + onFail("Failed to create the certificate.", true); + return; + } + }else{ + // Read certificate chain(s) + jsonForEach(certOpt, it){ + if (!cert.loadCert(it->asStringRef())){ + WARN_MSG("Could not load any certificates from file: %s", it->asStringRef().c_str()); + } + } + + // Read key + if (!cert.loadKey(config->getString("key"))){ + FAIL_MSG("Could not load any keys from file: %s", config->getString("key").c_str()); + return; + } + } + if (dtlsHandshake.init(&cert.cert, &cert.key, onDTLSHandshakeWantsToWriteCallback) != 0){ onFail("Failed to initialize the dtls-srtp handshake helper.", true); return; } - sdpAnswer.setFingerprint(cert.getFingerprintSha256()); + classPointer = this; setBlocking(false); @@ -107,7 +147,6 @@ namespace Mist{ if (dtlsHandshake.shutdown() != 0){ FAIL_MSG("Failed to cleanly shutdown the dtls handshake."); } - if (cert.shutdown() != 0){FAIL_MSG("Failed to cleanly shutdown the certificate.");} } // Initialize the WebRTC output. This is where we define what @@ -218,6 +257,19 @@ namespace Mist{ capa["optional"]["losttimeoutmobile"]["type"] = "uint"; capa["optional"]["losttimeoutmobile"]["default"] = 90; + capa["optional"]["cert"]["name"] = "Certificate"; + capa["optional"]["cert"]["help"] = "(Root) certificate(s) file(s) to append to chain"; + capa["optional"]["cert"]["option"] = "--cert"; + capa["optional"]["cert"]["short"] = "C"; + capa["optional"]["cert"]["default"] = ""; + capa["optional"]["cert"]["type"] = "str"; + capa["optional"]["key"]["name"] = "Key"; + capa["optional"]["key"]["help"] = "Private key for SSL"; + capa["optional"]["key"]["option"] = "--key"; + capa["optional"]["key"]["short"] = "K"; + capa["optional"]["key"]["default"] = ""; + capa["optional"]["key"]["type"] = "str"; + config->addOptionsFromCapabilities(capa); } diff --git a/src/output/output_webrtc.h b/src/output/output_webrtc.h index 687da87e..d7c5f054 100644 --- a/src/output/output_webrtc.h +++ b/src/output/output_webrtc.h @@ -140,7 +140,7 @@ namespace Mist{ int onDTLSHandshakeWantsToWrite(const uint8_t *data, int *nbytes); void onRTPSorterHasPacket(size_t tid, const RTP::Packet &pkt); void onDTSCConverterHasPacket(const DTSC::Packet &pkt); - void onDTSCConverterHasInitData(const uint64_t trackID, const std::string &initData); + void onDTSCConverterHasInitData(const size_t trackID, const std::string &initData); void onRTPPacketizerHasRTPPacket(const char *data, size_t nbytes); void onRTPPacketizerHasRTCPPacket(const char *data, uint32_t nbytes);