Merge branch 'development' into LTS_development
# Conflicts: # CMakeLists.txt
This commit is contained in:
		
						commit
						17baf864d1
					
				
					 6 changed files with 467 additions and 44 deletions
				
			
		
							
								
								
									
										1
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							|  | @ -59,4 +59,5 @@ rules.ninja | |||
| .ninja_deps | ||||
| aes_ctr128 | ||||
| /embed/testing | ||||
| *test | ||||
| 
 | ||||
|  |  | |||
|  | @ -80,6 +80,9 @@ endif() | |||
| if (DEFINED BIGMETA ) | ||||
|   add_definitions(-DBIGMETA=1) | ||||
| endif() | ||||
| if (NOT DEFINED NOSSL ) | ||||
|   add_definitions(-DSSL=1) | ||||
| endif() | ||||
| if (DEFINED DATASIZE ) | ||||
|   add_definitions(-DSHM_DATASIZE=${DATASIZE}) | ||||
| endif() | ||||
|  | @ -232,6 +235,9 @@ target_link_libraries(mist | |||
|   -lpthread  | ||||
|   ${LIBRT} | ||||
| ) | ||||
| if (NOT DEFINED NOSSL ) | ||||
|   target_link_libraries(mist mbedtls mbedx509 mbedcrypto) | ||||
| endif() | ||||
| install( | ||||
|   FILES ${libHeaders} | ||||
|   DESTINATION include/mist | ||||
|  |  | |||
|  | @ -4,6 +4,19 @@ | |||
| 
 | ||||
| namespace HTTP{ | ||||
| 
 | ||||
|   Downloader::Downloader(){ | ||||
|     progressCallback = 0; | ||||
|     connectedPort = 0; | ||||
|     ssl = false; | ||||
|     proxied = false; | ||||
|     char *p = getenv("http_proxy"); | ||||
|     if (p){ | ||||
|       proxyUrl = HTTP::URL(p); | ||||
|       proxied = true; | ||||
|       INFO_MSG("Proxying through %s", proxyUrl.getUrl().c_str()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /// Returns a reference to the internal HTTP::Parser body element
 | ||||
|   std::string &Downloader::data(){return H.body;} | ||||
| 
 | ||||
|  | @ -39,23 +52,57 @@ namespace HTTP{ | |||
|   Parser &Downloader::getHTTP(){return H;} | ||||
| 
 | ||||
|   /// Returns a reference to the internal Socket::Connection class instance.
 | ||||
|   Socket::Connection &Downloader::getSocket(){return S;} | ||||
|   Socket::Connection &Downloader::getSocket(){ | ||||
| #ifdef SSL | ||||
|     if (ssl){return S_SSL;} | ||||
| #endif | ||||
|     return S; | ||||
|   } | ||||
| 
 | ||||
|   /// Sends a request for the given URL, does no waiting.
 | ||||
|   void Downloader::doRequest(const HTTP::URL &link){ | ||||
|     if (link.protocol != "http"){ | ||||
|       FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); | ||||
|       return; | ||||
|     } | ||||
|   void Downloader::doRequest(const HTTP::URL &link, const std::string &method, const std::string &body){ | ||||
|     if (!canRequest(link)){return;} | ||||
|     bool needSSL = (link.protocol == "https"); | ||||
|     INFO_MSG("Retrieving %s", link.getUrl().c_str()); | ||||
|     H.Clean(); | ||||
|     // Reconnect if needed
 | ||||
|     if (!S || link.host != connectedHost || link.getPort() != connectedPort){ | ||||
|       S.close(); | ||||
|     if (!proxied || needSSL){ | ||||
|       if (!getSocket() || link.host != connectedHost || link.getPort() != connectedPort || | ||||
|           needSSL != ssl){ | ||||
|         getSocket().close(); | ||||
|         connectedHost = link.host; | ||||
|         connectedPort = link.getPort(); | ||||
| #ifdef SSL | ||||
|         if (needSSL){ | ||||
|           S_SSL = Socket::SSLConnection(connectedHost, connectedPort, true); | ||||
|         }else{ | ||||
|           S = Socket::Connection(connectedHost, connectedPort, true); | ||||
|         } | ||||
| #else | ||||
|         S = Socket::Connection(connectedHost, connectedPort, true); | ||||
| #endif | ||||
|       } | ||||
|     }else{ | ||||
|       if (!getSocket() || proxyUrl.host != connectedHost || proxyUrl.getPort() != connectedPort || | ||||
|           needSSL != ssl){ | ||||
|         getSocket().close(); | ||||
|         connectedHost = proxyUrl.host; | ||||
|         connectedPort = proxyUrl.getPort(); | ||||
|         S = Socket::Connection(connectedHost, connectedPort, true); | ||||
|       } | ||||
|     } | ||||
|     ssl = needSSL; | ||||
|     if (!getSocket()){ | ||||
|       return; // socket is closed
 | ||||
|     } | ||||
|     if (proxied && !ssl){ | ||||
|       H.url = link.getProxyUrl(); | ||||
|       if (proxyUrl.port.size()){ | ||||
|         H.SetHeader("Host", proxyUrl.host + ":" + proxyUrl.port); | ||||
|       }else{ | ||||
|         H.SetHeader("Host", proxyUrl.host); | ||||
|       } | ||||
|     }else{ | ||||
|       H.url = "/" + link.path; | ||||
|       if (link.args.size()){H.url += "?" + link.args;} | ||||
|       if (link.port.size()){ | ||||
|  | @ -63,34 +110,40 @@ namespace HTTP{ | |||
|       }else{ | ||||
|         H.SetHeader("Host", link.host); | ||||
|       } | ||||
|     } | ||||
|     if (method.size()){ | ||||
|       H.method = method; | ||||
|     } | ||||
|     H.SetHeader("User-Agent", "MistServer " PACKAGE_VERSION); | ||||
|     H.SetHeader("X-Version", PACKAGE_VERSION); | ||||
|     H.SetHeader("Accept", "*/*"); | ||||
|     if (authStr.size() && (link.user.size() || link.pass.size())){ | ||||
|       H.auth(link.user, link.pass, authStr); | ||||
|     } | ||||
|     if (proxyAuthStr.size() && (proxyUrl.user.size() || proxyUrl.pass.size())){ | ||||
|       H.auth(proxyUrl.user, proxyUrl.pass, proxyAuthStr, "Proxy-Authorization"); | ||||
|     } | ||||
|     if (extraHeaders.size()){ | ||||
|       for (std::map<std::string, std::string>::iterator it = extraHeaders.begin(); | ||||
|            it != extraHeaders.end(); ++it){ | ||||
|         H.SetHeader(it->first, it->second); | ||||
|       } | ||||
|     } | ||||
|     H.SendRequest(S); | ||||
|     H.SendRequest(getSocket(), body); | ||||
|     H.Clean(); | ||||
|   } | ||||
| 
 | ||||
|   /// Downloads the given URL into 'H', returns true on success.
 | ||||
|   /// Makes at most 5 attempts, and will wait no longer than 5 seconds without receiving data.
 | ||||
|   bool Downloader::get(const HTTP::URL &link, uint8_t maxRecursiveDepth){ | ||||
|     if (!link.host.size()){return false;} | ||||
|     if (link.protocol != "http"){ | ||||
|       FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); | ||||
|       return false; | ||||
|     } | ||||
|     if (!canRequest(link)){return false;} | ||||
|     unsigned int loop = 6; // max 5 attempts
 | ||||
|     while (--loop){// loop while we are unsuccessful
 | ||||
|       doRequest(link); | ||||
|       uint64_t reqTime = Util::bootSecs(); | ||||
|       while (S && Util::bootSecs() < reqTime + 5){ | ||||
|       while (getSocket() && Util::bootSecs() < reqTime + 5){ | ||||
|         // No data? Wait for a second or so.
 | ||||
|         if (!S.spool()){ | ||||
|         if (!getSocket().spool()){ | ||||
|           if (progressCallback != 0){ | ||||
|             if (!progressCallback()){ | ||||
|               WARN_MSG("Download aborted by callback"); | ||||
|  | @ -101,16 +154,17 @@ namespace HTTP{ | |||
|           continue; | ||||
|         } | ||||
|         // Data! Check if we can parse it...
 | ||||
|         if (H.Read(S)){ | ||||
|           if (getStatusCode() >= 300 && getStatusCode() < 400){ | ||||
|             // follow redirect
 | ||||
|             std::string location = getHeader("Location"); | ||||
|         if (H.Read(getSocket())){ | ||||
|           if (shouldContinue()){ | ||||
|             if (maxRecursiveDepth == 0){ | ||||
|               FAIL_MSG("Maximum redirect depth reached: %s", location.c_str()); | ||||
|               FAIL_MSG("Maximum recursion depth reached"); | ||||
|               return false; | ||||
|             } | ||||
|             if (!canContinue(link)){return false;} | ||||
|             if (getStatusCode() >= 300 && getStatusCode() < 400){ | ||||
|               return get(link.link(getHeader("Location")), --maxRecursiveDepth); | ||||
|             }else{ | ||||
|               FAIL_MSG("Following redirect to %s", location.c_str()); | ||||
|               return get(link.link(location), maxRecursiveDepth--); | ||||
|               return get(link, --maxRecursiveDepth); | ||||
|             } | ||||
|           } | ||||
|           return true; // Success!
 | ||||
|  | @ -118,7 +172,7 @@ namespace HTTP{ | |||
|         // reset the 5 second timeout
 | ||||
|         reqTime = Util::bootSecs(); | ||||
|       } | ||||
|       if (S){ | ||||
|       if (getSocket()){ | ||||
|         FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str()); | ||||
|         return false; | ||||
|       } | ||||
|  | @ -127,5 +181,124 @@ namespace HTTP{ | |||
|     FAIL_MSG("Could not retrieve %s", link.getUrl().c_str()); | ||||
|     return false; | ||||
|   } | ||||
|    | ||||
|   bool Downloader::post(const HTTP::URL &link, const std::string &payload, bool sync, uint8_t maxRecursiveDepth){ | ||||
|     if (!canRequest(link)){return false;} | ||||
|     unsigned int loop = 6; // max 5 attempts
 | ||||
|     while (--loop){// loop while we are unsuccessful
 | ||||
|       doRequest(link, "POST", payload); | ||||
|       //Not synced? Ignore the response and immediately return false.
 | ||||
|       if (!sync){return false;} | ||||
|       uint64_t reqTime = Util::bootSecs(); | ||||
|       while (getSocket() && Util::bootSecs() < reqTime + 5){ | ||||
|         // No data? Wait for a second or so.
 | ||||
|         if (!getSocket().spool()){ | ||||
|           if (progressCallback != 0){ | ||||
|             if (!progressCallback()){ | ||||
|               WARN_MSG("Download aborted by callback"); | ||||
|               return false; | ||||
|             } | ||||
|           } | ||||
|           Util::sleep(250); | ||||
|           continue; | ||||
|         } | ||||
|         // Data! Check if we can parse it...
 | ||||
|         if (H.Read(getSocket())){ | ||||
|           if (shouldContinue()){ | ||||
|             if (maxRecursiveDepth == 0){ | ||||
|               FAIL_MSG("Maximum recursion depth reached"); | ||||
|               return false; | ||||
|             } | ||||
|             if (!canContinue(link)){return false;} | ||||
|             if (getStatusCode() >= 300 && getStatusCode() < 400){ | ||||
|               return post(link.link(getHeader("Location")), payload, sync, --maxRecursiveDepth); | ||||
|             }else{ | ||||
|               return post(link, payload, sync, --maxRecursiveDepth); | ||||
|             } | ||||
|           } | ||||
|           return true; // Success!
 | ||||
|         } | ||||
|         // reset the 5 second timeout
 | ||||
|         reqTime = Util::bootSecs(); | ||||
|       } | ||||
|       if (getSocket()){ | ||||
|         FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str()); | ||||
|         return false; | ||||
|       } | ||||
|       Util::sleep(500); // wait a bit before retrying
 | ||||
|     } | ||||
|     FAIL_MSG("Could not retrieve %s", link.getUrl().c_str()); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   bool Downloader::canRequest(const HTTP::URL &link){ | ||||
|     if (!link.host.size()){return false;} | ||||
|     if (link.protocol != "http" && link.protocol != "https"){ | ||||
|       FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); | ||||
|       return false; | ||||
|     } | ||||
| #ifndef SSL | ||||
|     if (link.protocol == "https"){ | ||||
|       FAIL_MSG("Protocol not supported: %s", link.protocol.c_str()); | ||||
|       return false; | ||||
|     } | ||||
| #endif | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   bool Downloader::shouldContinue(){ | ||||
|     if (H.hasHeader("Set-Cookie")){ | ||||
|       std::string cookie = H.GetHeader("Set-Cookie"); | ||||
|       setHeader("Cookie", cookie.substr(0, cookie.find(';'))); | ||||
|     } | ||||
|     uint32_t sCode = getStatusCode(); | ||||
|     if (sCode == 401 || sCode == 407 || (sCode >= 300 && sCode < 400)){ | ||||
|       return true; | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   bool Downloader::canContinue(const HTTP::URL &link){ | ||||
|     if (getStatusCode() == 401){ | ||||
|       // retry with authentication
 | ||||
|       if (H.hasHeader("WWW-Authenticate")){authStr = H.GetHeader("WWW-Authenticate");} | ||||
|       if (H.hasHeader("Www-Authenticate")){authStr = H.GetHeader("Www-Authenticate");} | ||||
|       if (!authStr.size()){ | ||||
|         FAIL_MSG("Authentication required but no WWW-Authenticate header present"); | ||||
|         return false; | ||||
|       } | ||||
|       if (!link.user.size() && !link.pass.size()){ | ||||
|         FAIL_MSG("Authentication required but not included in URL"); | ||||
|         return false; | ||||
|       } | ||||
|       FAIL_MSG("Authenticating..."); | ||||
|       return true; | ||||
|     } | ||||
|     if (getStatusCode() == 407){ | ||||
|       // retry with authentication
 | ||||
|       if (H.hasHeader("Proxy-Authenticate")){ | ||||
|         proxyAuthStr = H.GetHeader("Proxy-Authenticate"); | ||||
|       } | ||||
|       if (!proxyAuthStr.size()){ | ||||
|         FAIL_MSG("Proxy authentication required but no Proxy-Authenticate header present"); | ||||
|         return false; | ||||
|       } | ||||
|       if (!proxyUrl.user.size() && !proxyUrl.pass.size()){ | ||||
|         FAIL_MSG("Proxy authentication required but not included in URL"); | ||||
|         return false; | ||||
|       } | ||||
|       FAIL_MSG("Authenticating proxy..."); | ||||
|       return true; | ||||
|     } | ||||
|     if (getStatusCode() >= 300 && getStatusCode() < 400){ | ||||
|       // follow redirect
 | ||||
|       std::string location = getHeader("Location"); | ||||
|       if (!location.size()){return false;} | ||||
|       INFO_MSG("Following redirect to %s", location.c_str()); | ||||
|       return true; | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
| }// namespace HTTP
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -4,18 +4,22 @@ | |||
| namespace HTTP{ | ||||
|   class Downloader{ | ||||
|   public: | ||||
|     Downloader(){progressCallback = 0;} | ||||
|     Downloader(); | ||||
|     std::string &data(); | ||||
|     void doRequest(const HTTP::URL &link); | ||||
|     void doRequest(const HTTP::URL &link, const std::string &method="", const std::string &body=""); | ||||
|     bool get(const std::string &link); | ||||
|     bool get(const HTTP::URL &link, uint8_t maxRecursiveDepth = 6); | ||||
|     bool post(const HTTP::URL &link, const std::string &payload, bool sync = true, uint8_t maxRecursiveDepth = 6); | ||||
|     std::string getHeader(const std::string &headerName); | ||||
|     std::string &getStatusText(); | ||||
|     uint32_t getStatusCode(); | ||||
|     bool isOk(); | ||||
|     bool isOk(); ///< True if the request was successful.
 | ||||
|     bool shouldContinue(); ///<True if the request should be followed-up with another. E.g. redirect or authenticate.
 | ||||
|     bool canContinue(const HTTP::URL &link);///<True if the request is able to continue, false if there is a state error or some such.
 | ||||
|     bool (*progressCallback)(); ///< Called every time the socket stalls, up to 4X per second.
 | ||||
|     void setHeader(const std::string &name, const std::string &val); | ||||
|     void clearHeaders(); | ||||
|     bool canRequest(const HTTP::URL &link); | ||||
|     Parser &getHTTP(); | ||||
|     Socket::Connection &getSocket(); | ||||
| 
 | ||||
|  | @ -25,6 +29,14 @@ namespace HTTP{ | |||
|     uint32_t connectedPort;                          ///< Currently connected port number
 | ||||
|     Parser H;                                        ///< HTTP parser for downloader
 | ||||
|     Socket::Connection S;                            ///< TCP socket for downloader
 | ||||
| #ifdef SSL | ||||
|     Socket::SSLConnection S_SSL; ///< SSL socket for downloader
 | ||||
| #endif | ||||
|     bool ssl;                 ///< True if ssl is currently in use.
 | ||||
|     std::string authStr;      ///< Most recently seen WWW-Authenticate request
 | ||||
|     std::string proxyAuthStr; ///< Most recently seen Proxy-Authenticate request
 | ||||
|     bool proxied;             ///< True if proxy server is configured.
 | ||||
|     HTTP::URL proxyUrl;       ///< Set to the URL of the configured proxy.
 | ||||
|   }; | ||||
| } | ||||
| }// namespace HTTP
 | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										201
									
								
								lib/socket.cpp
									
										
									
									
									
								
							
							
						
						
									
										201
									
								
								lib/socket.cpp
									
										
									
									
									
								
							|  | @ -811,6 +811,207 @@ bool Socket::Connection::isLocal(){ | |||
|   return false; | ||||
| } | ||||
| 
 | ||||
| #ifdef SSL | ||||
| Socket::SSLConnection::SSLConnection() : Socket::Connection::Connection(){ | ||||
|   isConnected = false; | ||||
|   server_fd = 0; | ||||
|   ssl = 0; | ||||
|   conf = 0; | ||||
|   ctr_drbg = 0; | ||||
|   entropy = 0; | ||||
| } | ||||
| 
 | ||||
| static void my_debug(void *ctx, int level, const char *file, int line, const char *str){ | ||||
|   ((void)level); | ||||
|   fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str); | ||||
|   fflush((FILE *)ctx); | ||||
| } | ||||
| 
 | ||||
| Socket::SSLConnection::SSLConnection(std::string hostname, int port, bool nonblock) : Socket::Connection(){ | ||||
|   mbedtls_debug_set_threshold(0); | ||||
|   isConnected = true; | ||||
|   server_fd = new mbedtls_net_context; | ||||
|   ssl = new mbedtls_ssl_context; | ||||
|   conf = new mbedtls_ssl_config; | ||||
|   ctr_drbg = new mbedtls_ctr_drbg_context; | ||||
|   entropy = new mbedtls_entropy_context; | ||||
|   mbedtls_net_init(server_fd); | ||||
|   mbedtls_ssl_init(ssl); | ||||
|   mbedtls_ssl_config_init(conf); | ||||
|   mbedtls_ctr_drbg_init(ctr_drbg); | ||||
|   mbedtls_entropy_init(entropy); | ||||
|   DONTEVEN_MSG("SSL init"); | ||||
|   if (mbedtls_ctr_drbg_seed(ctr_drbg, mbedtls_entropy_func, entropy, (const unsigned char*)"meow", 4) != 0){ | ||||
|     FAIL_MSG("SSL socket init failed"); | ||||
|     close(); | ||||
|     return; | ||||
|   } | ||||
|   DONTEVEN_MSG("SSL connect"); | ||||
|   int ret = 0; | ||||
|   if ((ret = mbedtls_net_connect(server_fd, hostname.c_str(), JSON::Value((long long)port).asString().c_str(), MBEDTLS_NET_PROTO_TCP)) != 0){ | ||||
|     FAIL_MSG(" failed\n  ! mbedtls_net_connect returned %d\n\n", ret); | ||||
|     close(); | ||||
|     return; | ||||
|   } | ||||
|   if ((ret = mbedtls_ssl_config_defaults(conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, | ||||
|                                          MBEDTLS_SSL_PRESET_DEFAULT)) != 0){ | ||||
|     FAIL_MSG(" failed\n  ! mbedtls_ssl_config_defaults returned %d\n\n", ret); | ||||
|     close(); | ||||
|     return; | ||||
|   } | ||||
|   mbedtls_ssl_conf_authmode(conf, MBEDTLS_SSL_VERIFY_NONE); | ||||
|   mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, ctr_drbg); | ||||
|   mbedtls_ssl_conf_dbg(conf, my_debug, stderr ); | ||||
|   if ((ret = mbedtls_ssl_setup(ssl, conf)) != 0){ | ||||
|       char estr[200]; | ||||
|       mbedtls_strerror(ret, estr, 200); | ||||
|       FAIL_MSG("SSL setup error %d: %s", ret, estr); | ||||
|       close(); | ||||
|       return; | ||||
|   } | ||||
|   if ((ret = mbedtls_ssl_set_hostname(ssl, hostname.c_str())) != 0){ | ||||
|     FAIL_MSG(" failed\n  ! mbedtls_ssl_set_hostname returned %d\n\n", ret); | ||||
|     close(); | ||||
|     return; | ||||
|   } | ||||
|   mbedtls_ssl_set_bio(ssl, server_fd, mbedtls_net_send, mbedtls_net_recv, NULL); | ||||
|   while ((ret = mbedtls_ssl_handshake(ssl)) != 0){ | ||||
|     if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE){ | ||||
|       char estr[200]; | ||||
|       mbedtls_strerror(ret, estr, 200); | ||||
|       FAIL_MSG("SSL handshake error %d: %s", ret, estr); | ||||
|       close(); | ||||
|       return; | ||||
|     } | ||||
|   } | ||||
|   Blocking = true; | ||||
|   if (nonblock){ | ||||
|     setBlocking(false); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| void Socket::SSLConnection::close(){ | ||||
|   DONTEVEN_MSG("SSL close"); | ||||
|   if (server_fd){ | ||||
|     mbedtls_net_free(server_fd); | ||||
|     delete server_fd; | ||||
|     server_fd = 0; | ||||
|   } | ||||
|   if (ssl){ | ||||
|     mbedtls_ssl_free(ssl); | ||||
|     delete ssl; | ||||
|     ssl = 0; | ||||
|   } | ||||
|   if (conf){ | ||||
|     mbedtls_ssl_config_free(conf); | ||||
|     delete conf; | ||||
|     conf = 0; | ||||
|   } | ||||
|   if (ctr_drbg){ | ||||
|     mbedtls_ctr_drbg_free(ctr_drbg); | ||||
|     delete ctr_drbg; | ||||
|     ctr_drbg = 0; | ||||
|   } | ||||
|   if (entropy){ | ||||
|     mbedtls_entropy_free(entropy); | ||||
|     delete entropy; | ||||
|     entropy = 0; | ||||
|   } | ||||
|   isConnected = false; | ||||
| } | ||||
| 
 | ||||
| /// Incremental read call. This function tries to read len bytes to the buffer from the socket,
 | ||||
| /// returning the amount of bytes it actually read.
 | ||||
| /// \param buffer Location of the buffer to read to.
 | ||||
| /// \param len Amount of bytes to read.
 | ||||
| /// \param flags Flags to use in the recv call. Ignored on fake sockets.
 | ||||
| /// \returns The amount of bytes actually read.
 | ||||
| int Socket::SSLConnection::iread(void *buffer, int len, int flags){ | ||||
|   DONTEVEN_MSG("SSL iread"); | ||||
|   if (!connected() || len < 1){return 0;} | ||||
|   int r; | ||||
|   /// \TODO Flags ignored... Bad.
 | ||||
|   r = mbedtls_ssl_read(ssl, (unsigned char*)buffer, len); | ||||
|   if (r < 0){ | ||||
|     char estr[200]; | ||||
|     mbedtls_strerror(r, estr, 200); | ||||
|     INFO_MSG("Read returns %d: %s", r, estr); | ||||
|   } | ||||
|   if (r < 0){ | ||||
|     switch (errno){ | ||||
|     case MBEDTLS_ERR_SSL_WANT_WRITE: return 0; break; | ||||
|     case MBEDTLS_ERR_SSL_WANT_READ: return 0; break; | ||||
|     case EWOULDBLOCK: return 0; break; | ||||
|     case EINTR: return 0; break; | ||||
|     default: | ||||
|       Error = true; | ||||
|       INSANE_MSG("Could not iread data! Error: %s", strerror(errno)); | ||||
|       close(); | ||||
|       return 0; | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|   if (r == 0){ | ||||
|     DONTEVEN_MSG("Socket closed by remote"); | ||||
|     close(); | ||||
|   } | ||||
|   down += r; | ||||
|   return r; | ||||
| } | ||||
| 
 | ||||
| /// Incremental write call. This function tries to write len bytes to the socket from the buffer,
 | ||||
| /// returning the amount of bytes it actually wrote.
 | ||||
| /// \param buffer Location of the buffer to write from.
 | ||||
| /// \param len Amount of bytes to write.
 | ||||
| /// \returns The amount of bytes actually written.
 | ||||
| unsigned int Socket::SSLConnection::iwrite(const void *buffer, int len){ | ||||
|   DONTEVEN_MSG("SSL iwrite"); | ||||
|   if (!connected() || len < 1){return 0;} | ||||
|   int r; | ||||
|   r = mbedtls_ssl_write(ssl, (const unsigned char*)buffer, len); | ||||
|   if (r < 0){ | ||||
|     char estr[200]; | ||||
|     mbedtls_strerror(r, estr, 200); | ||||
|     INFO_MSG("Write returns %d: %s", r, estr); | ||||
|   } | ||||
|   if (r < 0){ | ||||
|     switch (errno){ | ||||
|     case MBEDTLS_ERR_SSL_WANT_WRITE: return 0; break; | ||||
|     case MBEDTLS_ERR_SSL_WANT_READ: return 0; break; | ||||
|     case EWOULDBLOCK: return 0; break; | ||||
|     default: | ||||
|       Error = true; | ||||
|       INSANE_MSG("Could not iwrite data! Error: %s", strerror(errno)); | ||||
|       close(); | ||||
|       return 0; | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|   if (r == 0 && (sock >= 0)){ | ||||
|     DONTEVEN_MSG("Socket closed by remote"); | ||||
|     close(); | ||||
|   } | ||||
|   up += r; | ||||
|   return r; | ||||
| } | ||||
| 
 | ||||
| bool Socket::SSLConnection::connected() const{ | ||||
|   return isConnected; | ||||
| } | ||||
| 
 | ||||
| void Socket::SSLConnection::setBlocking(bool blocking){ | ||||
|   if (blocking != Blocking){return;} | ||||
|   if (blocking){ | ||||
|     mbedtls_net_set_block(server_fd); | ||||
|     Blocking = true; | ||||
|   }else{ | ||||
|     mbedtls_net_set_nonblock(server_fd); | ||||
|     Blocking = false; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| 
 | ||||
| /// Create a new base Server. The socket is never connected, and a placeholder for later connections.
 | ||||
| Socket::Server::Server(){ | ||||
|   sock = -1; | ||||
|  |  | |||
							
								
								
									
										42
									
								
								lib/socket.h
									
										
									
									
									
								
							
							
						
						
									
										42
									
								
								lib/socket.h
									
										
									
									
									
								
							|  | @ -16,6 +16,15 @@ | |||
| #include <sys/un.h> | ||||
| #include <unistd.h> | ||||
| 
 | ||||
| #ifdef SSL | ||||
| #include "mbedtls/net.h" | ||||
| #include "mbedtls/ssl.h" | ||||
| #include "mbedtls/entropy.h" | ||||
| #include "mbedtls/ctr_drbg.h" | ||||
| #include "mbedtls/debug.h" | ||||
| #include "mbedtls/error.h" | ||||
| #endif | ||||
| 
 | ||||
| // for being friendly with Socket::Connection down below
 | ||||
| namespace Buffer{ | ||||
|   class user; | ||||
|  | @ -54,7 +63,7 @@ namespace Socket{ | |||
| 
 | ||||
|   /// This class is for easy communicating through sockets, either TCP or Unix.
 | ||||
|   class Connection{ | ||||
|   private: | ||||
|   protected: | ||||
|     int sock;               ///< Internally saved socket number.
 | ||||
|     int pipes[2];           ///< Internally saved file descriptors for pipe socket simulation.
 | ||||
|     std::string remotehost; ///< Stores remote host address.
 | ||||
|  | @ -63,8 +72,8 @@ namespace Socket{ | |||
|     uint64_t down; | ||||
|     long long int conntime; | ||||
|     Buffer downbuffer;                                ///< Stores temporary data coming in.
 | ||||
|     int iread(void *buffer, int len, int flags = 0);  ///< Incremental read call.
 | ||||
|     unsigned int iwrite(const void *buffer, int len); ///< Incremental write call.
 | ||||
|     virtual int iread(void *buffer, int len, int flags = 0);  ///< Incremental read call.
 | ||||
|     virtual unsigned int iwrite(const void *buffer, int len); ///< Incremental write call.
 | ||||
|     bool iread(Buffer &buffer, int flags = 0);        ///< Incremental write call that is compatible with Socket::Buffer.
 | ||||
|     bool iwrite(std::string &buffer);                 ///< Write call that is compatible with std::string.
 | ||||
|   public: | ||||
|  | @ -77,9 +86,9 @@ namespace Socket{ | |||
|     Connection(std::string adres, bool nonblock = false);      ///< Create a new Unix Socket.
 | ||||
|     Connection(int write, int read);                           ///< Simulate a socket using two file descriptors.
 | ||||
|     // generic methods
 | ||||
|     void close();                    ///< Close connection.
 | ||||
|     virtual void close();                    ///< Close connection.
 | ||||
|     void drop();                     ///< Close connection without shutdown.
 | ||||
|     void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false).
 | ||||
|     virtual void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false).
 | ||||
|     bool isBlocking();               ///< Check if this socket is blocking (true) or nonblocking (false).
 | ||||
|     std::string getHost() const;     ///< Gets hostname for connection, if available.
 | ||||
|     std::string getBinHost(); | ||||
|  | @ -87,7 +96,7 @@ namespace Socket{ | |||
|     int getSocket();                ///< Returns internal socket number.
 | ||||
|     int getPureSocket();            ///< Returns non-piped internal socket number.
 | ||||
|     std::string getError();         ///< Returns a string describing the last error that occured.
 | ||||
|     bool connected() const;         ///< Returns the connected-state for this socket.
 | ||||
|     virtual bool connected() const;         ///< Returns the connected-state for this socket.
 | ||||
|     bool isAddress(const std::string &addr); | ||||
|     bool isLocal(); ///< Returns true if remote address is a local address.
 | ||||
|     // buffered i/o methods
 | ||||
|  | @ -114,6 +123,27 @@ namespace Socket{ | |||
|     operator bool() const; | ||||
|   }; | ||||
| 
 | ||||
| #ifdef SSL | ||||
|   /// Version of Socket::Connection that uses mbedtls for SSL
 | ||||
|   class SSLConnection : public Connection{ | ||||
|     public: | ||||
|       SSLConnection(); | ||||
|       SSLConnection(std::string hostname, int port, bool nonblock); ///< Create a new TCP socket.
 | ||||
|       void close();                    ///< Close connection.
 | ||||
|       bool connected() const;         ///< Returns the connected-state for this socket.
 | ||||
|       void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false).
 | ||||
|     protected: | ||||
|       bool isConnected; | ||||
|       int iread(void *buffer, int len, int flags = 0);  ///< Incremental read call.
 | ||||
|       unsigned int iwrite(const void *buffer, int len); ///< Incremental write call.
 | ||||
|       mbedtls_net_context * server_fd; | ||||
|       mbedtls_entropy_context * entropy; | ||||
|       mbedtls_ctr_drbg_context * ctr_drbg; | ||||
|       mbedtls_ssl_context * ssl; | ||||
|       mbedtls_ssl_config * conf; | ||||
|   }; | ||||
| #endif | ||||
| 
 | ||||
|   /// This class is for easily setting up listening socket, either TCP or Unix.
 | ||||
|   class Server{ | ||||
|   private: | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Thulinma
						Thulinma