summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/inspsocket.cpp171
-rw-r--r--src/listensocket.cpp22
-rw-r--r--src/modules/extra/m_ssl_gnutls.cpp43
-rw-r--r--src/modules/extra/m_ssl_mbedtls.cpp35
-rw-r--r--src/modules/extra/m_ssl_openssl.cpp26
-rw-r--r--src/modules/m_httpd.cpp6
-rw-r--r--src/modules/m_sha1.cpp199
-rw-r--r--src/modules/m_spanningtree/main.cpp4
-rw-r--r--src/modules/m_spanningtree/treesocket1.cpp9
-rw-r--r--src/modules/m_websocket.cpp405
-rw-r--r--src/usermanager.cpp8
11 files changed, 835 insertions, 93 deletions
diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp
index 89c3a71a9..9bfc6a73e 100644
--- a/src/inspsocket.cpp
+++ b/src/inspsocket.cpp
@@ -25,6 +25,14 @@
#include "inspircd.h"
#include "iohook.h"
+static IOHook* GetNextHook(IOHook* hook)
+{
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ if (iohm)
+ return iohm->GetNextHook();
+ return NULL;
+}
+
BufferedSocket::BufferedSocket()
{
Timeout = NULL;
@@ -112,11 +120,15 @@ void StreamSocket::Close()
{
// final chance, dump as much of the sendq as we can
DoWrite();
- if (GetIOHook())
+
+ IOHook* hook = GetIOHook();
+ DelIOHook();
+ while (hook)
{
- GetIOHook()->OnStreamSocketClose(this);
- delete iohook;
- DelIOHook();
+ hook->OnStreamSocketClose(this);
+ IOHook* const nexthook = GetNextHook(hook);
+ delete hook;
+ hook = nexthook;
}
SocketEngine::Shutdown(this, 2);
SocketEngine::Close(this);
@@ -139,51 +151,74 @@ bool StreamSocket::GetNextLine(std::string& line, char delim)
return true;
}
-void StreamSocket::DoRead()
+int StreamSocket::HookChainRead(IOHook* hook, std::string& rq)
{
- if (GetIOHook())
+ if (!hook)
+ return ReadToRecvQ(rq);
+
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ if (iohm)
{
- int rv = GetIOHook()->OnStreamSocketRead(this, recvq);
- if (rv > 0)
- OnDataReady();
- if (rv < 0)
- SetError("Read Error"); // will not overwrite a better error message
+ // Call the next hook to put data into the recvq of the current hook
+ const int ret = HookChainRead(iohm->GetNextHook(), iohm->GetRecvQ());
+ if (ret <= 0)
+ return ret;
}
- else
+ return hook->OnStreamSocketRead(this, rq);
+}
+
+void StreamSocket::DoRead()
+{
+ const std::string::size_type prevrecvqsize = recvq.size();
+
+ const int result = HookChainRead(GetIOHook(), recvq);
+ if (result < 0)
{
+ SetError("Read Error"); // will not overwrite a better error message
+ return;
+ }
+
+ if (recvq.size() > prevrecvqsize)
+ OnDataReady();
+}
+
+int StreamSocket::ReadToRecvQ(std::string& rq)
+{
char* ReadBuffer = ServerInstance->GetReadBuffer();
int n = SocketEngine::Recv(this, ReadBuffer, ServerInstance->Config->NetBufferSize, 0);
if (n == ServerInstance->Config->NetBufferSize)
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ);
- recvq.append(ReadBuffer, n);
- OnDataReady();
+ rq.append(ReadBuffer, n);
}
else if (n > 0)
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ);
- recvq.append(ReadBuffer, n);
- OnDataReady();
+ rq.append(ReadBuffer, n);
}
else if (n == 0)
{
error = "Connection closed";
SocketEngine::ChangeEventMask(this, FD_WANT_NO_READ | FD_WANT_NO_WRITE);
+ return -1;
}
else if (SocketEngine::IgnoreError())
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_READ_WILL_BLOCK);
+ return 0;
}
else if (errno == EINTR)
{
SocketEngine::ChangeEventMask(this, FD_WANT_FAST_READ | FD_ADD_TRIAL_READ);
+ return 0;
}
else
{
error = SocketEngine::LastError();
SocketEngine::ChangeEventMask(this, FD_WANT_NO_READ | FD_WANT_NO_WRITE);
+ return -1;
}
- }
+ return n;
}
/* Don't try to prepare huge blobs of data to send to a blocked socket */
@@ -191,7 +226,7 @@ static const int MYIOV_MAX = IOV_MAX < 128 ? IOV_MAX : 128;
void StreamSocket::DoWrite()
{
- if (sendq.empty())
+ if (getSendQSize() == 0)
return;
if (!error.empty() || fd < 0)
{
@@ -199,26 +234,48 @@ void StreamSocket::DoWrite()
return;
}
- if (GetIOHook())
+ SendQueue* psendq = &sendq;
+ IOHook* hook = GetIOHook();
+ while (hook)
{
- int rv = GetIOHook()->OnStreamSocketWrite(this);
- if (rv < 0)
- SetError("Write Error"); // will not overwrite a better error message
+ int rv = hook->OnStreamSocketWrite(this, *psendq);
+ psendq = NULL;
// rv == 0 means the socket has blocked. Stop trying to send data.
// IOHook has requested unblock notification from the socketengine.
+ if (rv == 0)
+ break;
+
+ if (rv < 0)
+ {
+ SetError("Write Error"); // will not overwrite a better error message
+ break;
+ }
+
+ IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(hook);
+ hook = NULL;
+ if (iohm)
+ {
+ psendq = &iohm->GetSendQ();
+ hook = iohm->GetNextHook();
+ }
}
- else
- {
+
+ if (psendq)
+ FlushSendQ(*psendq);
+}
+
+void StreamSocket::FlushSendQ(SendQueue& sq)
+{
// don't even try if we are known to be blocking
if (GetEventMask() & FD_WRITE_WILL_BLOCK)
return;
// start out optimistic - we won't need to write any more
int eventChange = FD_WANT_EDGE_WRITE;
- while (error.empty() && !sendq.empty() && eventChange == FD_WANT_EDGE_WRITE)
+ while (error.empty() && !sq.empty() && eventChange == FD_WANT_EDGE_WRITE)
{
// Prepare a writev() call to write all buffers efficiently
- int bufcount = sendq.size();
+ int bufcount = sq.size();
// cap the number of buffers at MYIOV_MAX
if (bufcount > MYIOV_MAX)
@@ -231,7 +288,7 @@ void StreamSocket::DoWrite()
{
SocketEngine::IOVector iovecs[MYIOV_MAX];
size_t j = 0;
- for (SendQueue::const_iterator i = sendq.begin(), end = i+bufcount; i != end; ++i, j++)
+ for (SendQueue::const_iterator i = sq.begin(), end = i+bufcount; i != end; ++i, j++)
{
const SendQueue::Element& elem = *i;
iovecs[j].iov_base = const_cast<char*>(elem.data());
@@ -241,11 +298,11 @@ void StreamSocket::DoWrite()
rv = SocketEngine::WriteV(this, iovecs, bufcount);
}
- if (rv == (int)sendq.bytes())
+ if (rv == (int)sq.bytes())
{
// it's our lucky day, everything got written out. Fast cleanup.
// This won't ever happen if the number of buffers got capped.
- sendq.clear();
+ sq.clear();
}
else if (rv > 0)
{
@@ -255,19 +312,19 @@ void StreamSocket::DoWrite()
// it's going to block now
eventChange = FD_WANT_FAST_WRITE | FD_WRITE_WILL_BLOCK;
}
- while (rv > 0 && !sendq.empty())
+ while (rv > 0 && !sq.empty())
{
- const SendQueue::Element& front = sendq.front();
+ const SendQueue::Element& front = sq.front();
if (front.length() <= (size_t)rv)
{
// this string got fully written out
rv -= front.length();
- sendq.pop_front();
+ sq.pop_front();
}
else
{
// stopped in the middle of this string
- sendq.erase_front(rv);
+ sq.erase_front(rv);
rv = 0;
}
}
@@ -299,7 +356,6 @@ void StreamSocket::DoWrite()
{
SocketEngine::ChangeEventMask(this, eventChange);
}
- }
}
void StreamSocket::WriteData(const std::string &data)
@@ -434,3 +490,50 @@ void StreamSocket::CheckError(BufferedSocketError errcode)
OnError(errcode);
}
}
+
+IOHook* StreamSocket::GetModHook(Module* mod) const
+{
+ for (IOHook* curr = GetIOHook(); curr; curr = GetNextHook(curr))
+ {
+ if (curr->prov->creator == mod)
+ return curr;
+ }
+ return NULL;
+}
+
+void StreamSocket::AddIOHook(IOHook* newhook)
+{
+ IOHook* curr = GetIOHook();
+ if (!curr)
+ {
+ iohook = newhook;
+ return;
+ }
+
+ IOHookMiddle* lasthook;
+ while (curr)
+ {
+ lasthook = IOHookMiddle::ToMiddleHook(curr);
+ if (!lasthook)
+ return;
+ curr = lasthook->GetNextHook();
+ }
+
+ lasthook->SetNextHook(newhook);
+}
+
+size_t StreamSocket::getSendQSize() const
+{
+ size_t ret = sendq.bytes();
+ IOHook* curr = GetIOHook();
+ while (curr)
+ {
+ const IOHookMiddle* const iohm = IOHookMiddle::ToMiddleHook(curr);
+ if (!iohm)
+ break;
+
+ ret += iohm->GetSendQ().bytes();
+ curr = iohm->GetNextHook();
+ }
+ return ret;
+}
diff --git a/src/listensocket.cpp b/src/listensocket.cpp
index fa43e6827..fb9f2a0ef 100644
--- a/src/listensocket.cpp
+++ b/src/listensocket.cpp
@@ -19,6 +19,7 @@
#include "inspircd.h"
+#include "iohook.h"
#ifndef _WIN32
#include <netinet/tcp.h>
@@ -26,7 +27,6 @@
ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_to)
: bind_tag(tag)
- , iohookprov(NULL, std::string())
{
irc::sockets::satoap(bind_to, bind_addr, bind_port);
bind_desc = bind_to.str();
@@ -178,15 +178,23 @@ void ListenSocket::OnEventHandlerRead()
}
}
-bool ListenSocket::ResetIOHookProvider()
+void ListenSocket::ResetIOHookProvider()
{
+ iohookprovs[0].SetProvider(bind_tag->getString("hook"));
+
+ // Check that all non-last hooks support being in the middle
+ for (IOHookProvList::iterator i = iohookprovs.begin(); i != iohookprovs.end()-1; ++i)
+ {
+ IOHookProvRef& curr = *i;
+ // Ignore if cannot be in the middle
+ if ((curr) && (!curr->IsMiddle()))
+ curr.SetProvider(std::string());
+ }
+
std::string provname = bind_tag->getString("ssl");
if (!provname.empty())
provname.insert(0, "ssl/");
- // Set the new provider name, dynref handles the rest
- iohookprov.SetProvider(provname);
-
- // Return true if no provider was set, or one was set and it was also found
- return (provname.empty() || iohookprov);
+ // SSL should be the last
+ iohookprovs.back().SetProvider(provname);
}
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp
index a1c989163..bda4e6a48 100644
--- a/src/modules/extra/m_ssl_gnutls.cpp
+++ b/src/modules/extra/m_ssl_gnutls.cpp
@@ -101,6 +101,8 @@ typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t;
#define INSPIRCD_GNUTLS_HAS_CORK
#endif
+static Module* thismod;
+
class RandGen : public HandlerBase2<void, char*, size_t>
{
public:
@@ -581,16 +583,21 @@ namespace GnuTLS
*/
const unsigned int outrecsize;
+ /** True to request a client certificate as a server
+ */
+ const bool requestclientcert;
+
Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr,
const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL,
- unsigned int recsize)
+ unsigned int recsize, bool Requestclientcert)
: name(profilename)
, x509cred(certstr, keystr)
, min_dh_bits(mindh)
, hash(hashstr)
, priority(priostr)
, outrecsize(recsize)
+ , requestclientcert(Requestclientcert)
{
x509cred.SetDH(DH);
x509cred.SetCA(CA, CRL);
@@ -661,7 +668,10 @@ namespace GnuTLS
#else
unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
#endif
- return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize);
+
+ const bool requestclientcert = tag->getBool("requestclientcert", true);
+
+ return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert);
}
/** Set up the given session with the settings in this profile
@@ -672,8 +682,9 @@ namespace GnuTLS
x509cred.SetupSession(sess);
gnutls_dh_set_prime_bits(sess, min_dh_bits);
- // Request client certificate if we are a server, no-op if we're a client
- gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
+ // Request client certificate if enabled and we are a server, no-op if we're a client
+ if (requestclientcert)
+ gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST);
}
const std::string& GetName() const { return name; }
@@ -917,7 +928,7 @@ info_done_dealloc:
{
StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
#ifdef _WIN32
- GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
#endif
if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
@@ -954,7 +965,7 @@ info_done_dealloc:
{
StreamSocket* sock = reinterpret_cast<StreamSocket*>(transportptr);
#ifdef _WIN32
- GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
#endif
if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -989,7 +1000,7 @@ info_done_dealloc:
{
StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
#ifdef _WIN32
- GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
#endif
if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -1086,7 +1097,7 @@ info_done_dealloc:
}
}
- int OnStreamSocketWrite(StreamSocket* user) CXX11_OVERRIDE
+ int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
{
// Finish handshake if needed
int prepret = PrepareIO(user);
@@ -1094,7 +1105,6 @@ info_done_dealloc:
return prepret;
// Session is ready for transferring application data
- StreamSocket::SendQueue& sendq = user->GetSendQ();
#ifdef INSPIRCD_GNUTLS_HAS_CORK
while (true)
@@ -1173,7 +1183,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
st->key_type = GNUTLS_PRIVKEY_X509;
#endif
StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
- GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials();
+ GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials();
st->ncerts = cred.certs.size();
st->cert.x509 = cred.certs.raw();
@@ -1283,6 +1293,7 @@ class ModuleSSLGnuTLS : public Module
#ifndef GNUTLS_HAS_RND
gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
#endif
+ thismod = this;
}
void init() CXX11_OVERRIDE
@@ -1318,7 +1329,7 @@ class ModuleSSLGnuTLS : public Module
{
LocalUser* user = IS_LOCAL(static_cast<User*>(item));
- if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+ if ((user) && (user->eh.GetModHook(this)))
{
// User is using SSL, they're a local user, and they're using one of *our* SSL ports.
// Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -1334,13 +1345,9 @@ class ModuleSSLGnuTLS : public Module
ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
{
- if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
- {
- GnuTLSIOHook* iohook = static_cast<GnuTLSIOHook*>(user->eh.GetIOHook());
- if (!iohook->IsHandshakeDone())
- return MOD_RES_DENY;
- }
-
+ const GnuTLSIOHook* const iohook = static_cast<GnuTLSIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
return MOD_RES_PASSTHRU;
}
};
diff --git a/src/modules/extra/m_ssl_mbedtls.cpp b/src/modules/extra/m_ssl_mbedtls.cpp
index 8578b8196..a465d06ee 100644
--- a/src/modules/extra/m_ssl_mbedtls.cpp
+++ b/src/modules/extra/m_ssl_mbedtls.cpp
@@ -257,7 +257,6 @@ namespace mbedTLS
mbedtls_debug_set_threshold(INT_MAX);
mbedtls_ssl_conf_dbg(&conf, DebugLogFunc, NULL);
#endif
- mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
// TODO: check ret of mbedtls_ssl_config_defaults
mbedtls_ssl_config_defaults(&conf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
@@ -308,6 +307,11 @@ namespace mbedTLS
mbedtls_ssl_conf_ca_chain(&conf, certs.get(), crl.get());
}
+ void SetOptionalVerifyCert()
+ {
+ mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+ }
+
const mbedtls_ssl_config* GetConf() const { return &conf; }
};
@@ -376,7 +380,8 @@ namespace mbedTLS
const std::string& castr, const std::string& crlstr,
unsigned int recsize,
CTRDRBG& ctrdrbg,
- int minver, int maxver
+ int minver, int maxver,
+ bool requestclientcert
)
: name(profilename)
, x509cred(certstr, keystr)
@@ -414,7 +419,13 @@ namespace mbedTLS
serverctx.SetDHParams(dhparams);
}
- serverctx.SetCA(cacerts, crl);
+ clientctx.SetOptionalVerifyCert();
+ // The default for servers is to not request a client certificate from the peer
+ if (requestclientcert)
+ {
+ serverctx.SetOptionalVerifyCert();
+ serverctx.SetCA(cacerts, crl);
+ }
}
static std::string ReadFile(const std::string& filename)
@@ -451,7 +462,8 @@ namespace mbedTLS
int minver = tag->getInt("minver");
int maxver = tag->getInt("maxver");
unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
- return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver);
+ const bool requestclientcert = tag->getBool("requestclientcert", true);
+ return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert);
}
/** Set up the given session with the settings in this profile
@@ -698,7 +710,7 @@ class mbedTLSIOHook : public SSLIOHook
}
}
- int OnStreamSocketWrite(StreamSocket* sock) CXX11_OVERRIDE
+ int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
{
// Finish handshake if needed
int prepret = PrepareIO(sock);
@@ -706,7 +718,6 @@ class mbedTLSIOHook : public SSLIOHook
return prepret;
// Session is ready for transferring application data
- StreamSocket::SendQueue& sendq = sock->GetSendQ();
while (!sendq.empty())
{
FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
@@ -895,7 +906,7 @@ class ModuleSSLmbedTLS : public Module
return;
LocalUser* user = IS_LOCAL(static_cast<User*>(item));
- if ((user) && (user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
+ if ((user) && (user->eh.GetModHook(this)))
{
// User is using SSL, they're a local user, and they're using our IOHook.
// Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -905,13 +916,9 @@ class ModuleSSLmbedTLS : public Module
ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
{
- if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
- {
- mbedTLSIOHook* iohook = static_cast<mbedTLSIOHook*>(user->eh.GetIOHook());
- if (!iohook->IsHandshakeDone())
- return MOD_RES_DENY;
- }
-
+ const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
return MOD_RES_PASSTHRU;
}
diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp
index 80c9d9395..4df0d8962 100644
--- a/src/modules/extra/m_ssl_openssl.cpp
+++ b/src/modules/extra/m_ssl_openssl.cpp
@@ -132,7 +132,7 @@ namespace OpenSSL
mode |= SSL_MODE_RELEASE_BUFFERS;
#endif
SSL_CTX_set_mode(ctx, mode);
- SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+ SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
SSL_CTX_set_info_callback(ctx, StaticSSLInfoCallback);
}
@@ -206,6 +206,11 @@ namespace OpenSSL
return SSL_CTX_clear_options(ctx, clearoptions);
}
+ void SetVerifyCert()
+ {
+ SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, OnVerify);
+ }
+
SSL* CreateServerSession()
{
SSL* sess = SSL_new(ctx);
@@ -345,6 +350,10 @@ namespace OpenSSL
ERR_print_errors_cb(error_callback, this);
ServerInstance->Logs->Log(MODNAME, LOG_DEFAULT, "Can't read CA list from %s. This is only a problem if you want to verify client certificates, otherwise it's safe to ignore this message. Error: %s", filename.c_str(), lasterr.c_str());
}
+
+ clictx.SetVerifyCert();
+ if (tag->getBool("requestclientcert", true))
+ ctx.SetVerifyCert();
}
const std::string& GetName() const { return name; }
@@ -656,7 +665,7 @@ class OpenSSLIOHook : public SSLIOHook
}
}
- int OnStreamSocketWrite(StreamSocket* user) CXX11_OVERRIDE
+ int OnStreamSocketWrite(StreamSocket* user, StreamSocket::SendQueue& sendq) CXX11_OVERRIDE
{
// Finish handshake if needed
int prepret = PrepareIO(user);
@@ -666,7 +675,6 @@ class OpenSSLIOHook : public SSLIOHook
data_to_write = true;
// Session is ready for transferring application data
- StreamSocket::SendQueue& sendq = user->GetSendQ();
while (!sendq.empty())
{
ERR_clear_error();
@@ -910,7 +918,7 @@ class ModuleSSLOpenSSL : public Module
{
LocalUser* user = IS_LOCAL((User*)item);
- if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+ if ((user) && (user->eh.GetModHook(this)))
{
// User is using SSL, they're a local user, and they're using one of *our* SSL ports.
// Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -921,13 +929,9 @@ class ModuleSSLOpenSSL : public Module
ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
{
- if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
- {
- OpenSSLIOHook* iohook = static_cast<OpenSSLIOHook*>(user->eh.GetIOHook());
- if (!iohook->IsHandshakeDone())
- return MOD_RES_DENY;
- }
-
+ const OpenSSLIOHook* const iohook = static_cast<OpenSSLIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
return MOD_RES_PASSTHRU;
}
diff --git a/src/modules/m_httpd.cpp b/src/modules/m_httpd.cpp
index 760647d47..64bef70d1 100644
--- a/src/modules/m_httpd.cpp
+++ b/src/modules/m_httpd.cpp
@@ -78,8 +78,8 @@ class HttpServerSocket : public BufferedSocket, public Timer, public insp::intru
{
ServerInstance->Timers.AddTimer(this);
- if (via->iohookprov)
- via->iohookprov->OnAccept(this, client, server);
+ if ((!via->iohookprovs.empty()) && (via->iohookprovs.back()))
+ via->iohookprovs.back()->OnAccept(this, client, server);
}
~HttpServerSocket()
@@ -413,7 +413,7 @@ class ModuleHttpServer : public Module
{
HttpServerSocket* sock = *i;
++i;
- if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+ if (sock->GetModHook(mod))
{
sock->cull();
delete sock;
diff --git a/src/modules/m_sha1.cpp b/src/modules/m_sha1.cpp
new file mode 100644
index 000000000..5926e4926
--- /dev/null
+++ b/src/modules/m_sha1.cpp
@@ -0,0 +1,199 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ * Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
+ *
+ * This file is part of InspIRCd. InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/*
+SHA-1 in C
+By Steve Reid <steve@edmweb.com>
+100% Public Domain
+*/
+
+#include "inspircd.h"
+#include "modules/hash.h"
+
+union CHAR64LONG16
+{
+ unsigned char c[64];
+ uint32_t l[16];
+};
+
+inline static uint32_t rol(uint32_t value, uint32_t bits) { return (value << bits) | (value >> (32 - bits)); }
+
+// blk0() and blk() perform the initial expand.
+// I got the idea of expanding during the round function from SSLeay
+static bool big_endian;
+inline static uint32_t blk0(CHAR64LONG16& block, uint32_t i)
+{
+ if (big_endian)
+ return block.l[i];
+ else
+ return block.l[i] = (rol(block.l[i], 24) & 0xFF00FF00) | (rol(block.l[i], 8) & 0x00FF00FF);
+}
+inline static uint32_t blk(CHAR64LONG16 &block, uint32_t i) { return block.l[i & 15] = rol(block.l[(i + 13) & 15] ^ block.l[(i + 8) & 15] ^ block.l[(i + 2) & 15] ^ block.l[i & 15],1); }
+
+// (R0+R1), R2, R3, R4 are the different operations used in SHA1
+inline static void R0(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += ((w & (x ^ y)) ^ y) + blk0(block, i) + 0x5A827999 + rol(v, 5); w = rol(w, 30); }
+inline static void R1(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += ((w & (x ^ y)) ^ y) + blk(block, i) + 0x5A827999 + rol(v, 5); w = rol(w, 30); }
+inline static void R2(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += (w ^ x ^ y) + blk(block, i) + 0x6ED9EBA1 + rol(v, 5); w = rol(w, 30); }
+inline static void R3(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += (((w | x) & y) | (w & x)) + blk(block, i) + 0x8F1BBCDC + rol(v, 5); w = rol(w, 30); }
+inline static void R4(CHAR64LONG16& block, uint32_t v, uint32_t &w, uint32_t x, uint32_t y, uint32_t &z, uint32_t i) { z += (w ^ x ^ y) + blk(block, i) + 0xCA62C1D6 + rol(v, 5); w = rol(w, 30); }
+
+static const uint32_t sha1_iv[5] =
+{
+ 0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0
+};
+
+class SHA1Context
+{
+ uint32_t state[5];
+ uint32_t count[2];
+ unsigned char buffer[64];
+ unsigned char digest[20];
+
+ void Transform(const unsigned char buf[64])
+ {
+ uint32_t a, b, c, d, e;
+
+ CHAR64LONG16 block;
+ memcpy(block.c, buf, 64);
+
+ // Copy state[] to working vars
+ a = this->state[0];
+ b = this->state[1];
+ c = this->state[2];
+ d = this->state[3];
+ e = this->state[4];
+
+ // 4 rounds of 20 operations each. Loop unrolled.
+ R0(block, a, b, c, d, e, 0); R0(block, e, a, b, c, d, 1); R0(block, d, e, a, b, c, 2); R0(block, c, d, e, a, b, 3);
+ R0(block, b, c, d, e, a, 4); R0(block, a, b, c, d, e, 5); R0(block, e, a, b, c, d, 6); R0(block, d, e, a, b, c, 7);
+ R0(block, c, d, e, a, b, 8); R0(block, b, c, d, e, a, 9); R0(block, a, b, c, d, e, 10); R0(block, e, a, b, c, d, 11);
+ R0(block, d, e, a, b, c, 12); R0(block, c, d, e, a, b, 13); R0(block, b, c, d, e, a, 14); R0(block, a, b, c, d, e, 15);
+ R1(block, e, a, b, c, d, 16); R1(block, d, e, a, b, c, 17); R1(block, c, d, e, a, b, 18); R1(block, b, c, d, e, a, 19);
+ R2(block, a, b, c, d, e, 20); R2(block, e, a, b, c, d, 21); R2(block, d, e, a, b, c, 22); R2(block, c, d, e, a, b, 23);
+ R2(block, b, c, d, e, a, 24); R2(block, a, b, c, d, e, 25); R2(block, e, a, b, c, d, 26); R2(block, d, e, a, b, c, 27);
+ R2(block, c, d, e, a, b, 28); R2(block, b, c, d, e, a, 29); R2(block, a, b, c, d, e, 30); R2(block, e, a, b, c, d, 31);
+ R2(block, d, e, a, b, c, 32); R2(block, c, d, e, a, b, 33); R2(block, b, c, d, e, a, 34); R2(block, a, b, c, d, e, 35);
+ R2(block, e, a, b, c, d, 36); R2(block, d, e, a, b, c, 37); R2(block, c, d, e, a, b, 38); R2(block, b, c, d, e, a, 39);
+ R3(block, a, b, c, d, e, 40); R3(block, e, a, b, c, d, 41); R3(block, d, e, a, b, c, 42); R3(block, c, d, e, a, b, 43);
+ R3(block, b, c, d, e, a, 44); R3(block, a, b, c, d, e, 45); R3(block, e, a, b, c, d, 46); R3(block, d, e, a, b, c, 47);
+ R3(block, c, d, e, a, b, 48); R3(block, b, c, d, e, a, 49); R3(block, a, b, c, d, e, 50); R3(block, e, a, b, c, d, 51);
+ R3(block, d, e, a, b, c, 52); R3(block, c, d, e, a, b, 53); R3(block, b, c, d, e, a, 54); R3(block, a, b, c, d, e, 55);
+ R3(block, e, a, b, c, d, 56); R3(block, d, e, a, b, c, 57); R3(block, c, d, e, a, b, 58); R3(block, b, c, d, e, a, 59);
+ R4(block, a, b, c, d, e, 60); R4(block, e, a, b, c, d, 61); R4(block, d, e, a, b, c, 62); R4(block, c, d, e, a, b, 63);
+ R4(block, b, c, d, e, a, 64); R4(block, a, b, c, d, e, 65); R4(block, e, a, b, c, d, 66); R4(block, d, e, a, b, c, 67);
+ R4(block, c, d, e, a, b, 68); R4(block, b, c, d, e, a, 69); R4(block, a, b, c, d, e, 70); R4(block, e, a, b, c, d, 71);
+ R4(block, d, e, a, b, c, 72); R4(block, c, d, e, a, b, 73); R4(block, b, c, d, e, a, 74); R4(block, a, b, c, d, e, 75);
+ R4(block, e, a, b, c, d, 76); R4(block, d, e, a, b, c, 77); R4(block, c, d, e, a, b, 78); R4(block, b, c, d, e, a, 79);
+ // Add the working vars back into state[]
+ this->state[0] += a;
+ this->state[1] += b;
+ this->state[2] += c;
+ this->state[3] += d;
+ this->state[4] += e;
+ }
+
+ public:
+ SHA1Context()
+ {
+ for (int i = 0; i < 5; ++i)
+ this->state[i] = sha1_iv[i];
+
+ this->count[0] = this->count[1] = 0;
+ memset(this->buffer, 0, sizeof(this->buffer));
+ memset(this->digest, 0, sizeof(this->digest));
+ }
+
+ void Update(const unsigned char* data, size_t len)
+ {
+ uint32_t i, j;
+
+ j = (this->count[0] >> 3) & 63;
+ if ((this->count[0] += len << 3) < (len << 3))
+ ++this->count[1];
+ this->count[1] += len >> 29;
+ if (j + len > 63)
+ {
+ memcpy(&this->buffer[j], data, (i = 64 - j));
+ this->Transform(this->buffer);
+ for (; i + 63 < len; i += 64)
+ this->Transform(&data[i]);
+ j = 0;
+ }
+ else
+ i = 0;
+ memcpy(&this->buffer[j], &data[i], len - i);
+ }
+
+ void Finalize()
+ {
+ uint32_t i;
+ unsigned char finalcount[8];
+
+ for (i = 0; i < 8; ++i)
+ finalcount[i] = static_cast<unsigned char>((this->count[i >= 4 ? 0 : 1] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
+ this->Update(reinterpret_cast<const unsigned char *>("\200"), 1);
+ while ((this->count[0] & 504) != 448)
+ this->Update(reinterpret_cast<const unsigned char *>("\0"), 1);
+ this->Update(finalcount, 8); // Should cause a SHA1Transform()
+ for (i = 0; i < 20; ++i)
+ this->digest[i] = static_cast<unsigned char>((this->state[i>>2] >> ((3 - (i & 3)) * 8)) & 255);
+
+ this->Transform(this->buffer);
+ }
+
+ std::string GetRaw() const
+ {
+ return std::string((const char*)digest, sizeof(digest));
+ }
+};
+
+class SHA1HashProvider : public HashProvider
+{
+ public:
+ SHA1HashProvider(Module* mod)
+ : HashProvider(mod, "hash/sha1", 20, 64)
+ {
+ }
+
+ std::string GenerateRaw(const std::string& data)
+ {
+ SHA1Context ctx;
+ ctx.Update(reinterpret_cast<const unsigned char*>(data.data()), data.length());
+ ctx.Finalize();
+ return ctx.GetRaw();
+ }
+};
+
+class ModuleSHA1 : public Module
+{
+ SHA1HashProvider sha1;
+
+ public:
+ ModuleSHA1()
+ : sha1(this)
+ {
+ big_endian = (htonl(1337) == 1337);
+ }
+
+ Version GetVersion()
+ {
+ return Version("Implements SHA-1 hashing", VF_VENDOR);
+ }
+};
+
+MODULE_INIT(ModuleSHA1)
diff --git a/src/modules/m_spanningtree/main.cpp b/src/modules/m_spanningtree/main.cpp
index 0b9bb65df..81543b0da 100644
--- a/src/modules/m_spanningtree/main.cpp
+++ b/src/modules/m_spanningtree/main.cpp
@@ -635,7 +635,7 @@ restart:
for (TreeServer::ChildServers::const_iterator i = list.begin(); i != list.end(); ++i)
{
TreeSocket* sock = (*i)->GetSocket();
- if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+ if (sock->GetModHook(mod))
{
sock->SendError("SSL module unloaded");
sock->Close();
@@ -647,7 +647,7 @@ restart:
for (SpanningTreeUtilities::TimeoutList::const_iterator i = Utils->timeoutlist.begin(); i != Utils->timeoutlist.end(); ++i)
{
TreeSocket* sock = i->first;
- if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+ if (sock->GetModHook(mod))
sock->Close();
}
}
diff --git a/src/modules/m_spanningtree/treesocket1.cpp b/src/modules/m_spanningtree/treesocket1.cpp
index 2198d6208..e1642a086 100644
--- a/src/modules/m_spanningtree/treesocket1.cpp
+++ b/src/modules/m_spanningtree/treesocket1.cpp
@@ -60,8 +60,13 @@ TreeSocket::TreeSocket(int newfd, ListenSocket* via, irc::sockets::sockaddrs* cl
capab = new CapabData;
capab->capab_phase = 0;
- if (via->iohookprov)
- via->iohookprov->OnAccept(this, client, server);
+ for (ListenSocket::IOHookProvList::iterator i = via->iohookprovs.begin(); i != via->iohookprovs.end(); ++i)
+ {
+ ListenSocket::IOHookProvRef& iohookprovref = *i;
+ if (iohookprovref)
+ iohookprovref->OnAccept(this, client, server);
+ }
+
SendCapabilities(1);
Utils->timeoutlist[this] = std::pair<std::string, int>(linkID, 30);
diff --git a/src/modules/m_websocket.cpp b/src/modules/m_websocket.cpp
new file mode 100644
index 000000000..399b0b017
--- /dev/null
+++ b/src/modules/m_websocket.cpp
@@ -0,0 +1,405 @@
+/*
+ * InspIRCd -- Internet Relay Chat Daemon
+ *
+ * Copyright (C) 2016 Attila Molnar <attilamolnar@hush.com>
+ *
+ * This file is part of InspIRCd. InspIRCd is free software: you can
+ * redistribute it and/or modify it under the terms of the GNU General Public
+ * License as published by the Free Software Foundation, version 2.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+ * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
+ * details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+
+#include "inspircd.h"
+#include "iohook.h"
+#include "modules/hash.h"
+
+static const char MagicGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+static const char whitespace[] = " \t\r\n";
+static dynamic_reference_nocheck<HashProvider>* sha1;
+
+class WebSocketHookProvider : public IOHookProvider
+{
+ public:
+ WebSocketHookProvider(Module* mod)
+ : IOHookProvider(mod, "websocket", IOHookProvider::IOH_UNKNOWN, true)
+ {
+ }
+
+ void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE;
+
+ void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
+ {
+ }
+};
+
+class WebSocketHook : public IOHookMiddle
+{
+ class HTTPHeaderFinder
+ {
+ std::string::size_type bpos;
+ std::string::size_type len;
+
+ public:
+ bool Find(const std::string& req, const char* header, std::string::size_type headerlen, std::string::size_type maxpos)
+ {
+ std::string::size_type keybegin = req.find(header);
+ if ((keybegin == std::string::npos) || (keybegin > maxpos) || (keybegin == 0) || (req[keybegin-1] != '\n'))
+ return false;
+
+ keybegin += headerlen;
+
+ bpos = req.find_first_not_of(whitespace, keybegin, sizeof(whitespace)-1);
+ if ((bpos == std::string::npos) || (bpos > maxpos))
+ return false;
+
+ const std::string::size_type epos = req.find_first_of(whitespace, bpos, sizeof(whitespace)-1);
+ len = epos - bpos;
+
+ return true;
+ }
+
+ std::string ExtractValue(const std::string& req) const
+ {
+ return std::string(req, bpos, len);
+ }
+ };
+
+ enum OpCode
+ {
+ OP_CONTINUATION = 0x00,
+ OP_TEXT = 0x01,
+ OP_BINARY = 0x02,
+ OP_CLOSE = 0x08,
+ OP_PING = 0x09,
+ OP_PONG = 0x0a
+ };
+
+ enum State
+ {
+ STATE_HTTPREQ,
+ STATE_ESTABLISHED
+ };
+
+ static const unsigned char WS_MASKBIT = (1 << 7);
+ static const unsigned char WS_FINBIT = (1 << 7);
+ static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_LARGE = 126;
+ static const unsigned char WS_PAYLOAD_LENGTH_MAGIC_HUGE = 127;
+ static const size_t WS_MAX_PAYLOAD_LENGTH_SMALL = 125;
+ static const size_t WS_MAX_PAYLOAD_LENGTH_LARGE = 65535;
+ static const size_t MAXHEADERSIZE = sizeof(uint64_t) + 2;
+
+ // Clients sending ping or pong frames faster than this are killed
+ static const time_t MINPINGPONGDELAY = 10;
+
+ State state;
+ time_t lastpingpong;
+
+ static size_t FillHeader(unsigned char* outbuf, size_t sendlength, OpCode opcode)
+ {
+ size_t pos = 0;
+ outbuf[pos++] = WS_FINBIT | opcode;
+
+ if (sendlength <= WS_MAX_PAYLOAD_LENGTH_SMALL)
+ {
+ outbuf[pos++] = sendlength;
+ }
+ else if (sendlength <= WS_MAX_PAYLOAD_LENGTH_LARGE)
+ {
+ outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_LARGE;
+ outbuf[pos++] = (sendlength >> 8) & 0xff;
+ outbuf[pos++] = sendlength & 0xff;
+ }
+ else
+ {
+ outbuf[pos++] = WS_PAYLOAD_LENGTH_MAGIC_HUGE;
+ const uint64_t len = sendlength;
+ for (int i = sizeof(uint64_t)-1; i >= 0; i--)
+ outbuf[pos++] = ((len >> i*8) & 0xff);
+ }
+
+ return pos;
+ }
+
+ static StreamSocket::SendQueue::Element PrepareSendQElem(size_t size, OpCode opcode)
+ {
+ unsigned char header[MAXHEADERSIZE];
+ const size_t n = FillHeader(header, size, opcode);
+
+ return StreamSocket::SendQueue::Element(reinterpret_cast<const char*>(header), n);
+ }
+
+ int HandleAppData(StreamSocket* sock, std::string& appdataout, bool allowlarge)
+ {
+ std::string& myrecvq = GetRecvQ();
+ // Need 1 byte opcode, minimum 1 byte len, 4 bytes masking key
+ if (myrecvq.length() < 6)
+ return 0;
+
+ const std::string& cmyrecvq = myrecvq;
+ unsigned char len1 = (unsigned char)cmyrecvq[1];
+ if (!(len1 & WS_MASKBIT))
+ {
+ sock->SetError("WebSocket protocol violation: unmasked client frame");
+ return -1;
+ }
+
+ len1 &= ~WS_MASKBIT;
+
+ // Assume the length is a single byte, if not, update values later
+ unsigned int len = len1;
+ unsigned int payloadstartoffset = 6;
+ const unsigned char* maskkey = reinterpret_cast<const unsigned char*>(&cmyrecvq[2]);
+
+ if (len1 == WS_PAYLOAD_LENGTH_MAGIC_LARGE)
+ {
+ // allowlarge is false for control frames according to the RFC meaning large pings, etc. are not allowed
+ if (!allowlarge)
+ {
+ sock->SetError("WebSocket protocol violation: large control frame");
+ return -1;
+ }
+
+ // Large frame, has 2 bytes len after the magic byte indicating the length
+ // Need 1 byte opcode, 3 bytes len, 4 bytes masking key
+ if (myrecvq.length() < 8)
+ return 0;
+
+ unsigned char len2 = (unsigned char)cmyrecvq[2];
+ unsigned char len3 = (unsigned char)cmyrecvq[3];
+ len = (len2 << 8) | len3;
+
+ if (len <= WS_MAX_PAYLOAD_LENGTH_SMALL)
+ {
+ sock->SetError("WebSocket protocol violation: non-minimal length encoding used");
+ return -1;
+ }
+
+ maskkey += 2;
+ payloadstartoffset += 2;
+ }
+ else if (len1 == WS_PAYLOAD_LENGTH_MAGIC_HUGE)
+ {
+ sock->SetError("WebSocket: Huge frames are not supported");
+ return -1;
+ }
+
+ if (myrecvq.length() < payloadstartoffset + len)
+ return 0;
+
+ unsigned int maskkeypos = 0;
+ const std::string::iterator endit = myrecvq.begin() + payloadstartoffset + len;
+ for (std::string::const_iterator i = myrecvq.begin() + payloadstartoffset; i != endit; ++i)
+ {
+ const unsigned char c = (unsigned char)*i;
+ appdataout.push_back(c ^ maskkey[maskkeypos++]);
+ maskkeypos %= 4;
+ }
+
+ myrecvq.erase(myrecvq.begin(), endit);
+ return 1;
+ }
+
+ int HandlePingPongFrame(StreamSocket* sock, bool isping)
+ {
+ if (lastpingpong + MINPINGPONGDELAY >= ServerInstance->Time())
+ {
+ sock->SetError("WebSocket: Ping/pong flood");
+ return -1;
+ }
+
+ lastpingpong = ServerInstance->Time();
+
+ std::string appdata;
+ const int result = HandleAppData(sock, appdata, false);
+ // If it's a pong stop here regardless of the result so we won't generate a reply
+ if ((result <= 0) || (!isping))
+ return result;
+
+ StreamSocket::SendQueue::Element elem = PrepareSendQElem(appdata.length(), OP_PONG);
+ elem.append(appdata);
+ GetSendQ().push_back(elem);
+
+ SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
+ return 1;
+ }
+
+ int HandleWS(StreamSocket* sock, std::string& destrecvq)
+ {
+ if (GetRecvQ().empty())
+ return 0;
+
+ unsigned char opcode = (unsigned char)GetRecvQ().c_str()[0];
+ opcode &= ~WS_FINBIT;
+
+ switch (opcode)
+ {
+ case OP_CONTINUATION:
+ case OP_TEXT:
+ case OP_BINARY:
+ {
+ return HandleAppData(sock, destrecvq, true);
+ }
+
+ case OP_PING:
+ {
+ return HandlePingPongFrame(sock, true);
+ }
+
+ case OP_PONG:
+ {
+ // A pong frame may be sent unsolicited, so we have to handle it.
+ // It may carry application data which we need to remove from the recvq as well.
+ return HandlePingPongFrame(sock, false);
+ }
+
+ case OP_CLOSE:
+ {
+ sock->SetError("Connection closed");
+ return -1;
+ }
+
+ default:
+ {
+ sock->SetError("WebSocket: Invalid opcode");
+ return -1;
+ }
+ }
+ }
+
+ void FailHandshake(StreamSocket* sock, const char* httpreply, const char* sockerror)
+ {
+ GetSendQ().push_back(StreamSocket::SendQueue::Element(httpreply));
+ sock->DoWrite();
+ sock->SetError(sockerror);
+ }
+
+ int HandleHTTPReq(StreamSocket* sock)
+ {
+ std::string& recvq = GetRecvQ();
+ const std::string::size_type reqend = recvq.find("\r\n\r\n");
+ if (reqend == std::string::npos)
+ return 0;
+
+ HTTPHeaderFinder keyheader;
+ if (!keyheader.Find(recvq, "Sec-WebSocket-Key:", 18, reqend))
+ {
+ FailHandshake(sock, "HTTP/1.1 501 Not Implemented\r\nConnection: close\r\n\r\n", "WebSocket: Received HTTP request which is not a websocket upgrade");
+ return -1;
+ }
+
+ if (!*sha1)
+ {
+ FailHandshake(sock, "HTTP/1.1 503 Service Unavailable\r\nConnection: close\r\n\r\n", "WebSocket: SHA-1 provider missing");
+ return -1;
+ }
+
+ state = STATE_ESTABLISHED;
+
+ std::string key = keyheader.ExtractValue(recvq);
+ key.append(MagicGUID);
+
+ std::string reply = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
+ reply.append(BinToBase64((*sha1)->GenerateRaw(key), NULL, '=')).append("\r\n\r\n");
+ GetSendQ().push_back(StreamSocket::SendQueue::Element(reply));
+
+ SocketEngine::ChangeEventMask(sock, FD_ADD_TRIAL_WRITE);
+
+ recvq.erase(0, reqend + 4);
+
+ return 1;
+ }
+
+ public:
+ WebSocketHook(IOHookProvider* Prov, StreamSocket* sock)
+ : IOHookMiddle(Prov)
+ , state(STATE_HTTPREQ)
+ , lastpingpong(0)
+ {
+ sock->AddIOHook(this);
+ }
+
+ int OnStreamSocketWrite(StreamSocket* sock, StreamSocket::SendQueue& uppersendq) CXX11_OVERRIDE
+ {
+ StreamSocket::SendQueue& mysendq = GetSendQ();
+
+ // Return 1 to allow sending back an error HTTP response
+ if (state != STATE_ESTABLISHED)
+ return (mysendq.empty() ? 0 : 1);
+
+ if (!uppersendq.empty())
+ {
+ StreamSocket::SendQueue::Element elem = PrepareSendQElem(uppersendq.bytes(), OP_BINARY);
+ mysendq.push_back(elem);
+ mysendq.moveall(uppersendq);
+ }
+
+ return 1;
+ }
+
+ int OnStreamSocketRead(StreamSocket* sock, std::string& destrecvq) CXX11_OVERRIDE
+ {
+ if (state == STATE_HTTPREQ)
+ {
+ int httpret = HandleHTTPReq(sock);
+ if (httpret <= 0)
+ return httpret;
+ }
+
+ int wsret;
+ do
+ {
+ wsret = HandleWS(sock, destrecvq);
+ }
+ while ((!GetRecvQ().empty()) && (wsret > 0));
+
+ return wsret;
+ }
+
+ void OnStreamSocketClose(StreamSocket* sock) CXX11_OVERRIDE
+ {
+ }
+};
+
+void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
+{
+ new WebSocketHook(this, sock);
+}
+
+class ModuleWebSocket : public Module
+{
+ dynamic_reference_nocheck<HashProvider> hash;
+ WebSocketHookProvider hookprov;
+
+ public:
+ ModuleWebSocket()
+ : hash(this, "hash/sha1")
+ , hookprov(this)
+ {
+ sha1 = &hash;
+ }
+
+ void OnCleanup(int target_type, void* item) CXX11_OVERRIDE
+ {
+ if (target_type != TYPE_USER)
+ return;
+
+ LocalUser* user = IS_LOCAL(static_cast<User*>(item));
+ if ((user) && (user->eh.GetModHook(this)))
+ ServerInstance->Users.QuitUser(user, "WebSocket module unloading");
+ }
+
+ Version GetVersion() CXX11_OVERRIDE
+ {
+ return Version("Provides RFC 6455 WebSocket support", VF_VENDOR);
+ }
+};
+
+MODULE_INIT(ModuleWebSocket)
diff --git a/src/usermanager.cpp b/src/usermanager.cpp
index fe052fcfc..95deca00a 100644
--- a/src/usermanager.cpp
+++ b/src/usermanager.cpp
@@ -72,8 +72,12 @@ void UserManager::AddUser(int socket, ListenSocket* via, irc::sockets::sockaddrs
UserIOHandler* eh = &New->eh;
// If this listener has an IO hook provider set then tell it about the connection
- if (via->iohookprov)
- via->iohookprov->OnAccept(eh, client, server);
+ for (ListenSocket::IOHookProvList::iterator i = via->iohookprovs.begin(); i != via->iohookprovs.end(); ++i)
+ {
+ ListenSocket::IOHookProvRef& iohookprovref = *i;
+ if (iohookprovref)
+ iohookprovref->OnAccept(eh, client, server);
+ }
ServerInstance->Logs->Log("USERS", LOG_DEBUG, "New user fd: %d", socket);