diff --git a/lib/shared_memory.cpp b/lib/shared_memory.cpp index 242f54e6..7a42abbc 100644 --- a/lib/shared_memory.cpp +++ b/lib/shared_memory.cpp @@ -846,6 +846,45 @@ namespace IPC { } } + ///Returns a pointer to the data for the given index. + ///Returns null on error or if index is empty. + char * sharedServer::getIndex(unsigned int requestId){ + char * empty = 0; + if (!hasCounter) { + empty = (char *)malloc(payLen * sizeof(char)); + memset(empty, 0, payLen); + } + semGuard tmpGuard(&mySemaphore); + unsigned int id = 0; + for (std::set::iterator it = myPages.begin(); it != myPages.end(); it++) { + if (!it->mapped || !it->len) { + DEBUG_MSG(DLVL_FAIL, "Something went terribly wrong?"); + return 0; + } + unsigned int offset = 0; + while (offset + payLen + (hasCounter ? 1 : 0) <= it->len) { + if (id == requestId){ + if (hasCounter) { + if (it->mapped[offset] != 0) { + return it->mapped + offset + 1; + }else{ + return 0; + } + } else { + if (memcmp(empty, it->mapped + offset, payLen)) { + return it->mapped + offset; + }else{ + return 0; + } + } + } + offset += payLen + (hasCounter ? 1 : 0); + id ++; + } + } + return 0; + } + ///\brief Parse each of the possible payload pieces, and runs a callback on it if in use. void sharedServer::parseEach(void (*callback)(char * data, size_t len, unsigned int id)) { char * empty = 0; diff --git a/lib/shared_memory.h b/lib/shared_memory.h index 1f350575..bc8906f2 100644 --- a/lib/shared_memory.h +++ b/lib/shared_memory.h @@ -181,6 +181,7 @@ namespace IPC { void init(std::string name, int len, bool withCounter = false); ~sharedServer(); void parseEach(void (*callback)(char * data, size_t len, unsigned int id)); + char * getIndex(unsigned int id); operator bool() const; ///\brief The amount of connected clients unsigned int amount; diff --git a/src/controller/controller_api.cpp b/src/controller/controller_api.cpp index 7ece9a50..12ed9ac6 100644 --- a/src/controller/controller_api.cpp +++ b/src/controller/controller_api.cpp @@ -570,6 +570,16 @@ int Controller::handleAPIConnection(Socket::Connection & conn){ Controller::fillActive(Request["stats_streams"], Response["stats_streams"]); } + if (Request.isMember("invalidate_sessions")){ + if (Request["totals"].isArray()){ + for (unsigned int i = 0; i < Request["invalidate_sessions"].size(); ++i){ + Controller::sessions_invalidate(Request["invalidate_sessions"][i].asStringRef()); + } + }else{ + Controller::sessions_invalidate(Request["invalidate_sessions"].asStringRef()); + } + } + if (Request.isMember("push_start")){ std::string stream; diff --git a/src/controller/controller_statistics.cpp b/src/controller/controller_statistics.cpp index 8ad7bdec..fa76613d 100644 --- a/src/controller/controller_statistics.cpp +++ b/src/controller/controller_statistics.cpp @@ -135,6 +135,36 @@ void Controller::streamStopped(std::string stream){ INFO_MSG("Stream %s became inactive", stream.c_str()); } +/// \todo Make this prettier. +IPC::sharedServer * statPointer = 0; + +///Invalidates all current sessions for the given streamname +void Controller::sessions_invalidate(const std::string & streamname){ + if (!statPointer){ + FAIL_MSG("In shutdown procedure - cannot invalidate sessions."); + return; + } + unsigned int invalidated = 0; + unsigned int sessCount = 0; + tthread::lock_guard guard(statsMutex); + for (std::map::iterator it = sessions.begin(); it != sessions.end(); it++){ + if (it->first.streamName == streamname){ + sessCount++; + it->second.sync = 1; + if (it->second.curConns.size()){ + for (std::map::iterator jt = it->second.curConns.begin(); jt != it->second.curConns.end(); ++jt){ + char * data = statPointer->getIndex(jt->first); + if (data){ + IPC::statExchange tmpEx(data); + tmpEx.setSync(2); + invalidated++; + } + } + } + } + } + INFO_MSG("Invalidated %u connections in %u sessions for stream %s", invalidated, sessCount, streamname.c_str()); +} /// This function runs as a thread and roughly once per second retrieves /// statistics from all connected clients, as well as wipes @@ -142,6 +172,7 @@ void Controller::streamStopped(std::string stream){ void Controller::SharedMemStats(void * config){ DEBUG_MSG(DLVL_HIGH, "Starting stats thread"); IPC::sharedServer statServer(SHM_STATISTICS, STAT_EX_SIZE, true); + statPointer = &statServer; std::set inactiveStreams; while(((Util::Config*)config)->is_active){ { @@ -180,6 +211,7 @@ void Controller::SharedMemStats(void * config){ } Util::wait(1000); } + statPointer = 0; DEBUG_MSG(DLVL_HIGH, "Stopping stats thread"); if (Controller::killOnExit){ DEBUG_MSG(DLVL_WARN, "Killing all connected clients to force full shutdown"); @@ -193,16 +225,12 @@ void Controller::SharedMemStats(void * config){ /// Updates the given active connection with new stats data. void Controller::statSession::update(unsigned long index, IPC::statExchange & data){ - //update the sync byte: 0 = requesting fill, 1 = needs checking, > 1 = state known (100=denied, 10=accepted) + //update the sync byte: 0 = requesting fill, 2 = requesting refill, 1 = needs checking, > 1 = state known (100=denied, 10=accepted) if (!data.getSync()){ - std::string myHost; - { - sessIndex tmpidx(data); - myHost = tmpidx.host; - } - MEDIUM_MSG("Setting sync to %u for %s, %s, %s, %lu", sync, data.streamName().c_str(), data.connector().c_str(), myHost.c_str(), data.crc() & 0xFFFFFFFFu); + sessIndex tmpidx(data); + std::string myHost = tmpidx.host; //if we have a maximum connection count per IP, enforce it - if (maxConnsPerIP){ + if (maxConnsPerIP && !data.getSync()){ unsigned int currConns = 1; long long shortly = Util::epoch(); for (std::map::iterator it = sessions.begin(); it != sessions.end(); it++){ @@ -212,15 +240,23 @@ void Controller::statSession::update(unsigned long index, IPC::statExchange & da if (currConns > maxConnsPerIP){ WARN_MSG("Disconnecting session from %s: exceeds max connection count of %u", myHost.c_str(), maxConnsPerIP); data.setSync(100); - }else{ + } + } + if (data.getSync() != 100){ + //only set the sync if this is the first connection in the list + //we also catch the case that there are no connections, which is an error-state + if (!sessions[tmpidx].curConns.size() || sessions[tmpidx].curConns.begin()->first == index){ + MEDIUM_MSG("Requesting sync to %u for %s, %s, %s, %lu", sync, data.streamName().c_str(), data.connector().c_str(), myHost.c_str(), data.crc() & 0xFFFFFFFFu); + data.setSync(sync); + } + //and, always set the sync if it is > 2 + if (sync > 2){ + MEDIUM_MSG("Setting sync to %u for %s, %s, %s, %lu", sync, data.streamName().c_str(), data.connector().c_str(), myHost.c_str(), data.crc() & 0xFFFFFFFFu); data.setSync(sync); } - }else{ - //no maximum, just set the sync byte to its current value - data.setSync(sync); } }else{ - if (sync < 2){ + if (sync < 2 && data.getSync() > 2){ sync = data.getSync(); } } diff --git a/src/controller/controller_statistics.h b/src/controller/controller_statistics.h index c29def5f..a9dddda5 100644 --- a/src/controller/controller_statistics.h +++ b/src/controller/controller_statistics.h @@ -78,10 +78,10 @@ namespace Controller { unsigned long long wipedUp; unsigned long long wipedDown; std::deque oldConns; - std::map curConns; - char sync; sessType sessionType; public: + char sync; + std::map curConns; sessType getSessType(); statSession(); void wipeOld(unsigned long long); @@ -116,6 +116,7 @@ namespace Controller { void fillActive(JSON::Value & req, JSON::Value & rep, bool onlyNow = false); void fillTotals(JSON::Value & req, JSON::Value & rep); void SharedMemStats(void * config); + void sessions_invalidate(const std::string & streamname); bool hasViewers(std::string streamName); #define PROMETHEUS_TEXT 0 diff --git a/src/output/output.cpp b/src/output/output.cpp index bf35bf53..17a52fba 100644 --- a/src/output/output.cpp +++ b/src/output/output.cpp @@ -146,38 +146,59 @@ namespace Mist { myConn.close(); } } - if(Triggers::shouldTrigger("USER_NEW", streamName)){ - //sync byte 0 = no sync yet, wait for sync from controller... - IPC::statExchange tmpEx(statsPage.getData()); - unsigned int i = 0; - tmpEx.setSync(0); - while (!tmpEx.getSync() && i++ < 30){ - Util::wait(100); - stats(); - tmpEx = IPC::statExchange(statsPage.getData()); - } - HIGH_MSG("USER_NEW sync achieved: %u", (unsigned int)tmpEx.getSync()); - //1 = check requested (connection is new) - if (tmpEx.getSync() == 1){ - std::string payload = streamName+"\n" + getConnectedHost() +"\n" + JSON::Value((long long)crc).asString() + "\n"+capa["name"].asStringRef()+"\n"+reqUrl; - if (!Triggers::doTrigger("USER_NEW", payload, streamName)){ - MEDIUM_MSG("Closing connection because denied by USER_NEW trigger"); - myConn.close(); - tmpEx.setSync(100);//100 = denied - }else{ - tmpEx.setSync(10);//10 = accepted - } - } - //100 = denied - if (tmpEx.getSync() == 100){ - myConn.close(); - MEDIUM_MSG("Closing connection because denied by USER_NEW sync byte"); - } - //anything else = accepted - } + doSync(true); /*LTS-END*/ } + /// If called with force set to true and a USER_NEW trigger enabled, forces a sync immediately. + /// Otherwise, does nothing unless the sync byte is set to 2, in which case it forces a sync as well. + /// May be called recursively because it calls stats() which calls this function. + /// If this happens, the extra calls to the function return instantly. + void Output::doSync(bool force){ + static bool recursing = false; + if (recursing){return;} + recursing = true; + IPC::statExchange tmpEx(statsPage.getData()); + if (tmpEx.getSync() == 2 || force){ + if(Triggers::shouldTrigger("USER_NEW", streamName)){ + //sync byte 0 = no sync yet, wait for sync from controller... + unsigned int i = 0; + tmpEx.setSync(0); + //wait max 10 seconds for sync + while ((!tmpEx.getSync() || tmpEx.getSync() == 2) && i++ < 100){ + Util::wait(100); + stats(); + tmpEx = IPC::statExchange(statsPage.getData()); + } + HIGH_MSG("USER_NEW sync achieved: %u", (unsigned int)tmpEx.getSync()); + //1 = check requested (connection is new) + if (tmpEx.getSync() == 1){ + std::string payload = streamName+"\n" + getConnectedHost() +"\n" + JSON::Value((long long)crc).asString() + "\n"+capa["name"].asStringRef()+"\n"+reqUrl; + if (!Triggers::doTrigger("USER_NEW", payload, streamName)){ + MEDIUM_MSG("Closing connection because denied by USER_NEW trigger"); + myConn.close(); + tmpEx.setSync(100);//100 = denied + }else{ + tmpEx.setSync(10);//10 = accepted + } + } + //100 = denied + if (tmpEx.getSync() == 100){ + myConn.close(); + MEDIUM_MSG("Closing connection because denied by USER_NEW sync byte"); + } + if (tmpEx.getSync() == 0 || tmpEx.getSync() == 2){ + myConn.close(); + FAIL_MSG("Closing connection because sync byte timeout!"); + } + //anything else = accepted + }else{ + tmpEx.setSync(10);//auto-accept if no trigger + } + } + recursing = false; + } + std::string Output::getConnectedHost(){ return myConn.getHost(); } @@ -1211,6 +1232,7 @@ namespace Mist { statsPage.keepAlive(); } } + doSync(); int tNum = 0; if (!nProxy.userClient.getData()){ char userPageName[NAME_BUFFER_SIZE]; diff --git a/src/output/output.h b/src/output/output.h index 52251d7a..69c33357 100644 --- a/src/output/output.h +++ b/src/output/output.h @@ -93,6 +93,7 @@ namespace Mist { bool onList(std::string ip, std::string list); std::string getCountry(std::string ip); /*LTS-END*/ + void doSync(bool force = false); std::map currKeyOpen;