summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAttila Molnar <attilamolnar@hush.com>2016-08-08 14:31:49 +0200
committerAttila Molnar <attilamolnar@hush.com>2016-08-08 14:31:49 +0200
commiteef55acb1dbb2ae6c0202fec54e12506c064f892 (patch)
treed7e4e3e313053294192715d9f7bca1b426b8d232
parent3a11a742ba35155e1b2e14dc4ef1a4f7f659ea13 (diff)
Add StreamSocket::GetModHook() for obtaining the IOHook belonging to a given module
Use it to simplify logic in all modules using or providing IOHooks
-rw-r--r--include/inspsocket.h6
-rw-r--r--src/inspsocket.cpp10
-rw-r--r--src/modules/extra/m_ssl_gnutls.cpp23
-rw-r--r--src/modules/extra/m_ssl_mbedtls.cpp12
-rw-r--r--src/modules/extra/m_ssl_openssl.cpp12
-rw-r--r--src/modules/m_httpd.cpp2
-rw-r--r--src/modules/m_spanningtree/main.cpp4
7 files changed, 38 insertions, 31 deletions
diff --git a/include/inspsocket.h b/include/inspsocket.h
index 53eca2e91..72fb03d58 100644
--- a/include/inspsocket.h
+++ b/include/inspsocket.h
@@ -284,6 +284,12 @@ class CoreExport StreamSocket : public EventHandler
virtual void Close();
/** This ensures that close is called prior to destructor */
virtual CullResult cull();
+
+ /** Get the IOHook of a module attached to this socket
+ * @param mod Module whose IOHook to return
+ * @return IOHook belonging to the module or NULL if the module haven't attached an IOHook to this socket
+ */
+ IOHook* GetModHook(Module* mod) const;
};
/**
* BufferedSocket is an extendable socket class which modules
diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp
index dcc455482..0b0507f7c 100644
--- a/src/inspsocket.cpp
+++ b/src/inspsocket.cpp
@@ -434,3 +434,13 @@ void StreamSocket::CheckError(BufferedSocketError errcode)
OnError(errcode);
}
}
+
+IOHook* StreamSocket::GetModHook(Module* mod) const
+{
+ if (iohook)
+ {
+ if (iohook->prov->creator == mod)
+ return iohook;
+ }
+ return NULL;
+}
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp
index 44a49d895..dfd3b47dd 100644
--- a/src/modules/extra/m_ssl_gnutls.cpp
+++ b/src/modules/extra/m_ssl_gnutls.cpp
@@ -101,6 +101,8 @@ typedef gnutls_connection_end_t inspircd_gnutls_session_init_flags_t;
#define INSPIRCD_GNUTLS_HAS_CORK
#endif
+static Module* thismod;
+
class RandGen : public HandlerBase2<void, char*, size_t>
{
public:
@@ -917,7 +919,7 @@ info_done_dealloc:
{
StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
#ifdef _WIN32
- GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
#endif
if (sock->GetEventMask() & FD_READ_WILL_BLOCK)
@@ -954,7 +956,7 @@ info_done_dealloc:
{
StreamSocket* sock = reinterpret_cast<StreamSocket*>(transportptr);
#ifdef _WIN32
- GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
#endif
if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -989,7 +991,7 @@ info_done_dealloc:
{
StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap);
#ifdef _WIN32
- GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook());
+ GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod));
#endif
if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK)
@@ -1172,7 +1174,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
st->key_type = GNUTLS_PRIVKEY_X509;
#endif
StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
- GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials();
+ GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials();
st->ncerts = cred.certs.size();
st->cert.x509 = cred.certs.raw();
@@ -1282,6 +1284,7 @@ class ModuleSSLGnuTLS : public Module
#ifndef GNUTLS_HAS_RND
gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0);
#endif
+ thismod = this;
}
void init() CXX11_OVERRIDE
@@ -1317,7 +1320,7 @@ class ModuleSSLGnuTLS : public Module
{
LocalUser* user = IS_LOCAL(static_cast<User*>(item));
- if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+ if ((user) && (user->eh.GetModHook(this)))
{
// User is using SSL, they're a local user, and they're using one of *our* SSL ports.
// Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -1333,13 +1336,9 @@ class ModuleSSLGnuTLS : public Module
ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
{
- if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
- {
- GnuTLSIOHook* iohook = static_cast<GnuTLSIOHook*>(user->eh.GetIOHook());
- if (!iohook->IsHandshakeDone())
- return MOD_RES_DENY;
- }
-
+ const GnuTLSIOHook* const iohook = static_cast<GnuTLSIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
return MOD_RES_PASSTHRU;
}
};
diff --git a/src/modules/extra/m_ssl_mbedtls.cpp b/src/modules/extra/m_ssl_mbedtls.cpp
index 7efcce72d..845d02aa3 100644
--- a/src/modules/extra/m_ssl_mbedtls.cpp
+++ b/src/modules/extra/m_ssl_mbedtls.cpp
@@ -894,7 +894,7 @@ class ModuleSSLmbedTLS : public Module
return;
LocalUser* user = IS_LOCAL(static_cast<User*>(item));
- if ((user) && (user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
+ if ((user) && (user->eh.GetModHook(this)))
{
// User is using SSL, they're a local user, and they're using our IOHook.
// Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -904,13 +904,9 @@ class ModuleSSLmbedTLS : public Module
ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
{
- if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
- {
- mbedTLSIOHook* iohook = static_cast<mbedTLSIOHook*>(user->eh.GetIOHook());
- if (!iohook->IsHandshakeDone())
- return MOD_RES_DENY;
- }
-
+ const mbedTLSIOHook* const iohook = static_cast<mbedTLSIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
return MOD_RES_PASSTHRU;
}
diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp
index 5587f323a..4ad556438 100644
--- a/src/modules/extra/m_ssl_openssl.cpp
+++ b/src/modules/extra/m_ssl_openssl.cpp
@@ -909,7 +909,7 @@ class ModuleSSLOpenSSL : public Module
{
LocalUser* user = IS_LOCAL((User*)item);
- if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this)
+ if ((user) && (user->eh.GetModHook(this)))
{
// User is using SSL, they're a local user, and they're using one of *our* SSL ports.
// Potentially there could be multiple SSL modules loaded at once on different ports.
@@ -920,13 +920,9 @@ class ModuleSSLOpenSSL : public Module
ModResult OnCheckReady(LocalUser* user) CXX11_OVERRIDE
{
- if ((user->eh.GetIOHook()) && (user->eh.GetIOHook()->prov->creator == this))
- {
- OpenSSLIOHook* iohook = static_cast<OpenSSLIOHook*>(user->eh.GetIOHook());
- if (!iohook->IsHandshakeDone())
- return MOD_RES_DENY;
- }
-
+ const OpenSSLIOHook* const iohook = static_cast<OpenSSLIOHook*>(user->eh.GetModHook(this));
+ if ((iohook) && (!iohook->IsHandshakeDone()))
+ return MOD_RES_DENY;
return MOD_RES_PASSTHRU;
}
diff --git a/src/modules/m_httpd.cpp b/src/modules/m_httpd.cpp
index 760647d47..0b6b2e32b 100644
--- a/src/modules/m_httpd.cpp
+++ b/src/modules/m_httpd.cpp
@@ -413,7 +413,7 @@ class ModuleHttpServer : public Module
{
HttpServerSocket* sock = *i;
++i;
- if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+ if (sock->GetModHook(mod))
{
sock->cull();
delete sock;
diff --git a/src/modules/m_spanningtree/main.cpp b/src/modules/m_spanningtree/main.cpp
index 0b9bb65df..81543b0da 100644
--- a/src/modules/m_spanningtree/main.cpp
+++ b/src/modules/m_spanningtree/main.cpp
@@ -635,7 +635,7 @@ restart:
for (TreeServer::ChildServers::const_iterator i = list.begin(); i != list.end(); ++i)
{
TreeSocket* sock = (*i)->GetSocket();
- if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+ if (sock->GetModHook(mod))
{
sock->SendError("SSL module unloaded");
sock->Close();
@@ -647,7 +647,7 @@ restart:
for (SpanningTreeUtilities::TimeoutList::const_iterator i = Utils->timeoutlist.begin(); i != Utils->timeoutlist.end(); ++i)
{
TreeSocket* sock = i->first;
- if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod)
+ if (sock->GetModHook(mod))
sock->Close();
}
}