mistserver/lib/downloader.cpp

333 lines
11 KiB
C++

#include "downloader.h"
#include "defines.h"
#include "encode.h"
#include "timing.h"
namespace HTTP{
Downloader::Downloader(){
progressCallback = 0;
connectedPort = 0;
dataTimeout = 5;
retryCount = 5;
ssl = false;
proxied = false;
char *p = getenv("http_proxy");
if (p){
proxyUrl = HTTP::URL(p);
proxied = true;
MEDIUM_MSG("Proxying through %s", proxyUrl.getUrl().c_str());
}
}
/// Returns a reference to the internal HTTP::Parser body element
std::string &Downloader::data(){return H.body;}
/// Returns a const reference to the internal HTTP::Parser body element
const std::string &Downloader::const_data() const{return H.body;}
/// Returns the status text of the HTTP Request.
std::string &Downloader::getStatusText(){return H.method;}
/// Returns the status code of the HTTP Request.
uint32_t Downloader::getStatusCode(){return atoi(H.url.c_str());}
/// Returns true if the HTTP Request is OK
bool Downloader::isOk(){return (getStatusCode() == 200);}
/// Returns the given header from the response, or empty string if it does not exist.
std::string Downloader::getHeader(const std::string &headerName){
return H.GetHeader(headerName);
}
/// Simply turns link into a HTTP::URL and calls get(const HTTP::URL&)
bool Downloader::get(const std::string &link){
HTTP::URL uri(link);
return get(uri);
}
/// Sets an extra (or overridden) header to be sent with outgoing requests.
void Downloader::setHeader(const std::string &name, const std::string &val){
extraHeaders[name] = val;
}
/// Clears all extra/override headers for outgoing requests.
void Downloader::clearHeaders(){extraHeaders.clear();}
/// Returns a reference to the internal HTTP class instance.
Parser &Downloader::getHTTP(){return H;}
/// Returns a reference to the internal Socket::Connection class instance.
Socket::Connection &Downloader::getSocket(){
return S;
}
Downloader::~Downloader(){
S.close();
}
/// Sends a request for the given URL, does no waiting.
void Downloader::doRequest(const HTTP::URL &link, const std::string &method,
const std::string &body){
if (!canRequest(link)){return;}
bool needSSL = (link.protocol == "https");
H.Clean();
// Reconnect if needed
if (!proxied || needSSL){
if (!getSocket() || link.host != connectedHost || link.getPort() != connectedPort ||
needSSL != ssl){
getSocket().close();
connectedHost = link.host;
connectedPort = link.getPort();
#ifdef SSL
if (needSSL){
S.open(connectedHost, connectedPort, true, true);
}else{
S.open(connectedHost, connectedPort, true);
}
#else
S.open(connectedHost, connectedPort, true);
#endif
}
}else{
if (!getSocket() || proxyUrl.host != connectedHost || proxyUrl.getPort() != connectedPort ||
needSSL != ssl){
getSocket().close();
connectedHost = proxyUrl.host;
connectedPort = proxyUrl.getPort();
S.open(connectedHost, connectedPort, true);
}
}
ssl = needSSL;
if (!getSocket()){
return; // socket is closed
}
if (proxied && !ssl){
H.url = link.getProxyUrl();
if (link.port.size()){
H.SetHeader("Host", link.host + ":" + link.port);
}else{
H.SetHeader("Host", link.host);
}
}else{
H.url = "/" + Encodings::URL::encode(link.path, "/:=@[]");
if (link.args.size()){H.url += "?" + link.args;}
if (link.port.size()){
H.SetHeader("Host", link.host + ":" + link.port);
}else{
H.SetHeader("Host", link.host);
}
}
if (method.size()){H.method = method;}
H.SetHeader("User-Agent", "MistServer " PACKAGE_VERSION);
H.SetHeader("X-Version", PACKAGE_VERSION);
H.SetHeader("Accept", "*/*");
if (authStr.size() && (link.user.size() || link.pass.size())){
H.auth(link.user, link.pass, authStr);
}
if (proxyAuthStr.size() && (proxyUrl.user.size() || proxyUrl.pass.size())){
H.auth(proxyUrl.user, proxyUrl.pass, proxyAuthStr, "Proxy-Authorization");
}
if (extraHeaders.size()){
for (std::map<std::string, std::string>::iterator it = extraHeaders.begin();
it != extraHeaders.end(); ++it){
H.SetHeader(it->first, it->second);
}
}
H.SendRequest(getSocket(), body);
H.Clean();
}
/// Downloads the given URL into 'H', returns true on success.
/// Makes at most 5 attempts, and will wait no longer than 5 seconds without receiving data.
bool Downloader::get(const HTTP::URL &link, uint8_t maxRecursiveDepth){
if (!canRequest(link)){return false;}
size_t loop = retryCount + 1; // max 5 attempts
while (--loop){// loop while we are unsuccessful
MEDIUM_MSG("Retrieving %s (%zu/%" PRIu32 ")", link.getUrl().c_str(), retryCount - loop + 1,
retryCount);
doRequest(link);
uint64_t reqTime = Util::bootSecs();
while (getSocket() && Util::bootSecs() < reqTime + dataTimeout){
// No data? Wait for a second or so.
if (!getSocket().spool()){
if (progressCallback != 0){
if (!progressCallback()){
WARN_MSG("Download aborted by callback");
return false;
}
}
Util::sleep(250);
continue;
}
// Data! Check if we can parse it...
if (H.Read(getSocket())){
if (shouldContinue()){
if (maxRecursiveDepth == 0){
FAIL_MSG("Maximum recursion depth reached");
return false;
}
if (!canContinue(link)){return false;}
if (getStatusCode() >= 300 && getStatusCode() < 400){
return get(link.link(getHeader("Location")), --maxRecursiveDepth);
}else{
return get(link, --maxRecursiveDepth);
}
}
return true; // Success!
}
// reset the data timeout
if (reqTime != Util::bootSecs()){
if (progressCallback != 0){
if (!progressCallback()){
WARN_MSG("Download aborted by callback");
return false;
}
}
reqTime = Util::bootSecs();
}
}
if (getSocket()){
FAIL_MSG("Timeout while retrieving %s (%zu/%" PRIu32 ")", link.getUrl().c_str(),
retryCount - loop + 1, retryCount);
getSocket().close();
}else{
if (retryCount - loop + 1 > 2){
INFO_MSG("Lost connection while retrieving %s (%zu/%" PRIu32 ")", link.getUrl().c_str(), retryCount - loop + 1, retryCount);
}else{
MEDIUM_MSG("Lost connection while retrieving %s (%zu/%" PRIu32 ")", link.getUrl().c_str(), retryCount - loop + 1, retryCount);
}
}
Util::sleep(500); // wait a bit before retrying
}
FAIL_MSG("Could not retrieve %s", link.getUrl().c_str());
return false;
}
bool Downloader::post(const HTTP::URL &link, const std::string &payload, bool sync,
uint8_t maxRecursiveDepth){
if (!canRequest(link)){return false;}
size_t loop = retryCount; // max 5 attempts
while (--loop){// loop while we are unsuccessful
MEDIUM_MSG("Posting to %s (%zu/%" PRIu32 ")", link.getUrl().c_str(), retryCount - loop + 1,
retryCount);
doRequest(link, "POST", payload);
// Not synced? Ignore the response and immediately return true.
if (!sync){return true;}
uint64_t reqTime = Util::bootSecs();
while (getSocket() && Util::bootSecs() < reqTime + dataTimeout){
// No data? Wait for a second or so.
if (!getSocket().spool()){
if (progressCallback != 0){
if (!progressCallback()){
WARN_MSG("Download aborted by callback");
return false;
}
}
Util::sleep(250);
continue;
}
// Data! Check if we can parse it...
if (H.Read(getSocket())){
if (shouldContinue()){
if (maxRecursiveDepth == 0){
FAIL_MSG("Maximum recursion depth reached");
return false;
}
if (!canContinue(link)){return false;}
if (getStatusCode() >= 300 && getStatusCode() < 400){
return post(link.link(getHeader("Location")), payload, sync, --maxRecursiveDepth);
}else{
return post(link, payload, sync, --maxRecursiveDepth);
}
}
return true; // Success!
}
// reset the data timeout
if (reqTime != Util::bootSecs()){
if (progressCallback != 0){
if (!progressCallback()){
WARN_MSG("Download aborted by callback");
return false;
}
}
reqTime = Util::bootSecs();
}
}
if (getSocket()){
FAIL_MSG("Timeout while retrieving %s", link.getUrl().c_str());
return false;
}
Util::sleep(500); // wait a bit before retrying
}
FAIL_MSG("Could not retrieve %s", link.getUrl().c_str());
return false;
}
bool Downloader::canRequest(const HTTP::URL &link){
if (!link.host.size()){return false;}
if (link.protocol != "http" && link.protocol != "https"){
FAIL_MSG("Protocol not supported: %s", link.protocol.c_str());
return false;
}
#ifndef SSL
if (link.protocol == "https"){
FAIL_MSG("Protocol not supported: %s", link.protocol.c_str());
return false;
}
#endif
return true;
}
bool Downloader::shouldContinue(){
if (H.hasHeader("Set-Cookie")){
std::string cookie = H.GetHeader("Set-Cookie");
setHeader("Cookie", cookie.substr(0, cookie.find(';')));
}
uint32_t sCode = getStatusCode();
if (sCode == 401 || sCode == 407 || (sCode >= 300 && sCode < 400)){return true;}
return false;
}
bool Downloader::canContinue(const HTTP::URL &link){
if (getStatusCode() == 401){
// retry with authentication
if (H.hasHeader("WWW-Authenticate")){authStr = H.GetHeader("WWW-Authenticate");}
if (H.hasHeader("Www-Authenticate")){authStr = H.GetHeader("Www-Authenticate");}
if (!authStr.size()){
FAIL_MSG("Authentication required but no WWW-Authenticate header present");
return false;
}
if (!link.user.size() && !link.pass.size()){
FAIL_MSG("Authentication required but not included in URL");
return false;
}
FAIL_MSG("Authenticating...");
return true;
}
if (getStatusCode() == 407){
// retry with authentication
if (H.hasHeader("Proxy-Authenticate")){proxyAuthStr = H.GetHeader("Proxy-Authenticate");}
if (!proxyAuthStr.size()){
FAIL_MSG("Proxy authentication required but no Proxy-Authenticate header present");
return false;
}
if (!proxyUrl.user.size() && !proxyUrl.pass.size()){
FAIL_MSG("Proxy authentication required but not included in URL");
return false;
}
FAIL_MSG("Authenticating proxy...");
return true;
}
if (getStatusCode() >= 300 && getStatusCode() < 400){
// follow redirect
std::string location = getHeader("Location");
if (!location.size()){return false;}
INFO_MSG("Following redirect to %s", location.c_str());
return true;
}
return false;
}
}// namespace HTTP