SSL socket class, downloadertest application, HTTP::Downloader support for HTTPS connections, authentication, proxies and POST requests
This commit is contained in:
parent
ce9aae3095
commit
5e3df09831
6 changed files with 467 additions and 44 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -59,4 +59,5 @@ rules.ninja
|
||||||
.ninja_deps
|
.ninja_deps
|
||||||
aes_ctr128
|
aes_ctr128
|
||||||
/embed/testing
|
/embed/testing
|
||||||
|
*test
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,9 @@ endif()
|
||||||
if (DEFINED BIGMETA )
|
if (DEFINED BIGMETA )
|
||||||
add_definitions(-DBIGMETA=1)
|
add_definitions(-DBIGMETA=1)
|
||||||
endif()
|
endif()
|
||||||
|
if (NOT DEFINED NOSSL )
|
||||||
|
add_definitions(-DSSL=1)
|
||||||
|
endif()
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
# Build Variables - Thread Names #
|
# Build Variables - Thread Names #
|
||||||
|
@ -177,6 +180,9 @@ target_link_libraries(mist
|
||||||
-lpthread
|
-lpthread
|
||||||
${LIBRT}
|
${LIBRT}
|
||||||
)
|
)
|
||||||
|
if (NOT DEFINED NOSSL )
|
||||||
|
target_link_libraries(mist mbedtls mbedx509 mbedcrypto)
|
||||||
|
endif()
|
||||||
install(
|
install(
|
||||||
FILES ${libHeaders}
|
FILES ${libHeaders}
|
||||||
DESTINATION include/mist
|
DESTINATION include/mist
|
||||||
|
|
|
@ -4,6 +4,19 @@
|
||||||
|
|
||||||
namespace HTTP{
|
namespace HTTP{
|
||||||
|
|
||||||
|
Downloader::Downloader(){
|
||||||
|
progressCallback = 0;
|
||||||
|
connectedPort = 0;
|
||||||
|
ssl = false;
|
||||||
|
proxied = false;
|
||||||
|
char *p = getenv("http_proxy");
|
||||||
|
if (p){
|
||||||
|
proxyUrl = HTTP::URL(p);
|
||||||
|
proxied = true;
|
||||||
|
INFO_MSG("Proxying through %s", proxyUrl.getUrl().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a reference to the internal HTTP::Parser body element
|
/// Returns a reference to the internal HTTP::Parser body element
|
||||||
std::string &Downloader::data(){return H.body;}
|
std::string &Downloader::data(){return H.body;}
|
||||||
|
|
||||||
|
@ -39,58 +52,98 @@ namespace HTTP{
|
||||||
Parser &Downloader::getHTTP(){return H;}
|
Parser &Downloader::getHTTP(){return H;}
|
||||||
|
|
||||||
/// Returns a reference to the internal Socket::Connection class instance.
|
/// Returns a reference to the internal Socket::Connection class instance.
|
||||||
Socket::Connection &Downloader::getSocket(){return S;}
|
Socket::Connection &Downloader::getSocket(){
|
||||||
|
#ifdef SSL
|
||||||
|
if (ssl){return S_SSL;}
|
||||||
|
#endif
|
||||||
|
return S;
|
||||||
|
}
|
||||||
|
|
||||||
/// Sends a request for the given URL, does no waiting.
|
/// Sends a request for the given URL, does no waiting.
|
||||||
void Downloader::doRequest(const HTTP::URL &link){
|
void Downloader::doRequest(const HTTP::URL &link, const std::string &method, const std::string &body){
|
||||||
if (link.protocol != "http"){
|
if (!canRequest(link)){return;}
|
||||||
FAIL_MSG("Protocol not supported: %s", link.protocol.c_str());
|
bool needSSL = (link.protocol == "https");
|
||||||
return;
|
|
||||||
}
|
|
||||||
INFO_MSG("Retrieving %s", link.getUrl().c_str());
|
INFO_MSG("Retrieving %s", link.getUrl().c_str());
|
||||||
H.Clean();
|
H.Clean();
|
||||||
// Reconnect if needed
|
// Reconnect if needed
|
||||||
if (!S || link.host != connectedHost || link.getPort() != connectedPort){
|
if (!proxied || needSSL){
|
||||||
S.close();
|
if (!getSocket() || link.host != connectedHost || link.getPort() != connectedPort ||
|
||||||
connectedHost = link.host;
|
needSSL != ssl){
|
||||||
connectedPort = link.getPort();
|
getSocket().close();
|
||||||
S = Socket::Connection(connectedHost, connectedPort, true);
|
connectedHost = link.host;
|
||||||
}
|
connectedPort = link.getPort();
|
||||||
H.url = "/" + link.path;
|
#ifdef SSL
|
||||||
if (link.args.size()){H.url += "?" + link.args;}
|
if (needSSL){
|
||||||
if (link.port.size()){
|
S_SSL = Socket::SSLConnection(connectedHost, connectedPort, true);
|
||||||
H.SetHeader("Host", link.host + ":" + link.port);
|
}else{
|
||||||
|
S = Socket::Connection(connectedHost, connectedPort, true);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
S = Socket::Connection(connectedHost, connectedPort, true);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
}else{
|
}else{
|
||||||
H.SetHeader("Host", link.host);
|
if (!getSocket() || proxyUrl.host != connectedHost || proxyUrl.getPort() != connectedPort ||
|
||||||
|
needSSL != ssl){
|
||||||
|
getSocket().close();
|
||||||
|
connectedHost = proxyUrl.host;
|
||||||
|
connectedPort = proxyUrl.getPort();
|
||||||
|
S = Socket::Connection(connectedHost, connectedPort, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ssl = needSSL;
|
||||||
|
if (!getSocket()){
|
||||||
|
return; // socket is closed
|
||||||
|
}
|
||||||
|
if (proxied && !ssl){
|
||||||
|
H.url = link.getProxyUrl();
|
||||||
|
if (proxyUrl.port.size()){
|
||||||
|
H.SetHeader("Host", proxyUrl.host + ":" + proxyUrl.port);
|
||||||
|
}else{
|
||||||
|
H.SetHeader("Host", proxyUrl.host);
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
H.url = "/" + link.path;
|
||||||
|
if (link.args.size()){H.url += "?" + link.args;}
|
||||||
|
if (link.port.size()){
|
||||||
|
H.SetHeader("Host", link.host + ":" + link.port);
|
||||||
|
}else{
|
||||||
|
H.SetHeader("Host", link.host);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (method.size()){
|
||||||
|
H.method = method;
|
||||||
}
|
}
|
||||||
H.SetHeader("User-Agent", "MistServer " PACKAGE_VERSION);
|
H.SetHeader("User-Agent", "MistServer " PACKAGE_VERSION);
|
||||||
H.SetHeader("X-Version", PACKAGE_VERSION);
|
H.SetHeader("X-Version", PACKAGE_VERSION);
|
||||||
H.SetHeader("Accept", "*/*");
|
H.SetHeader("Accept", "*/*");
|
||||||
|
if (authStr.size() && (link.user.size() || link.pass.size())){
|
||||||
|
H.auth(link.user, link.pass, authStr);
|
||||||
|
}
|
||||||
|
if (proxyAuthStr.size() && (proxyUrl.user.size() || proxyUrl.pass.size())){
|
||||||
|
H.auth(proxyUrl.user, proxyUrl.pass, proxyAuthStr, "Proxy-Authorization");
|
||||||
|
}
|
||||||
if (extraHeaders.size()){
|
if (extraHeaders.size()){
|
||||||
for (std::map<std::string, std::string>::iterator it = extraHeaders.begin();
|
for (std::map<std::string, std::string>::iterator it = extraHeaders.begin();
|
||||||
it != extraHeaders.end(); ++it){
|
it != extraHeaders.end(); ++it){
|
||||||
H.SetHeader(it->first, it->second);
|
H.SetHeader(it->first, it->second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
H.SendRequest(S);
|
H.SendRequest(getSocket(), body);
|
||||||
H.Clean();
|
H.Clean();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Downloads the given URL into 'H', returns true on success.
|
/// Downloads the given URL into 'H', returns true on success.
|
||||||
/// Makes at most 5 attempts, and will wait no longer than 5 seconds without receiving data.
|
/// Makes at most 5 attempts, and will wait no longer than 5 seconds without receiving data.
|
||||||
bool Downloader::get(const HTTP::URL &link, uint8_t maxRecursiveDepth){
|
bool Downloader::get(const HTTP::URL &link, uint8_t maxRecursiveDepth){
|
||||||
if (!link.host.size()){return false;}
|
if (!canRequest(link)){return false;}
|
||||||
if (link.protocol != "http"){
|
|
||||||
FAIL_MSG("Protocol not supported: %s", link.protocol.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
unsigned int loop = 6; // max 5 attempts
|
unsigned int loop = 6; // max 5 attempts
|
||||||
while (--loop){// loop while we are unsuccessful
|
while (--loop){// loop while we are unsuccessful
|
||||||
doRequest(link);
|
doRequest(link);
|
||||||
uint64_t reqTime = Util::bootSecs();
|
uint64_t reqTime = Util::bootSecs();
|
||||||
while (S && Util::bootSecs() < reqTime + 5){
|
while (getSocket() && Util::bootSecs() < reqTime + 5){
|
||||||
// No data? Wait for a second or so.
|
// No data? Wait for a second or so.
|
||||||
if (!S.spool()){
|
if (!getSocket().spool()){
|
||||||
if (progressCallback != 0){
|
if (progressCallback != 0){
|
||||||
if (!progressCallback()){
|
if (!progressCallback()){
|
||||||
WARN_MSG("Download aborted by callback");
|
WARN_MSG("Download aborted by callback");
|
||||||
|
@ -101,16 +154,17 @@ namespace HTTP{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Data! Check if we can parse it...
|
// Data! Check if we can parse it...
|
||||||
if (H.Read(S)){
|
if (H.Read(getSocket())){
|
||||||
if (getStatusCode() >= 300 && getStatusCode() < 400){
|
if (shouldContinue()){
|
||||||
// follow redirect
|
|
||||||
std::string location = getHeader("Location");
|
|
||||||
if (maxRecursiveDepth == 0){
|
if (maxRecursiveDepth == 0){
|
||||||
FAIL_MSG("Maximum redirect depth reached: %s", location.c_str());
|
FAIL_MSG("Maximum recursion depth reached");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
if (!canContinue(link)){return false;}
|
||||||
|
if (getStatusCode() >= 300 && getStatusCode() < 400){
|
||||||
|
return get(link.link(getHeader("Location")), --maxRecursiveDepth);
|
||||||
}else{
|
}else{
|
||||||
FAIL_MSG("Following redirect to %s", location.c_str());
|
return get(link, --maxRecursiveDepth);
|
||||||
return get(link.link(location), maxRecursiveDepth--);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true; // Success!
|
return true; // Success!
|
||||||
|
@ -118,7 +172,7 @@ namespace HTTP{
|
||||||
// reset the 5 second timeout
|
// reset the 5 second timeout
|
||||||
reqTime = Util::bootSecs();
|
reqTime = Util::bootSecs();
|
||||||
}
|
}
|
||||||
if (S){
|
if (getSocket()){
|
||||||
FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str());
|
FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -127,5 +181,124 @@ namespace HTTP{
|
||||||
FAIL_MSG("Could not retrieve %s", link.getUrl().c_str());
|
FAIL_MSG("Could not retrieve %s", link.getUrl().c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
bool Downloader::post(const HTTP::URL &link, const std::string &payload, bool sync, uint8_t maxRecursiveDepth){
|
||||||
|
if (!canRequest(link)){return false;}
|
||||||
|
unsigned int loop = 6; // max 5 attempts
|
||||||
|
while (--loop){// loop while we are unsuccessful
|
||||||
|
doRequest(link, "POST", payload);
|
||||||
|
//Not synced? Ignore the response and immediately return false.
|
||||||
|
if (!sync){return false;}
|
||||||
|
uint64_t reqTime = Util::bootSecs();
|
||||||
|
while (getSocket() && Util::bootSecs() < reqTime + 5){
|
||||||
|
// No data? Wait for a second or so.
|
||||||
|
if (!getSocket().spool()){
|
||||||
|
if (progressCallback != 0){
|
||||||
|
if (!progressCallback()){
|
||||||
|
WARN_MSG("Download aborted by callback");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Util::sleep(250);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Data! Check if we can parse it...
|
||||||
|
if (H.Read(getSocket())){
|
||||||
|
if (shouldContinue()){
|
||||||
|
if (maxRecursiveDepth == 0){
|
||||||
|
FAIL_MSG("Maximum recursion depth reached");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!canContinue(link)){return false;}
|
||||||
|
if (getStatusCode() >= 300 && getStatusCode() < 400){
|
||||||
|
return post(link.link(getHeader("Location")), payload, sync, --maxRecursiveDepth);
|
||||||
|
}else{
|
||||||
|
return post(link, payload, sync, --maxRecursiveDepth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true; // Success!
|
||||||
|
}
|
||||||
|
// reset the 5 second timeout
|
||||||
|
reqTime = Util::bootSecs();
|
||||||
|
}
|
||||||
|
if (getSocket()){
|
||||||
|
FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Util::sleep(500); // wait a bit before retrying
|
||||||
|
}
|
||||||
|
FAIL_MSG("Could not retrieve %s", link.getUrl().c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Downloader::canRequest(const HTTP::URL &link){
|
||||||
|
if (!link.host.size()){return false;}
|
||||||
|
if (link.protocol != "http" && link.protocol != "https"){
|
||||||
|
FAIL_MSG("Protocol not supported: %s", link.protocol.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#ifndef SSL
|
||||||
|
if (link.protocol == "https"){
|
||||||
|
FAIL_MSG("Protocol not supported: %s", link.protocol.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Downloader::shouldContinue(){
|
||||||
|
if (H.hasHeader("Set-Cookie")){
|
||||||
|
std::string cookie = H.GetHeader("Set-Cookie");
|
||||||
|
setHeader("Cookie", cookie.substr(0, cookie.find(';')));
|
||||||
|
}
|
||||||
|
uint32_t sCode = getStatusCode();
|
||||||
|
if (sCode == 401 || sCode == 407 || (sCode >= 300 && sCode < 400)){
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Downloader::canContinue(const HTTP::URL &link){
|
||||||
|
if (getStatusCode() == 401){
|
||||||
|
// retry with authentication
|
||||||
|
if (H.hasHeader("WWW-Authenticate")){authStr = H.GetHeader("WWW-Authenticate");}
|
||||||
|
if (H.hasHeader("Www-Authenticate")){authStr = H.GetHeader("Www-Authenticate");}
|
||||||
|
if (!authStr.size()){
|
||||||
|
FAIL_MSG("Authentication required but no WWW-Authenticate header present");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!link.user.size() && !link.pass.size()){
|
||||||
|
FAIL_MSG("Authentication required but not included in URL");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
FAIL_MSG("Authenticating...");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (getStatusCode() == 407){
|
||||||
|
// retry with authentication
|
||||||
|
if (H.hasHeader("Proxy-Authenticate")){
|
||||||
|
proxyAuthStr = H.GetHeader("Proxy-Authenticate");
|
||||||
|
}
|
||||||
|
if (!proxyAuthStr.size()){
|
||||||
|
FAIL_MSG("Proxy authentication required but no Proxy-Authenticate header present");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!proxyUrl.user.size() && !proxyUrl.pass.size()){
|
||||||
|
FAIL_MSG("Proxy authentication required but not included in URL");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
FAIL_MSG("Authenticating proxy...");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (getStatusCode() >= 300 && getStatusCode() < 400){
|
||||||
|
// follow redirect
|
||||||
|
std::string location = getHeader("Location");
|
||||||
|
if (!location.size()){return false;}
|
||||||
|
INFO_MSG("Following redirect to %s", location.c_str());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
}// namespace HTTP
|
||||||
|
|
||||||
|
|
|
@ -4,18 +4,22 @@
|
||||||
namespace HTTP{
|
namespace HTTP{
|
||||||
class Downloader{
|
class Downloader{
|
||||||
public:
|
public:
|
||||||
Downloader(){progressCallback = 0;}
|
Downloader();
|
||||||
std::string &data();
|
std::string &data();
|
||||||
void doRequest(const HTTP::URL &link);
|
void doRequest(const HTTP::URL &link, const std::string &method="", const std::string &body="");
|
||||||
bool get(const std::string &link);
|
bool get(const std::string &link);
|
||||||
bool get(const HTTP::URL &link, uint8_t maxRecursiveDepth = 6);
|
bool get(const HTTP::URL &link, uint8_t maxRecursiveDepth = 6);
|
||||||
|
bool post(const HTTP::URL &link, const std::string &payload, bool sync = true, uint8_t maxRecursiveDepth = 6);
|
||||||
std::string getHeader(const std::string &headerName);
|
std::string getHeader(const std::string &headerName);
|
||||||
std::string &getStatusText();
|
std::string &getStatusText();
|
||||||
uint32_t getStatusCode();
|
uint32_t getStatusCode();
|
||||||
bool isOk();
|
bool isOk(); ///< True if the request was successful.
|
||||||
|
bool shouldContinue(); ///<True if the request should be followed-up with another. E.g. redirect or authenticate.
|
||||||
|
bool canContinue(const HTTP::URL &link);///<True if the request is able to continue, false if there is a state error or some such.
|
||||||
bool (*progressCallback)(); ///< Called every time the socket stalls, up to 4X per second.
|
bool (*progressCallback)(); ///< Called every time the socket stalls, up to 4X per second.
|
||||||
void setHeader(const std::string &name, const std::string &val);
|
void setHeader(const std::string &name, const std::string &val);
|
||||||
void clearHeaders();
|
void clearHeaders();
|
||||||
|
bool canRequest(const HTTP::URL &link);
|
||||||
Parser &getHTTP();
|
Parser &getHTTP();
|
||||||
Socket::Connection &getSocket();
|
Socket::Connection &getSocket();
|
||||||
|
|
||||||
|
@ -25,6 +29,14 @@ namespace HTTP{
|
||||||
uint32_t connectedPort; ///< Currently connected port number
|
uint32_t connectedPort; ///< Currently connected port number
|
||||||
Parser H; ///< HTTP parser for downloader
|
Parser H; ///< HTTP parser for downloader
|
||||||
Socket::Connection S; ///< TCP socket for downloader
|
Socket::Connection S; ///< TCP socket for downloader
|
||||||
|
#ifdef SSL
|
||||||
|
Socket::SSLConnection S_SSL; ///< SSL socket for downloader
|
||||||
|
#endif
|
||||||
|
bool ssl; ///< True if ssl is currently in use.
|
||||||
|
std::string authStr; ///< Most recently seen WWW-Authenticate request
|
||||||
|
std::string proxyAuthStr; ///< Most recently seen Proxy-Authenticate request
|
||||||
|
bool proxied; ///< True if proxy server is configured.
|
||||||
|
HTTP::URL proxyUrl; ///< Set to the URL of the configured proxy.
|
||||||
};
|
};
|
||||||
}
|
}// namespace HTTP
|
||||||
|
|
||||||
|
|
201
lib/socket.cpp
201
lib/socket.cpp
|
@ -811,6 +811,207 @@ bool Socket::Connection::isLocal(){
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef SSL
|
||||||
|
Socket::SSLConnection::SSLConnection() : Socket::Connection::Connection(){
|
||||||
|
isConnected = false;
|
||||||
|
server_fd = 0;
|
||||||
|
ssl = 0;
|
||||||
|
conf = 0;
|
||||||
|
ctr_drbg = 0;
|
||||||
|
entropy = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void my_debug(void *ctx, int level, const char *file, int line, const char *str){
|
||||||
|
((void)level);
|
||||||
|
fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str);
|
||||||
|
fflush((FILE *)ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
Socket::SSLConnection::SSLConnection(std::string hostname, int port, bool nonblock) : Socket::Connection(){
|
||||||
|
mbedtls_debug_set_threshold(0);
|
||||||
|
isConnected = true;
|
||||||
|
server_fd = new mbedtls_net_context;
|
||||||
|
ssl = new mbedtls_ssl_context;
|
||||||
|
conf = new mbedtls_ssl_config;
|
||||||
|
ctr_drbg = new mbedtls_ctr_drbg_context;
|
||||||
|
entropy = new mbedtls_entropy_context;
|
||||||
|
mbedtls_net_init(server_fd);
|
||||||
|
mbedtls_ssl_init(ssl);
|
||||||
|
mbedtls_ssl_config_init(conf);
|
||||||
|
mbedtls_ctr_drbg_init(ctr_drbg);
|
||||||
|
mbedtls_entropy_init(entropy);
|
||||||
|
DONTEVEN_MSG("SSL init");
|
||||||
|
if (mbedtls_ctr_drbg_seed(ctr_drbg, mbedtls_entropy_func, entropy, (const unsigned char*)"meow", 4) != 0){
|
||||||
|
FAIL_MSG("SSL socket init failed");
|
||||||
|
close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
DONTEVEN_MSG("SSL connect");
|
||||||
|
int ret = 0;
|
||||||
|
if ((ret = mbedtls_net_connect(server_fd, hostname.c_str(), JSON::Value((long long)port).asString().c_str(), MBEDTLS_NET_PROTO_TCP)) != 0){
|
||||||
|
FAIL_MSG(" failed\n ! mbedtls_net_connect returned %d\n\n", ret);
|
||||||
|
close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if ((ret = mbedtls_ssl_config_defaults(conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM,
|
||||||
|
MBEDTLS_SSL_PRESET_DEFAULT)) != 0){
|
||||||
|
FAIL_MSG(" failed\n ! mbedtls_ssl_config_defaults returned %d\n\n", ret);
|
||||||
|
close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mbedtls_ssl_conf_authmode(conf, MBEDTLS_SSL_VERIFY_NONE);
|
||||||
|
mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, ctr_drbg);
|
||||||
|
mbedtls_ssl_conf_dbg(conf, my_debug, stderr );
|
||||||
|
if ((ret = mbedtls_ssl_setup(ssl, conf)) != 0){
|
||||||
|
char estr[200];
|
||||||
|
mbedtls_strerror(ret, estr, 200);
|
||||||
|
FAIL_MSG("SSL setup error %d: %s", ret, estr);
|
||||||
|
close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if ((ret = mbedtls_ssl_set_hostname(ssl, hostname.c_str())) != 0){
|
||||||
|
FAIL_MSG(" failed\n ! mbedtls_ssl_set_hostname returned %d\n\n", ret);
|
||||||
|
close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mbedtls_ssl_set_bio(ssl, server_fd, mbedtls_net_send, mbedtls_net_recv, NULL);
|
||||||
|
while ((ret = mbedtls_ssl_handshake(ssl)) != 0){
|
||||||
|
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE){
|
||||||
|
char estr[200];
|
||||||
|
mbedtls_strerror(ret, estr, 200);
|
||||||
|
FAIL_MSG("SSL handshake error %d: %s", ret, estr);
|
||||||
|
close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Blocking = true;
|
||||||
|
if (nonblock){
|
||||||
|
setBlocking(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Socket::SSLConnection::close(){
|
||||||
|
DONTEVEN_MSG("SSL close");
|
||||||
|
if (server_fd){
|
||||||
|
mbedtls_net_free(server_fd);
|
||||||
|
delete server_fd;
|
||||||
|
server_fd = 0;
|
||||||
|
}
|
||||||
|
if (ssl){
|
||||||
|
mbedtls_ssl_free(ssl);
|
||||||
|
delete ssl;
|
||||||
|
ssl = 0;
|
||||||
|
}
|
||||||
|
if (conf){
|
||||||
|
mbedtls_ssl_config_free(conf);
|
||||||
|
delete conf;
|
||||||
|
conf = 0;
|
||||||
|
}
|
||||||
|
if (ctr_drbg){
|
||||||
|
mbedtls_ctr_drbg_free(ctr_drbg);
|
||||||
|
delete ctr_drbg;
|
||||||
|
ctr_drbg = 0;
|
||||||
|
}
|
||||||
|
if (entropy){
|
||||||
|
mbedtls_entropy_free(entropy);
|
||||||
|
delete entropy;
|
||||||
|
entropy = 0;
|
||||||
|
}
|
||||||
|
isConnected = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Incremental read call. This function tries to read len bytes to the buffer from the socket,
|
||||||
|
/// returning the amount of bytes it actually read.
|
||||||
|
/// \param buffer Location of the buffer to read to.
|
||||||
|
/// \param len Amount of bytes to read.
|
||||||
|
/// \param flags Flags to use in the recv call. Ignored on fake sockets.
|
||||||
|
/// \returns The amount of bytes actually read.
|
||||||
|
int Socket::SSLConnection::iread(void *buffer, int len, int flags){
|
||||||
|
DONTEVEN_MSG("SSL iread");
|
||||||
|
if (!connected() || len < 1){return 0;}
|
||||||
|
int r;
|
||||||
|
/// \TODO Flags ignored... Bad.
|
||||||
|
r = mbedtls_ssl_read(ssl, (unsigned char*)buffer, len);
|
||||||
|
if (r < 0){
|
||||||
|
char estr[200];
|
||||||
|
mbedtls_strerror(r, estr, 200);
|
||||||
|
INFO_MSG("Read returns %d: %s", r, estr);
|
||||||
|
}
|
||||||
|
if (r < 0){
|
||||||
|
switch (errno){
|
||||||
|
case MBEDTLS_ERR_SSL_WANT_WRITE: return 0; break;
|
||||||
|
case MBEDTLS_ERR_SSL_WANT_READ: return 0; break;
|
||||||
|
case EWOULDBLOCK: return 0; break;
|
||||||
|
case EINTR: return 0; break;
|
||||||
|
default:
|
||||||
|
Error = true;
|
||||||
|
INSANE_MSG("Could not iread data! Error: %s", strerror(errno));
|
||||||
|
close();
|
||||||
|
return 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (r == 0){
|
||||||
|
DONTEVEN_MSG("Socket closed by remote");
|
||||||
|
close();
|
||||||
|
}
|
||||||
|
down += r;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Incremental write call. This function tries to write len bytes to the socket from the buffer,
|
||||||
|
/// returning the amount of bytes it actually wrote.
|
||||||
|
/// \param buffer Location of the buffer to write from.
|
||||||
|
/// \param len Amount of bytes to write.
|
||||||
|
/// \returns The amount of bytes actually written.
|
||||||
|
unsigned int Socket::SSLConnection::iwrite(const void *buffer, int len){
|
||||||
|
DONTEVEN_MSG("SSL iwrite");
|
||||||
|
if (!connected() || len < 1){return 0;}
|
||||||
|
int r;
|
||||||
|
r = mbedtls_ssl_write(ssl, (const unsigned char*)buffer, len);
|
||||||
|
if (r < 0){
|
||||||
|
char estr[200];
|
||||||
|
mbedtls_strerror(r, estr, 200);
|
||||||
|
INFO_MSG("Write returns %d: %s", r, estr);
|
||||||
|
}
|
||||||
|
if (r < 0){
|
||||||
|
switch (errno){
|
||||||
|
case MBEDTLS_ERR_SSL_WANT_WRITE: return 0; break;
|
||||||
|
case MBEDTLS_ERR_SSL_WANT_READ: return 0; break;
|
||||||
|
case EWOULDBLOCK: return 0; break;
|
||||||
|
default:
|
||||||
|
Error = true;
|
||||||
|
INSANE_MSG("Could not iwrite data! Error: %s", strerror(errno));
|
||||||
|
close();
|
||||||
|
return 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (r == 0 && (sock >= 0)){
|
||||||
|
DONTEVEN_MSG("Socket closed by remote");
|
||||||
|
close();
|
||||||
|
}
|
||||||
|
up += r;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Socket::SSLConnection::connected() const{
|
||||||
|
return isConnected;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Socket::SSLConnection::setBlocking(bool blocking){
|
||||||
|
if (blocking != Blocking){return;}
|
||||||
|
if (blocking){
|
||||||
|
mbedtls_net_set_block(server_fd);
|
||||||
|
Blocking = true;
|
||||||
|
}else{
|
||||||
|
mbedtls_net_set_nonblock(server_fd);
|
||||||
|
Blocking = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
/// Create a new base Server. The socket is never connected, and a placeholder for later connections.
|
/// Create a new base Server. The socket is never connected, and a placeholder for later connections.
|
||||||
Socket::Server::Server(){
|
Socket::Server::Server(){
|
||||||
sock = -1;
|
sock = -1;
|
||||||
|
|
42
lib/socket.h
42
lib/socket.h
|
@ -16,6 +16,15 @@
|
||||||
#include <sys/un.h>
|
#include <sys/un.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#ifdef SSL
|
||||||
|
#include "mbedtls/net.h"
|
||||||
|
#include "mbedtls/ssl.h"
|
||||||
|
#include "mbedtls/entropy.h"
|
||||||
|
#include "mbedtls/ctr_drbg.h"
|
||||||
|
#include "mbedtls/debug.h"
|
||||||
|
#include "mbedtls/error.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
// for being friendly with Socket::Connection down below
|
// for being friendly with Socket::Connection down below
|
||||||
namespace Buffer{
|
namespace Buffer{
|
||||||
class user;
|
class user;
|
||||||
|
@ -54,7 +63,7 @@ namespace Socket{
|
||||||
|
|
||||||
/// This class is for easy communicating through sockets, either TCP or Unix.
|
/// This class is for easy communicating through sockets, either TCP or Unix.
|
||||||
class Connection{
|
class Connection{
|
||||||
private:
|
protected:
|
||||||
int sock; ///< Internally saved socket number.
|
int sock; ///< Internally saved socket number.
|
||||||
int pipes[2]; ///< Internally saved file descriptors for pipe socket simulation.
|
int pipes[2]; ///< Internally saved file descriptors for pipe socket simulation.
|
||||||
std::string remotehost; ///< Stores remote host address.
|
std::string remotehost; ///< Stores remote host address.
|
||||||
|
@ -63,8 +72,8 @@ namespace Socket{
|
||||||
uint64_t down;
|
uint64_t down;
|
||||||
long long int conntime;
|
long long int conntime;
|
||||||
Buffer downbuffer; ///< Stores temporary data coming in.
|
Buffer downbuffer; ///< Stores temporary data coming in.
|
||||||
int iread(void *buffer, int len, int flags = 0); ///< Incremental read call.
|
virtual int iread(void *buffer, int len, int flags = 0); ///< Incremental read call.
|
||||||
unsigned int iwrite(const void *buffer, int len); ///< Incremental write call.
|
virtual unsigned int iwrite(const void *buffer, int len); ///< Incremental write call.
|
||||||
bool iread(Buffer &buffer, int flags = 0); ///< Incremental write call that is compatible with Socket::Buffer.
|
bool iread(Buffer &buffer, int flags = 0); ///< Incremental write call that is compatible with Socket::Buffer.
|
||||||
bool iwrite(std::string &buffer); ///< Write call that is compatible with std::string.
|
bool iwrite(std::string &buffer); ///< Write call that is compatible with std::string.
|
||||||
public:
|
public:
|
||||||
|
@ -77,9 +86,9 @@ namespace Socket{
|
||||||
Connection(std::string adres, bool nonblock = false); ///< Create a new Unix Socket.
|
Connection(std::string adres, bool nonblock = false); ///< Create a new Unix Socket.
|
||||||
Connection(int write, int read); ///< Simulate a socket using two file descriptors.
|
Connection(int write, int read); ///< Simulate a socket using two file descriptors.
|
||||||
// generic methods
|
// generic methods
|
||||||
void close(); ///< Close connection.
|
virtual void close(); ///< Close connection.
|
||||||
void drop(); ///< Close connection without shutdown.
|
void drop(); ///< Close connection without shutdown.
|
||||||
void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false).
|
virtual void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false).
|
||||||
bool isBlocking(); ///< Check if this socket is blocking (true) or nonblocking (false).
|
bool isBlocking(); ///< Check if this socket is blocking (true) or nonblocking (false).
|
||||||
std::string getHost() const; ///< Gets hostname for connection, if available.
|
std::string getHost() const; ///< Gets hostname for connection, if available.
|
||||||
std::string getBinHost();
|
std::string getBinHost();
|
||||||
|
@ -87,7 +96,7 @@ namespace Socket{
|
||||||
int getSocket(); ///< Returns internal socket number.
|
int getSocket(); ///< Returns internal socket number.
|
||||||
int getPureSocket(); ///< Returns non-piped internal socket number.
|
int getPureSocket(); ///< Returns non-piped internal socket number.
|
||||||
std::string getError(); ///< Returns a string describing the last error that occured.
|
std::string getError(); ///< Returns a string describing the last error that occured.
|
||||||
bool connected() const; ///< Returns the connected-state for this socket.
|
virtual bool connected() const; ///< Returns the connected-state for this socket.
|
||||||
bool isAddress(const std::string &addr);
|
bool isAddress(const std::string &addr);
|
||||||
bool isLocal(); ///< Returns true if remote address is a local address.
|
bool isLocal(); ///< Returns true if remote address is a local address.
|
||||||
// buffered i/o methods
|
// buffered i/o methods
|
||||||
|
@ -114,6 +123,27 @@ namespace Socket{
|
||||||
operator bool() const;
|
operator bool() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef SSL
|
||||||
|
/// Version of Socket::Connection that uses mbedtls for SSL
|
||||||
|
class SSLConnection : public Connection{
|
||||||
|
public:
|
||||||
|
SSLConnection();
|
||||||
|
SSLConnection(std::string hostname, int port, bool nonblock); ///< Create a new TCP socket.
|
||||||
|
void close(); ///< Close connection.
|
||||||
|
bool connected() const; ///< Returns the connected-state for this socket.
|
||||||
|
void setBlocking(bool blocking); ///< Set this socket to be blocking (true) or nonblocking (false).
|
||||||
|
protected:
|
||||||
|
bool isConnected;
|
||||||
|
int iread(void *buffer, int len, int flags = 0); ///< Incremental read call.
|
||||||
|
unsigned int iwrite(const void *buffer, int len); ///< Incremental write call.
|
||||||
|
mbedtls_net_context * server_fd;
|
||||||
|
mbedtls_entropy_context * entropy;
|
||||||
|
mbedtls_ctr_drbg_context * ctr_drbg;
|
||||||
|
mbedtls_ssl_context * ssl;
|
||||||
|
mbedtls_ssl_config * conf;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
/// This class is for easily setting up listening socket, either TCP or Unix.
|
/// This class is for easily setting up listening socket, either TCP or Unix.
|
||||||
class Server{
|
class Server{
|
||||||
private:
|
private:
|
||||||
|
|
Loading…
Add table
Reference in a new issue