diff --git a/.gitignore b/.gitignore index ab582394..c454f45e 100644 --- a/.gitignore +++ b/.gitignore @@ -59,4 +59,5 @@ rules.ninja .ninja_deps aes_ctr128 /embed/testing +*test diff --git a/CMakeLists.txt b/CMakeLists.txt index 3855c5c6..3db56499 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,9 @@ endif() if (DEFINED BIGMETA ) add_definitions(-DBIGMETA=1) endif() +if (NOT DEFINED NOSSL ) + add_definitions(-DSSL=1) +endif() ######################################## # Build Variables - Thread Names # @@ -177,6 +180,9 @@ target_link_libraries(mist -lpthread ${LIBRT} ) +if (NOT DEFINED NOSSL ) + target_link_libraries(mist mbedtls mbedx509 mbedcrypto) +endif() install( FILES ${libHeaders} DESTINATION include/mist diff --git a/lib/downloader.cpp b/lib/downloader.cpp index 2bebdeea..6f76736d 100644 --- a/lib/downloader.cpp +++ b/lib/downloader.cpp @@ -4,6 +4,19 @@ namespace HTTP{ + Downloader::Downloader(){ + progressCallback = 0; + connectedPort = 0; + ssl = false; + proxied = false; + char *p = getenv("http_proxy"); + if (p){ + proxyUrl = HTTP::URL(p); + proxied = true; + INFO_MSG("Proxying through %s", proxyUrl.getUrl().c_str()); + } + } + /// Returns a reference to the internal HTTP::Parser body element std::string &Downloader::data(){return H.body;} @@ -39,58 +52,98 @@ namespace HTTP{ Parser &Downloader::getHTTP(){return H;} /// Returns a reference to the internal Socket::Connection class instance. - Socket::Connection &Downloader::getSocket(){return S;} + Socket::Connection &Downloader::getSocket(){ +#ifdef SSL + if (ssl){return S_SSL;} +#endif + return S; + } /// Sends a request for the given URL, does no waiting. - void Downloader::doRequest(const HTTP::URL &link){ - if (link.protocol != "http"){ - FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); - return; - } + void Downloader::doRequest(const HTTP::URL &link, const std::string &method, const std::string &body){ + if (!canRequest(link)){return;} + bool needSSL = (link.protocol == "https"); INFO_MSG("Retrieving %s", link.getUrl().c_str()); H.Clean(); // Reconnect if needed - if (!S || link.host != connectedHost || link.getPort() != connectedPort){ - S.close(); - connectedHost = link.host; - connectedPort = link.getPort(); - S = Socket::Connection(connectedHost, connectedPort, true); - } - H.url = "/" + link.path; - if (link.args.size()){H.url += "?" + link.args;} - if (link.port.size()){ - H.SetHeader("Host", link.host + ":" + link.port); + if (!proxied || needSSL){ + if (!getSocket() || link.host != connectedHost || link.getPort() != connectedPort || + needSSL != ssl){ + getSocket().close(); + connectedHost = link.host; + connectedPort = link.getPort(); +#ifdef SSL + if (needSSL){ + S_SSL = Socket::SSLConnection(connectedHost, connectedPort, true); + }else{ + S = Socket::Connection(connectedHost, connectedPort, true); + } +#else + S = Socket::Connection(connectedHost, connectedPort, true); +#endif + } }else{ - H.SetHeader("Host", link.host); + if (!getSocket() || proxyUrl.host != connectedHost || proxyUrl.getPort() != connectedPort || + needSSL != ssl){ + getSocket().close(); + connectedHost = proxyUrl.host; + connectedPort = proxyUrl.getPort(); + S = Socket::Connection(connectedHost, connectedPort, true); + } + } + ssl = needSSL; + if (!getSocket()){ + return; // socket is closed + } + if (proxied && !ssl){ + H.url = link.getProxyUrl(); + if (proxyUrl.port.size()){ + H.SetHeader("Host", proxyUrl.host + ":" + proxyUrl.port); + }else{ + H.SetHeader("Host", proxyUrl.host); + } + }else{ + H.url = "/" + link.path; + if (link.args.size()){H.url += "?" + link.args;} + if (link.port.size()){ + H.SetHeader("Host", link.host + ":" + link.port); + }else{ + H.SetHeader("Host", link.host); + } + } + if (method.size()){ + H.method = method; } H.SetHeader("User-Agent", "MistServer " PACKAGE_VERSION); H.SetHeader("X-Version", PACKAGE_VERSION); H.SetHeader("Accept", "*/*"); + if (authStr.size() && (link.user.size() || link.pass.size())){ + H.auth(link.user, link.pass, authStr); + } + if (proxyAuthStr.size() && (proxyUrl.user.size() || proxyUrl.pass.size())){ + H.auth(proxyUrl.user, proxyUrl.pass, proxyAuthStr, "Proxy-Authorization"); + } if (extraHeaders.size()){ for (std::map::iterator it = extraHeaders.begin(); it != extraHeaders.end(); ++it){ H.SetHeader(it->first, it->second); } } - H.SendRequest(S); + H.SendRequest(getSocket(), body); H.Clean(); } /// Downloads the given URL into 'H', returns true on success. /// Makes at most 5 attempts, and will wait no longer than 5 seconds without receiving data. bool Downloader::get(const HTTP::URL &link, uint8_t maxRecursiveDepth){ - if (!link.host.size()){return false;} - if (link.protocol != "http"){ - FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); - return false; - } + if (!canRequest(link)){return false;} unsigned int loop = 6; // max 5 attempts while (--loop){// loop while we are unsuccessful doRequest(link); uint64_t reqTime = Util::bootSecs(); - while (S && Util::bootSecs() < reqTime + 5){ + while (getSocket() && Util::bootSecs() < reqTime + 5){ // No data? Wait for a second or so. - if (!S.spool()){ + if (!getSocket().spool()){ if (progressCallback != 0){ if (!progressCallback()){ WARN_MSG("Download aborted by callback"); @@ -101,16 +154,17 @@ namespace HTTP{ continue; } // Data! Check if we can parse it... - if (H.Read(S)){ - if (getStatusCode() >= 300 && getStatusCode() < 400){ - // follow redirect - std::string location = getHeader("Location"); + if (H.Read(getSocket())){ + if (shouldContinue()){ if (maxRecursiveDepth == 0){ - FAIL_MSG("Maximum redirect depth reached: %s", location.c_str()); + FAIL_MSG("Maximum recursion depth reached"); return false; + } + if (!canContinue(link)){return false;} + if (getStatusCode() >= 300 && getStatusCode() < 400){ + return get(link.link(getHeader("Location")), --maxRecursiveDepth); }else{ - FAIL_MSG("Following redirect to %s", location.c_str()); - return get(link.link(location), maxRecursiveDepth--); + return get(link, --maxRecursiveDepth); } } return true; // Success! @@ -118,7 +172,56 @@ namespace HTTP{ // reset the 5 second timeout reqTime = Util::bootSecs(); } - if (S){ + if (getSocket()){ + FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str()); + return false; + } + Util::sleep(500); // wait a bit before retrying + } + FAIL_MSG("Could not retrieve %s", link.getUrl().c_str()); + return false; + } + + bool Downloader::post(const HTTP::URL &link, const std::string &payload, bool sync, uint8_t maxRecursiveDepth){ + if (!canRequest(link)){return false;} + unsigned int loop = 6; // max 5 attempts + while (--loop){// loop while we are unsuccessful + doRequest(link, "POST", payload); + //Not synced? Ignore the response and immediately return false. + if (!sync){return false;} + uint64_t reqTime = Util::bootSecs(); + while (getSocket() && Util::bootSecs() < reqTime + 5){ + // No data? Wait for a second or so. + if (!getSocket().spool()){ + if (progressCallback != 0){ + if (!progressCallback()){ + WARN_MSG("Download aborted by callback"); + return false; + } + } + Util::sleep(250); + continue; + } + // Data! Check if we can parse it... + if (H.Read(getSocket())){ + if (shouldContinue()){ + if (maxRecursiveDepth == 0){ + FAIL_MSG("Maximum recursion depth reached"); + return false; + } + if (!canContinue(link)){return false;} + if (getStatusCode() >= 300 && getStatusCode() < 400){ + return post(link.link(getHeader("Location")), payload, sync, --maxRecursiveDepth); + }else{ + return post(link, payload, sync, --maxRecursiveDepth); + } + } + return true; // Success! + } + // reset the 5 second timeout + reqTime = Util::bootSecs(); + } + if (getSocket()){ FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str()); return false; } @@ -127,5 +230,75 @@ namespace HTTP{ FAIL_MSG("Could not retrieve %s", link.getUrl().c_str()); return false; } -} + + bool Downloader::canRequest(const HTTP::URL &link){ + if (!link.host.size()){return false;} + if (link.protocol != "http" && link.protocol != "https"){ + FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); + return false; + } +#ifndef SSL + if (link.protocol == "https"){ + FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); + return false; + } +#endif + return true; + } + + bool Downloader::shouldContinue(){ + if (H.hasHeader("Set-Cookie")){ + std::string cookie = H.GetHeader("Set-Cookie"); + setHeader("Cookie", cookie.substr(0, cookie.find(';'))); + } + uint32_t sCode = getStatusCode(); + if (sCode == 401 || sCode == 407 || (sCode >= 300 && sCode < 400)){ + return true; + } + return false; + } + + bool Downloader::canContinue(const HTTP::URL &link){ + if (getStatusCode() == 401){ + // retry with authentication + if (H.hasHeader("WWW-Authenticate")){authStr = H.GetHeader("WWW-Authenticate");} + if (H.hasHeader("Www-Authenticate")){authStr = H.GetHeader("Www-Authenticate");} + if (!authStr.size()){ + FAIL_MSG("Authentication required but no WWW-Authenticate header present"); + return false; + } + if (!link.user.size() && !link.pass.size()){ + FAIL_MSG("Authentication required but not included in URL"); + return false; + } + FAIL_MSG("Authenticating..."); + return true; + } + if (getStatusCode() == 407){ + // retry with authentication + if (H.hasHeader("Proxy-Authenticate")){ + proxyAuthStr = H.GetHeader("Proxy-Authenticate"); + } + if (!proxyAuthStr.size()){ + FAIL_MSG("Proxy authentication required but no Proxy-Authenticate header present"); + return false; + } + if (!proxyUrl.user.size() && !proxyUrl.pass.size()){ + FAIL_MSG("Proxy authentication required but not included in URL"); + return false; + } + FAIL_MSG("Authenticating proxy..."); + return true; + } + if (getStatusCode() >= 300 && getStatusCode() < 400){ + // follow redirect + std::string location = getHeader("Location"); + if (!location.size()){return false;} + INFO_MSG("Following redirect to %s", location.c_str()); + return true; + } + return false; + } + +}// namespace HTTP diff --git a/lib/downloader.h b/lib/downloader.h index a209e170..85f7f835 100644 --- a/lib/downloader.h +++ b/lib/downloader.h @@ -4,18 +4,22 @@ namespace HTTP{ class Downloader{ public: - Downloader(){progressCallback = 0;} + Downloader(); std::string &data(); - void doRequest(const HTTP::URL &link); + void doRequest(const HTTP::URL &link, const std::string &method="", const std::string &body=""); bool get(const std::string &link); bool get(const HTTP::URL &link, uint8_t maxRecursiveDepth = 6); + bool post(const HTTP::URL &link, const std::string &payload, bool sync = true, uint8_t maxRecursiveDepth = 6); std::string getHeader(const std::string &headerName); std::string &getStatusText(); uint32_t getStatusCode(); - bool isOk(); + bool isOk(); ///< True if the request was successful. + bool shouldContinue(); ///= 0)){ + DONTEVEN_MSG("Socket closed by remote"); + close(); + } + up += r; + return r; +} + +bool Socket::SSLConnection::connected() const{ + return isConnected; +} + +void Socket::SSLConnection::setBlocking(bool blocking){ + if (blocking != Blocking){return;} + if (blocking){ + mbedtls_net_set_block(server_fd); + Blocking = true; + }else{ + mbedtls_net_set_nonblock(server_fd); + Blocking = false; + } +} + +#endif + /// Create a new base Server. The socket is never connected, and a placeholder for later connections. Socket::Server::Server(){ sock = -1; diff --git a/lib/socket.h b/lib/socket.h index 1a4f8e48..73c391c3 100644 --- a/lib/socket.h +++ b/lib/socket.h @@ -16,6 +16,15 @@ #include #include +#ifdef SSL +#include "mbedtls/net.h" +#include "mbedtls/ssl.h" +#include "mbedtls/entropy.h" +#include "mbedtls/ctr_drbg.h" +#include "mbedtls/debug.h" +#include "mbedtls/error.h" +#endif + // for being friendly with Socket::Connection down below namespace Buffer{ class user; @@ -54,7 +63,7 @@ namespace Socket{ /// This class is for easy communicating through sockets, either TCP or Unix. class Connection{ - private: + protected: int sock; ///< Internally saved socket number. int pipes[2]; ///< Internally saved file descriptors for pipe socket simulation. std::string remotehost; ///< Stores remote host address. @@ -63,8 +72,8 @@ namespace Socket{ uint64_t down; long long int conntime; Buffer downbuffer; ///< Stores temporary data coming in. - int iread(void *buffer, int len, int flags = 0); ///< Incremental read call. - unsigned int iwrite(const void *buffer, int len); ///< Incremental write call. + virtual int iread(void *buffer, int len, int flags = 0); ///< Incremental read call. + virtual unsigned int iwrite(const void *buffer, int len); ///< Incremental write call. bool iread(Buffer &buffer, int flags = 0); ///< Incremental write call that is compatible with Socket::Buffer. bool iwrite(std::string &buffer); ///< Write call that is compatible with std::string. public: @@ -77,9 +86,9 @@ namespace Socket{ Connection(std::string adres, bool nonblock = false); ///< Create a new Unix Socket. Connection(int write, int read); ///< Simulate a socket using two file descriptors. // generic methods - void close(); ///< Close connection. + virtual void close(); ///< Close connection. void drop(); ///< Close connection without shutdown. - void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false). + virtual void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false). bool isBlocking(); ///< Check if this socket is blocking (true) or nonblocking (false). std::string getHost() const; ///< Gets hostname for connection, if available. std::string getBinHost(); @@ -87,7 +96,7 @@ namespace Socket{ int getSocket(); ///< Returns internal socket number. int getPureSocket(); ///< Returns non-piped internal socket number. std::string getError(); ///< Returns a string describing the last error that occured. - bool connected() const; ///< Returns the connected-state for this socket. + virtual bool connected() const; ///< Returns the connected-state for this socket. bool isAddress(const std::string &addr); bool isLocal(); ///< Returns true if remote address is a local address. // buffered i/o methods @@ -114,6 +123,27 @@ namespace Socket{ operator bool() const; }; +#ifdef SSL + /// Version of Socket::Connection that uses mbedtls for SSL + class SSLConnection : public Connection{ + public: + SSLConnection(); + SSLConnection(std::string hostname, int port, bool nonblock); ///< Create a new TCP socket. + void close(); ///< Close connection. + bool connected() const; ///< Returns the connected-state for this socket. + void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false). + protected: + bool isConnected; + int iread(void *buffer, int len, int flags = 0); ///< Incremental read call. + unsigned int iwrite(const void *buffer, int len); ///< Incremental write call. + mbedtls_net_context * server_fd; + mbedtls_entropy_context * entropy; + mbedtls_ctr_drbg_context * ctr_drbg; + mbedtls_ssl_context * ssl; + mbedtls_ssl_config * conf; + }; +#endif + /// This class is for easily setting up listening socket, either TCP or Unix. class Server{ private: