diff options
author | Attila Molnar <attilamolnar@hush.com> | 2013-09-24 20:40:20 +0200 |
---|---|---|
committer | Attila Molnar <attilamolnar@hush.com> | 2014-01-22 19:10:01 +0100 |
commit | 99f79a4e5c3abbe91a03216824e7659051872054 (patch) | |
tree | 629ed4d4cccb115e95f53c582047bc239d213624 | |
parent | 282138ad0e9ef483ec2a1606376fc5cb6d5f4cbc (diff) |
Split IOHook into IOHook and IOHookProvider
Create one IOHook instance for each hooked socket which contains all the
hook specific data and read/write/close functions, removing the need for
the "issl_session" array in SSL modules.
Register instances of the IOHookProvider class in the core and use them to
create specialized IOHook instances (OnConnect/OnAccept).
Remove the OnHookIO hook, add a dynamic reference to ListenSocket that
points to the hook provider (if any) to use for incoming connections on
that socket.
For outgoing connections modules still have to find the IOHookProvider
they want to use themselves but instead of calling AddIOHook(hookprov),
now they have to call IOHookProvider::OnConnect() after the connection
has been established.
-rw-r--r-- | include/iohook.h | 34 | ||||
-rw-r--r-- | include/modules.h | 8 | ||||
-rw-r--r-- | include/modules/ssl.h | 30 | ||||
-rw-r--r-- | include/socket.h | 13 | ||||
-rw-r--r-- | src/inspsocket.cpp | 5 | ||||
-rw-r--r-- | src/listensocket.cpp | 16 | ||||
-rw-r--r-- | src/modules.cpp | 1 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 273 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_openssl.cpp | 285 | ||||
-rw-r--r-- | src/modules/m_httpd.cpp | 5 | ||||
-rw-r--r-- | src/modules/m_spanningtree/main.cpp | 4 | ||||
-rw-r--r-- | src/modules/m_spanningtree/treesocket1.cpp | 27 | ||||
-rw-r--r-- | src/modules/m_starttls.cpp | 9 | ||||
-rw-r--r-- | src/socket.cpp | 2 | ||||
-rw-r--r-- | src/usermanager.cpp | 17 |
15 files changed, 325 insertions, 404 deletions
diff --git a/include/iohook.h b/include/iohook.h index 7c3a0faee..ce7ca2a1b 100644 --- a/include/iohook.h +++ b/include/iohook.h @@ -21,7 +21,7 @@ class StreamSocket; -class IOHook : public ServiceProvider +class IOHookProvider : public ServiceProvider { public: enum Type @@ -32,19 +32,35 @@ class IOHook : public ServiceProvider const Type type; - IOHook(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN) + IOHookProvider(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN) : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { } - /** Called immediately after any connection is accepted. This is intended for raw socket + /** Called immediately after a connection is accepted. This is intended for raw socket * processing (e.g. modules which wrap the tcp connection within another library) and provides * no information relating to a user record as the connection has not been assigned yet. - * There are no return values from this call as all modules get an opportunity if required to - * process the connection. * @param sock The socket in question * @param client The client IP address and port * @param server The server IP address and port */ - virtual void OnStreamSocketAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) = 0; + virtual void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) = 0; + + /** Called immediately upon connection of an outbound BufferedSocket which has been hooked + * by a module. + * @param sock The socket in question + */ + virtual void OnConnect(StreamSocket* sock) = 0; +}; + +class IOHook : public classbase +{ + public: + /** The IOHookProvider for this hook, contains information about the hook, + * such as the module providing it and the hook type. + */ + IOHookProvider* const prov; + + IOHook(IOHookProvider* provider) + : prov(provider) { } /** * Called when a hooked stream has data to write, or when the socket @@ -62,12 +78,6 @@ class IOHook : public ServiceProvider */ virtual void OnStreamSocketClose(StreamSocket* sock) = 0; - /** Called immediately upon connection of an outbound BufferedSocket which has been hooked - * by a module. - * @param sock The socket in question - */ - virtual void OnStreamSocketConnect(StreamSocket* sock) = 0; - /** * Called when the stream socket has data to read * @param sock The socket that is ready diff --git a/include/modules.h b/include/modules.h index 0be1ea294..7223f6b9d 100644 --- a/include/modules.h +++ b/include/modules.h @@ -264,7 +264,7 @@ enum Implementation I_OnChangeLocalUserGECOS, I_OnUserRegister, I_OnChannelPreDelete, I_OnChannelDelete, I_OnPostOper, I_OnSyncNetwork, I_OnSetAway, I_OnPostCommand, I_OnPostJoin, I_OnWhoisLine, I_OnBuildNeighborList, I_OnGarbageCollect, I_OnSetConnectClass, - I_OnText, I_OnPassCompare, I_OnRunTestSuite, I_OnNamesListItem, I_OnNumeric, I_OnHookIO, + I_OnText, I_OnPassCompare, I_OnRunTestSuite, I_OnNamesListItem, I_OnNumeric, I_OnPreRehash, I_OnModuleRehash, I_OnSendWhoLine, I_OnChangeIdent, I_OnSetUserIP, I_END }; @@ -989,12 +989,6 @@ class CoreExport Module : public classbase, public usecountbase */ virtual void OnPostConnect(User* user); - /** Called to install an I/O hook on an event handler - * @param user The socket to possibly install the I/O hook on - * @param via The port that the user connected on - */ - virtual void OnHookIO(StreamSocket* user, ListenSocket* via); - /** Called when a port accepts a connection * Return MOD_RES_ACCEPT if you have used the file descriptor. * @param fd The file descriptor returned from accept() diff --git a/include/modules/ssl.h b/include/modules/ssl.h index 25076215a..0f58e0b7b 100644 --- a/include/modules/ssl.h +++ b/include/modules/ssl.h @@ -133,28 +133,34 @@ class ssl_cert : public refcountbase class SSLIOHook : public IOHook { + protected: + /** Peer SSL certificate, set by the SSL module + */ + reference<ssl_cert> certificate; + public: - SSLIOHook(Module* mod, const std::string& Name) - : IOHook(mod, Name, IOHook::IOH_SSL) + SSLIOHook(IOHookProvider* hookprov) + : IOHook(hookprov) { } /** - * Get the client certificate from a socket - * @param sock The socket to get the certificate from, must be using this IOHook - * @return The SSL client certificate information + * Get the certificate sent by this peer + * @return The SSL certificate sent by the peer, NULL if no cert was sent */ - virtual ssl_cert* GetCertificate(StreamSocket* sock) = 0; + ssl_cert* GetCertificate() const + { + return certificate; + } /** - * Get the fingerprint of a client certificate from a socket - * @param sock The socket to get the certificate fingerprint from, must be using this IOHook + * Get the fingerprint of the peer's certificate * @return The fingerprint of the SSL client certificate sent by the peer, * empty if no cert was sent */ - std::string GetFingerprint(StreamSocket* sock) + std::string GetFingerprint() const { - ssl_cert* cert = GetCertificate(sock); + ssl_cert* cert = GetCertificate(); if (cert) return cert->GetFingerprint(); return ""; @@ -175,11 +181,11 @@ class SSLClientCert static ssl_cert* GetCertificate(StreamSocket* sock) { IOHook* iohook = sock->GetIOHook(); - if ((!iohook) || (iohook->type != IOHook::IOH_SSL)) + if ((!iohook) || (iohook->prov->type != IOHookProvider::IOH_SSL)) return NULL; SSLIOHook* ssliohook = static_cast<SSLIOHook*>(iohook); - return ssliohook->GetCertificate(sock); + return ssliohook->GetCertificate(); } /** diff --git a/include/socket.h b/include/socket.h index c54517a76..c292b7010 100644 --- a/include/socket.h +++ b/include/socket.h @@ -127,6 +127,7 @@ namespace irc } } +#include "iohook.h" #include "socketengine.h" /** This class handles incoming connections on client ports. * It will create a new User for every valid connection @@ -140,6 +141,12 @@ class CoreExport ListenSocket : public EventHandler int bind_port; /** Human-readable bind description */ std::string bind_desc; + + /** The IOHook provider which handles connections on this socket, + * NULL if there is none. + */ + dynamic_reference_nocheck<IOHookProvider> iohookprov; + /** Create a new listening socket */ ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_to); @@ -153,4 +160,10 @@ class CoreExport ListenSocket : public EventHandler /** Handles sockets internals crap of a connection, convenience wrapper really */ void AcceptInternal(); + + /** Inspects the bind block belonging to this socket to set the name of the IO hook + * provider which this socket will use for incoming connections. + * @return True if the IO hook provider was found or none was given, false otherwise. + */ + bool ResetIOHookProvider(); }; diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp index 8822f69f8..ea09a8b1d 100644 --- a/src/inspsocket.cpp +++ b/src/inspsocket.cpp @@ -134,6 +134,7 @@ void StreamSocket::Close() ServerInstance->Logs->Log("SOCKET", LOG_DEFAULT, "%s threw an exception: %s", modexcept.GetSource().c_str(), modexcept.GetReason().c_str()); } + delete iohook; DelIOHook(); } ServerInstance->SE->Shutdown(this, 2); @@ -467,9 +468,7 @@ void BufferedSocket::DoWrite() { state = I_CONNECTED; this->OnConnected(); - if (GetIOHook()) - GetIOHook()->OnStreamSocketConnect(this); - else + if (!GetIOHook()) ServerInstance->SE->ChangeEventMask(this, FD_WANT_FAST_READ | FD_WANT_EDGE_WRITE); } this->StreamSocket::DoWrite(); diff --git a/src/listensocket.cpp b/src/listensocket.cpp index 108466ae3..01bc36cc5 100644 --- a/src/listensocket.cpp +++ b/src/listensocket.cpp @@ -28,6 +28,7 @@ ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_to) : bind_tag(tag) + , iohookprov(NULL, std::string()) { irc::sockets::satoap(bind_to, bind_addr, bind_port); bind_desc = bind_to.str(); @@ -85,6 +86,8 @@ ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_t { ServerInstance->SE->NonBlocking(this->fd); ServerInstance->SE->AddFd(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + + this->ResetIOHookProvider(); } } @@ -214,3 +217,16 @@ void ListenSocket::HandleEvent(EventType e, int err) break; } } + +bool ListenSocket::ResetIOHookProvider() +{ + std::string provname = bind_tag->getString("ssl"); + if (!provname.empty()) + provname.insert(0, "ssl/"); + + // Set the new provider name, dynref handles the rest + iohookprov.SetProvider(provname); + + // Return true if no provider was set, or one was set and it was also found + return (provname.empty() || iohookprov); +} diff --git a/src/modules.cpp b/src/modules.cpp index 23aceb3e1..c70a99d77 100644 --- a/src/modules.cpp +++ b/src/modules.cpp @@ -154,7 +154,6 @@ void Module::OnText(User*, void*, int, const std::string&, char, CUList&) { De void Module::OnRunTestSuite() { DetachEvent(I_OnRunTestSuite); } void Module::OnNamesListItem(User*, Membership*, std::string&, std::string&) { DetachEvent(I_OnNamesListItem); } ModResult Module::OnNumeric(User*, unsigned int, const std::string&) { DetachEvent(I_OnNumeric); return MOD_RES_PASSTHRU; } -void Module::OnHookIO(StreamSocket*, ListenSocket*) { DetachEvent(I_OnHookIO); } ModResult Module::OnAcceptConnection(int, ListenSocket*, irc::sockets::sockaddrs*, irc::sockets::sockaddrs*) { DetachEvent(I_OnAcceptConnection); return MOD_RES_PASSTHRU; } void Module::OnSendWhoLine(User*, const std::vector<std::string>&, User*, std::string&) { DetachEvent(I_OnSendWhoLine); } void Module::OnSetUserIP(LocalUser*) { DetachEvent(I_OnSetUserIP); } diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 7c19925dd..2add962fd 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -543,58 +543,41 @@ namespace GnuTLS }; } -/** Represents an SSL user's extra data - */ -class issl_session +class GnuTLSIOHook : public SSLIOHook { -public: - StreamSocket* socket; + private: gnutls_session_t sess; issl_status status; - reference<ssl_cert> cert; reference<GnuTLS::Profile> profile; - issl_session() : socket(NULL), sess(NULL) {} -}; - -class GnuTLSIOHook : public SSLIOHook -{ - private: void InitSession(StreamSocket* user, bool me_server) { - issl_session* session = &sessions[user->GetFd()]; - - gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); - session->socket = user; + gnutls_init(&sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); - session->profile->SetupSession(session->sess); - gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(session)); - gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); - gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); + profile->SetupSession(sess); + gnutls_transport_set_ptr(sess, reinterpret_cast<gnutls_transport_ptr_t>(user)); + gnutls_transport_set_push_function(sess, gnutls_push_wrapper); + gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper); if (me_server) - gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. - - Handshake(session, user); + gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. } - void CloseSession(issl_session* session) + void CloseSession() { - if (session->sess) + if (this->sess) { - gnutls_bye(session->sess, GNUTLS_SHUT_WR); - gnutls_deinit(session->sess); + gnutls_bye(this->sess, GNUTLS_SHUT_WR); + gnutls_deinit(this->sess); } - session->socket = NULL; - session->sess = NULL; - session->cert = NULL; - session->profile = NULL; - session->status = ISSL_NONE; + sess = NULL; + certificate = NULL; + status = ISSL_NONE; } - bool Handshake(issl_session* session, StreamSocket* user) + bool Handshake(StreamSocket* user) { - int ret = gnutls_handshake(session->sess); + int ret = gnutls_handshake(this->sess); if (ret < 0) { @@ -602,24 +585,24 @@ class GnuTLSIOHook : public SSLIOHook { // Handshake needs resuming later, read() or write() would have blocked. - if(gnutls_record_get_direction(session->sess) == 0) + if (gnutls_record_get_direction(this->sess) == 0) { // gnutls_handshake() wants to read() again. - session->status = ISSL_HANDSHAKING_READ; + this->status = ISSL_HANDSHAKING_READ; ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); } else { // gnutls_handshake() wants to write() again. - session->status = ISSL_HANDSHAKING_WRITE; + this->status = ISSL_HANDSHAKING_WRITE; ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); } } else { user->SetError("Handshake Failed - " + std::string(gnutls_strerror(ret))); - CloseSession(session); - session->status = ISSL_CLOSING; + CloseSession(); + this->status = ISSL_CLOSING; } return false; @@ -627,9 +610,9 @@ class GnuTLSIOHook : public SSLIOHook else { // Change the seesion state - session->status = ISSL_HANDSHAKEN; + this->status = ISSL_HANDSHAKEN; - VerifyCertificate(session,user); + VerifyCertificate(); // Finish writing, if any left ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); @@ -638,12 +621,9 @@ class GnuTLSIOHook : public SSLIOHook } } - void VerifyCertificate(issl_session* session, StreamSocket* user) + void VerifyCertificate() { - if (!session->sess || !user) - return; - - unsigned int status; + unsigned int certstatus; const gnutls_datum_t* cert_list; int ret; unsigned int cert_list_size; @@ -653,12 +633,12 @@ class GnuTLSIOHook : public SSLIOHook size_t digest_size = sizeof(digest); size_t name_size = sizeof(str); ssl_cert* certinfo = new ssl_cert; - session->cert = certinfo; + this->certificate = certinfo; /* This verification function uses the trusted CAs in the credentials * structure. So you must have installed one or more CA certificates. */ - ret = gnutls_certificate_verify_peers2(session->sess, &status); + ret = gnutls_certificate_verify_peers2(this->sess, &certstatus); if (ret < 0) { @@ -666,16 +646,16 @@ class GnuTLSIOHook : public SSLIOHook return; } - certinfo->invalid = (status & GNUTLS_CERT_INVALID); - certinfo->unknownsigner = (status & GNUTLS_CERT_SIGNER_NOT_FOUND); - certinfo->revoked = (status & GNUTLS_CERT_REVOKED); - certinfo->trusted = !(status & GNUTLS_CERT_SIGNER_NOT_CA); + certinfo->invalid = (certstatus & GNUTLS_CERT_INVALID); + certinfo->unknownsigner = (certstatus & GNUTLS_CERT_SIGNER_NOT_FOUND); + certinfo->revoked = (certstatus & GNUTLS_CERT_REVOKED); + certinfo->trusted = !(certstatus & GNUTLS_CERT_SIGNER_NOT_CA); /* Up to here the process is the same for X.509 certificates and * OpenPGP keys. From now on X.509 certificates are assumed. This can * be easily extended to work with openpgp keys as well. */ - if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) + if (gnutls_certificate_type_get(this->sess) != GNUTLS_CRT_X509) { certinfo->error = "No X509 keys sent"; return; @@ -689,7 +669,7 @@ class GnuTLSIOHook : public SSLIOHook } cert_list_size = 0; - cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size); + cert_list = gnutls_certificate_get_peers(this->sess, &cert_list_size); if (cert_list == NULL) { certinfo->error = "No certificate was found"; @@ -713,7 +693,7 @@ class GnuTLSIOHook : public SSLIOHook gnutls_x509_crt_get_issuer_dn(cert, str, &name_size); certinfo->issuer = str; - if ((ret = gnutls_x509_crt_get_fingerprint(cert, session->profile->GetHash(), digest, &digest_size)) < 0) + if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0) { certinfo->error = gnutls_strerror(ret); } @@ -740,8 +720,12 @@ info_done_dealloc: static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size) { - issl_session* session = reinterpret_cast<issl_session*>(session_wrap); - if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK) + StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook()); +#endif + + if (sock->GetEventMask() & FD_READ_WILL_BLOCK) { #ifdef _WIN32 gnutls_transport_set_errno(session->sess, EAGAIN); @@ -751,7 +735,7 @@ info_done_dealloc: return -1; } - int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast<char *>(buffer), size, 0); + int rv = ServerInstance->SE->Recv(sock, reinterpret_cast<char *>(buffer), size, 0); #ifdef _WIN32 if (rv < 0) @@ -766,14 +750,18 @@ info_done_dealloc: #endif if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK); + ServerInstance->SE->ChangeEventMask(sock, FD_READ_WILL_BLOCK); return rv; } static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size) { - issl_session* session = reinterpret_cast<issl_session*>(session_wrap); - if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK) + StreamSocket* sock = reinterpret_cast<StreamSocket*>(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast<GnuTLSIOHook*>(sock->GetIOHook()); +#endif + + if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) { #ifdef _WIN32 gnutls_transport_set_errno(session->sess, EAGAIN); @@ -783,7 +771,7 @@ info_done_dealloc: return -1; } - int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast<const char *>(buffer), size, 0); + int rv = ServerInstance->SE->Send(sock, reinterpret_cast<const char *>(buffer), size, 0); #ifdef _WIN32 if (rv < 0) @@ -798,75 +786,55 @@ info_done_dealloc: #endif if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK); + ServerInstance->SE->ChangeEventMask(sock, FD_WRITE_WILL_BLOCK); return rv; } public: - issl_session* sessions; - - GnuTLSIOHook(Module* parent) - : SSLIOHook(parent, "ssl/gnutls") - { - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; - } - - ~GnuTLSIOHook() - { - delete[] sessions; - } - - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool outbound, const reference<GnuTLS::Profile>& sslprofile) + : SSLIOHook(hookprov) + , sess(NULL) + , status(ISSL_NONE) + , profile(sslprofile) { - issl_session* session = &sessions[user->GetFd()]; - - /* For STARTTLS: Don't try and init a session on a socket that already has a session */ - if (session->sess) - return; - - InitSession(user, true); - } - - void OnStreamSocketConnect(StreamSocket* user) CXX11_OVERRIDE - { - InitSession(user, false); + InitSession(sock, outbound); + sock->AddIOHook(this); + Handshake(sock); } void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE { - CloseSession(&sessions[user->GetFd()]); + CloseSession(); } int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE { - issl_session* session = &sessions[user->GetFd()]; - - if (!session->sess) + if (!this->sess) { - CloseSession(session); + CloseSession(); user->SetError("No SSL session"); return -1; } - if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) + if (this->status == ISSL_HANDSHAKING_READ || this->status == ISSL_HANDSHAKING_WRITE) { // The handshake isn't finished, try to finish it. - if(!Handshake(session, user)) + if (!Handshake(user)) { - if (session->status != ISSL_CLOSING) + if (this->status != ISSL_CLOSING) return 0; return -1; } } - // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. + // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN. - if (session->status == ISSL_HANDSHAKEN) + if (this->status == ISSL_HANDSHAKEN) { char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - int ret = gnutls_record_recv(session->sess, buffer, bufsiz); + int ret = gnutls_record_recv(this->sess, buffer, bufsiz); if (ret > 0) { recvq.append(buffer, ret); @@ -879,17 +847,17 @@ info_done_dealloc: else if (ret == 0) { user->SetError("Connection closed"); - CloseSession(session); + CloseSession(); return -1; } else { user->SetError(gnutls_strerror(ret)); - CloseSession(session); + CloseSession(); return -1; } } - else if (session->status == ISSL_CLOSING) + else if (this->status == ISSL_CLOSING) return -1; return 0; @@ -897,29 +865,27 @@ info_done_dealloc: int OnStreamSocketWrite(StreamSocket* user, std::string& sendq) CXX11_OVERRIDE { - issl_session* session = &sessions[user->GetFd()]; - - if (!session->sess) + if (!this->sess) { - CloseSession(session); + CloseSession(); user->SetError("No SSL session"); return -1; } - if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ) + if (this->status == ISSL_HANDSHAKING_WRITE || this->status == ISSL_HANDSHAKING_READ) { // The handshake isn't finished, try to finish it. - Handshake(session, user); - if (session->status != ISSL_CLOSING) + Handshake(user); + if (this->status != ISSL_CLOSING) return 0; return -1; } int ret = 0; - if (session->status == ISSL_HANDSHAKEN) + if (this->status == ISSL_HANDSHAKEN) { - ret = gnutls_record_send(session->sess, sendq.data(), sendq.length()); + ret = gnutls_record_send(this->sess, sendq.data(), sendq.length()); if (ret == (int)sendq.length()) { @@ -940,7 +906,7 @@ info_done_dealloc: else // (ret < 0) { user->SetError(gnutls_strerror(ret)); - CloseSession(session); + CloseSession(); return -1; } } @@ -948,16 +914,8 @@ info_done_dealloc: return 0; } - ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE - { - int fd = sock->GetFd(); - issl_session* session = &sessions[fd]; - return session->cert; - } - void TellCiphersAndFingerprint(LocalUser* user) { - const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess; if (sess) { std::string text = "*** You are connected using SSL cipher '"; @@ -966,13 +924,14 @@ info_done_dealloc: text.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-"); text.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess)))).append("'"); - ssl_cert* cert = sessions[user->eh.GetFd()].cert; - if (!cert->fingerprint.empty()) - text += " and your SSL fingerprint is " + cert->fingerprint; + if (!certificate->fingerprint.empty()) + text += " and your SSL fingerprint is " + certificate->fingerprint; user->WriteNotice(text); } } + + GnuTLS::Profile* GetProfile() { return profile; } }; int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st) @@ -983,8 +942,8 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d st->cert_type = GNUTLS_CRT_X509; st->key_type = GNUTLS_PRIVKEY_X509; #endif - issl_session* session = reinterpret_cast<issl_session*>(gnutls_transport_get_ptr(sess)); - GnuTLS::X509Credentials& cred = session->profile->GetX509Credentials(); + StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess)); + GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetIOHook())->GetProfile()->GetX509Credentials(); st->ncerts = cred.certs.size(); st->cert.x509 = cred.certs.raw(); @@ -994,15 +953,41 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d return 0; } +class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider +{ + reference<GnuTLS::Profile> profile; + + public: + GnuTLSIOHookProvider(Module* mod, reference<GnuTLS::Profile>& prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); + } + + ~GnuTLSIOHookProvider() + { + ServerInstance->Modules->DelService(*this); + } + + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, true, profile); + } + + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, false, profile); + } +}; + class ModuleSSLGnuTLS : public Module { - typedef std::vector<reference<GnuTLS::Profile> > ProfileList; + typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList; // First member of the class, gets constructed first and destructed last GnuTLS::Init libinit; - GnuTLSIOHook iohook; - std::string sslports; RandGen randhandler; @@ -1026,7 +1011,7 @@ class ModuleSSLGnuTLS : public Module try { reference<GnuTLS::Profile> profile(GnuTLS::Profile::Create(defname, tag)); - newprofiles.push_back(profile); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); } catch (CoreException& ex) { @@ -1057,7 +1042,7 @@ class ModuleSSLGnuTLS : public Module throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(profile); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); } // New profiles are ok, begin using them @@ -1066,7 +1051,7 @@ class ModuleSSLGnuTLS : public Module } public: - ModuleSSLGnuTLS() : iohook(this) + ModuleSSLGnuTLS() { #ifndef GNUTLS_HAS_RND gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0); @@ -1144,7 +1129,7 @@ class ModuleSSLGnuTLS : public Module { LocalUser* user = IS_LOCAL(static_cast<User*>(item)); - if (user && user->eh.GetIOHook() == &iohook) + if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. @@ -1164,27 +1149,11 @@ class ModuleSSLGnuTLS : public Module tokens["SSL"] = sslports; } - void OnHookIO(StreamSocket* user, ListenSocket* lsb) CXX11_OVERRIDE - { - if (!user->GetIOHook()) - { - std::string profilename = lsb->bind_tag->getString("ssl"); - for (ProfileList::const_iterator i = profiles.begin(); i != profiles.end(); ++i) - { - if ((*i)->GetName() == profilename) - { - iohook.sessions[user->GetFd()].profile = *i; - user->AddIOHook(&iohook); - break; - } - } - } - } - void OnUserConnect(LocalUser* user) CXX11_OVERRIDE { - if (user->eh.GetIOHook() == &iohook) - iohook.TellCiphersAndFingerprint(user); + IOHook* hook = user->eh.GetIOHook(); + if (hook && hook->prov->creator == this) + static_cast<GnuTLSIOHook*>(hook)->TellCiphersAndFingerprint(user); } }; diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index 11f4a365e..962350e1c 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -235,26 +235,6 @@ namespace OpenSSL }; } -/** Represents an SSL user's extra data - */ -class issl_session -{ -public: - SSL* sess; - issl_status status; - reference<ssl_cert> cert; - reference<OpenSSL::Profile> profile; - - bool outbound; - bool data_to_write; - - issl_session() - { - outbound = false; - data_to_write = false; - } -}; - static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) { /* XXX: This will allow self signed certificates. @@ -272,34 +252,40 @@ static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) class OpenSSLIOHook : public SSLIOHook { private: - bool Handshake(StreamSocket* user, issl_session* session) + SSL* sess; + issl_status status; + const bool outbound; + bool data_to_write; + reference<OpenSSL::Profile> profile; + + bool Handshake(StreamSocket* user) { int ret; - if (session->outbound) - ret = SSL_connect(session->sess); + if (outbound) + ret = SSL_connect(sess); else - ret = SSL_accept(session->sess); + ret = SSL_accept(sess); if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_READ) { ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); - session->status = ISSL_HANDSHAKING; + this->status = ISSL_HANDSHAKING; return true; } else if (err == SSL_ERROR_WANT_WRITE) { ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); - session->status = ISSL_HANDSHAKING; + this->status = ISSL_HANDSHAKING; return true; } else { - CloseSession(session); + CloseSession(); } return false; @@ -307,9 +293,9 @@ class OpenSSLIOHook : public SSLIOHook else if (ret > 0) { // Handshake complete. - VerifyCertificate(session, user); + VerifyCertificate(); - session->status = ISSL_OPEN; + status = ISSL_OPEN; ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); @@ -317,38 +303,35 @@ class OpenSSLIOHook : public SSLIOHook } else if (ret == 0) { - CloseSession(session); + CloseSession(); return true; } return true; } - void CloseSession(issl_session* session) + void CloseSession() { - if (session->sess) + if (sess) { - SSL_shutdown(session->sess); - SSL_free(session->sess); + SSL_shutdown(sess); + SSL_free(sess); } - - session->sess = NULL; - session->status = ISSL_NONE; + sess = NULL; + certificate = NULL; + status = ISSL_NONE; errno = EIO; } - void VerifyCertificate(issl_session* session, StreamSocket* user) + void VerifyCertificate() { - if (!session->sess || !user) - return; - X509* cert; ssl_cert* certinfo = new ssl_cert; - session->cert = certinfo; + this->certificate = certinfo; unsigned int n; unsigned char md[EVP_MAX_MD_SIZE]; - cert = SSL_get_peer_certificate((SSL*)session->sess); + cert = SSL_get_peer_certificate(sess); if (!cert) { @@ -356,7 +339,7 @@ class OpenSSLIOHook : public SSLIOHook return; } - certinfo->invalid = (SSL_get_verify_result(session->sess) != X509_V_OK); + certinfo->invalid = (SSL_get_verify_result(sess) != X509_V_OK); if (!SelfSigned) { @@ -372,7 +355,7 @@ class OpenSSLIOHook : public SSLIOHook certinfo->dn = X509_NAME_oneline(X509_get_subject_name(cert),0,0); certinfo->issuer = X509_NAME_oneline(X509_get_issuer_name(cert),0,0); - if (!X509_digest(cert, session->profile->GetDigest(), md, &n)) + if (!X509_digest(cert, profile->GetDigest(), md, &n)) { certinfo->error = "Out of memory generating fingerprint"; } @@ -390,129 +373,73 @@ class OpenSSLIOHook : public SSLIOHook } public: - issl_session* sessions; - - OpenSSLIOHook(Module* mod) - : SSLIOHook(mod, "ssl/openssl") + OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool is_outbound, SSL* session, const reference<OpenSSL::Profile>& sslprofile) + : SSLIOHook(hookprov) + , sess(session) + , status(ISSL_NONE) + , outbound(is_outbound) + , data_to_write(false) + , profile(sslprofile) { - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; - } - - ~OpenSSLIOHook() - { - delete[] sessions; - } - - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE - { - int fd = user->GetFd(); - - issl_session* session = &sessions[fd]; - - session->sess = session->profile->CreateServerSession(); - session->status = ISSL_NONE; - session->outbound = false; - session->cert = NULL; - - if (session->sess == NULL) + if (sess == NULL) return; + if (SSL_set_fd(sess, sock->GetFd()) == 0) + throw ModuleException("Can't set fd with SSL_set_fd: " + ConvToStr(sock->GetFd())); - if (SSL_set_fd(session->sess, fd) == 0) - { - ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Can't set fd with SSL_set_fd: %d", fd); - return; - } - - Handshake(user, session); - } - - void OnStreamSocketConnect(StreamSocket* user) CXX11_OVERRIDE - { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() -1)) - return; - - issl_session* session = &sessions[fd]; - - session->sess = session->profile->CreateClientSession(); - session->status = ISSL_NONE; - session->outbound = true; - - if (session->sess == NULL) - return; - - if (SSL_set_fd(session->sess, fd) == 0) - { - ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Can't set fd with SSL_set_fd: %d", fd); - return; - } - - Handshake(user, session); + sock->AddIOHook(this); + Handshake(sock); } void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return; - - CloseSession(&sessions[fd]); + CloseSession(); } int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return -1; - - issl_session* session = &sessions[fd]; - - if (!session->sess) + if (!sess) { - CloseSession(session); + CloseSession(); return -1; } - if (session->status == ISSL_HANDSHAKING) + if (status == ISSL_HANDSHAKING) { // The handshake isn't finished and it wants to read, try to finish it. - if (!Handshake(user, session)) + if (!Handshake(user)) { // Couldn't resume handshake. - if (session->status == ISSL_NONE) + if (status == ISSL_NONE) return -1; return 0; } } - // If we resumed the handshake then session->status will be ISSL_OPEN + // If we resumed the handshake then this->status will be ISSL_OPEN - if (session->status == ISSL_OPEN) + if (status == ISSL_OPEN) { char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - int ret = SSL_read(session->sess, buffer, bufsiz); + int ret = SSL_read(sess, buffer, bufsiz); if (ret > 0) { recvq.append(buffer, ret); - if (session->data_to_write) + if (data_to_write) ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_SINGLE_WRITE); return 1; } else if (ret == 0) { // Client closed connection. - CloseSession(session); + CloseSession(); user->SetError("Connection closed"); return -1; } else if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_READ) { @@ -526,7 +453,7 @@ class OpenSSLIOHook : public SSLIOHook } else { - CloseSession(session); + CloseSession(); return -1; } } @@ -537,35 +464,31 @@ class OpenSSLIOHook : public SSLIOHook int OnStreamSocketWrite(StreamSocket* user, std::string& buffer) CXX11_OVERRIDE { - int fd = user->GetFd(); - - issl_session* session = &sessions[fd]; - - if (!session->sess) + if (!sess) { - CloseSession(session); + CloseSession(); return -1; } - session->data_to_write = true; + data_to_write = true; - if (session->status == ISSL_HANDSHAKING) + if (status == ISSL_HANDSHAKING) { - if (!Handshake(user, session)) + if (!Handshake(user)) { // Couldn't resume handshake. - if (session->status == ISSL_NONE) + if (status == ISSL_NONE) return -1; return 0; } } - if (session->status == ISSL_OPEN) + if (status == ISSL_OPEN) { - int ret = SSL_write(session->sess, buffer.data(), buffer.size()); + int ret = SSL_write(sess, buffer.data(), buffer.size()); if (ret == (int)buffer.length()) { - session->data_to_write = false; + data_to_write = false; ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); return 1; } @@ -577,12 +500,12 @@ class OpenSSLIOHook : public SSLIOHook } else if (ret == 0) { - CloseSession(session); + CloseSession(); return -1; } else if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_WRITE) { @@ -596,7 +519,7 @@ class OpenSSLIOHook : public SSLIOHook } else { - CloseSession(session); + CloseSession(); return -1; } } @@ -604,20 +527,12 @@ class OpenSSLIOHook : public SSLIOHook return 0; } - ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE - { - int fd = sock->GetFd(); - issl_session* session = &sessions[fd]; - return session->cert; - } - void TellCiphersAndFingerprint(LocalUser* user) { - issl_session& s = sessions[user->eh.GetFd()]; - if (s.sess) + if (sess) { - std::string text = "*** You are connected using SSL cipher '" + std::string(SSL_get_cipher(s.sess)) + "'"; - const std::string& fingerprint = s.cert->fingerprint; + std::string text = "*** You are connected using SSL cipher '" + std::string(SSL_get_cipher(sess)) + "'"; + const std::string& fingerprint = certificate->fingerprint; if (!fingerprint.empty()) text += " and your SSL fingerprint is " + fingerprint; @@ -626,12 +541,39 @@ class OpenSSLIOHook : public SSLIOHook } }; +class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider +{ + reference<OpenSSL::Profile> profile; + + public: + OpenSSLIOHookProvider(Module* mod, reference<OpenSSL::Profile>& prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); + } + + ~OpenSSLIOHookProvider() + { + ServerInstance->Modules->DelService(*this); + } + + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new OpenSSLIOHook(this, sock, false, profile->CreateServerSession(), profile); + } + + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new OpenSSLIOHook(this, sock, true, profile->CreateClientSession(), profile); + } +}; + class ModuleSSLOpenSSL : public Module { - typedef std::vector<reference<OpenSSL::Profile> > ProfileList; + typedef std::vector<reference<OpenSSLIOHookProvider> > ProfileList; std::string sslports; - OpenSSLIOHook iohook; ProfileList profiles; void ReadProfiles() @@ -648,7 +590,7 @@ class ModuleSSLOpenSSL : public Module try { reference<OpenSSL::Profile> profile(new OpenSSL::Profile(defname, tag)); - newprofiles.push_back(profile); + newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); } catch (OpenSSL::Exception& ex) { @@ -679,14 +621,14 @@ class ModuleSSLOpenSSL : public Module throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(profile); + newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); } profiles.swap(newprofiles); } public: - ModuleSSLOpenSSL() : iohook(this) + ModuleSSLOpenSSL() { // Initialize OpenSSL SSL_library_init(); @@ -698,24 +640,6 @@ class ModuleSSLOpenSSL : public Module ReadProfiles(); } - void OnHookIO(StreamSocket* user, ListenSocket* lsb) CXX11_OVERRIDE - { - if (user->GetIOHook()) - return; - - ConfigTag* tag = lsb->bind_tag; - std::string profilename = tag->getString("ssl"); - for (ProfileList::const_iterator i = profiles.begin(); i != profiles.end(); ++i) - { - if ((*i)->GetName() == profilename) - { - iohook.sessions[user->GetFd()].profile = *i; - user->AddIOHook(&iohook); - break; - } - } - } - void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { sslports.clear(); @@ -778,8 +702,9 @@ class ModuleSSLOpenSSL : public Module void OnUserConnect(LocalUser* user) CXX11_OVERRIDE { - if (user->eh.GetIOHook() == &iohook) - iohook.TellCiphersAndFingerprint(user); + IOHook* hook = user->eh.GetIOHook(); + if (hook && hook->prov->creator == this) + static_cast<OpenSSLIOHook*>(hook)->TellCiphersAndFingerprint(user); } void OnCleanup(int target_type, void* item) CXX11_OVERRIDE @@ -788,7 +713,7 @@ class ModuleSSLOpenSSL : public Module { LocalUser* user = IS_LOCAL((User*)item); - if (user && user->eh.GetIOHook() == &iohook) + if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. diff --git a/src/modules/m_httpd.cpp b/src/modules/m_httpd.cpp index 735551dff..d0291b8cc 100644 --- a/src/modules/m_httpd.cpp +++ b/src/modules/m_httpd.cpp @@ -65,9 +65,8 @@ class HttpServerSocket : public BufferedSocket { InternalState = HTTP_SERVE_WAIT_REQUEST; - FOREACH_MOD(OnHookIO, (this, via)); - if (GetIOHook()) - GetIOHook()->OnStreamSocketAccept(this, client, server); + if (via->iohookprov) + via->iohookprov->OnAccept(this, client, server); } ~HttpServerSocket() diff --git a/src/modules/m_spanningtree/main.cpp b/src/modules/m_spanningtree/main.cpp index 671e10269..1782f7e2a 100644 --- a/src/modules/m_spanningtree/main.cpp +++ b/src/modules/m_spanningtree/main.cpp @@ -677,7 +677,7 @@ void ModuleSpanningTree::OnUnloadModule(Module* mod) for (TreeServer::ChildServers::const_iterator i = list.begin(); i != list.end(); ++i) { TreeSocket* sock = (*i)->GetSocket(); - if (sock && sock->GetIOHook() && sock->GetIOHook()->creator == mod) + if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod) { sock->SendError("SSL module unloaded"); sock->Close(); @@ -687,7 +687,7 @@ void ModuleSpanningTree::OnUnloadModule(Module* mod) for (SpanningTreeUtilities::TimeoutList::const_iterator i = Utils->timeoutlist.begin(); i != Utils->timeoutlist.end(); ++i) { TreeSocket* sock = i->first; - if (sock->GetIOHook() && sock->GetIOHook()->creator == mod) + if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod) sock->Close(); } } diff --git a/src/modules/m_spanningtree/treesocket1.cpp b/src/modules/m_spanningtree/treesocket1.cpp index fa8a94f72..9c262f1ea 100644 --- a/src/modules/m_spanningtree/treesocket1.cpp +++ b/src/modules/m_spanningtree/treesocket1.cpp @@ -44,16 +44,7 @@ TreeSocket::TreeSocket(Link* link, Autoconnect* myac, const std::string& ipaddr) capab->link = link; capab->ac = myac; capab->capab_phase = 0; - if (!link->Hook.empty()) - { - ServiceProvider* prov = ServerInstance->Modules->FindService(SERVICE_IOHOOK, link->Hook); - if (!prov) - { - SetError("Could not find hook '" + link->Hook + "' for connection to " + linkID); - return; - } - AddIOHook(static_cast<IOHook*>(prov)); - } + DoConnect(ipaddr, link->Port, link->Timeout, link->Bind); Utils->timeoutlist[this] = std::pair<std::string, int>(linkID, link->Timeout); SendCapabilities(1); @@ -71,9 +62,8 @@ TreeSocket::TreeSocket(int newfd, ListenSocket* via, irc::sockets::sockaddrs* cl capab = new CapabData; capab->capab_phase = 0; - FOREACH_MOD(OnHookIO, (this, via)); - if (GetIOHook()) - GetIOHook()->OnStreamSocketAccept(this, client, server); + if (via->iohookprov) + via->iohookprov->OnAccept(this, client, server); SendCapabilities(1); Utils->timeoutlist[this] = std::pair<std::string, int>(linkID, 30); @@ -116,6 +106,17 @@ void TreeSocket::OnConnected() { if (this->LinkState == CONNECTING) { + if (!capab->link->Hook.empty()) + { + ServiceProvider* prov = ServerInstance->Modules->FindService(SERVICE_IOHOOK, capab->link->Hook); + if (!prov) + { + SetError("Could not find hook '" + capab->link->Hook + "' for connection to " + linkID); + return; + } + static_cast<IOHookProvider*>(prov)->OnConnect(this); + } + ServerInstance->SNO->WriteGlobalSno('l', "Connection to \2%s\2[%s] started.", linkID.c_str(), (capab->link->HiddenFromStats ? "<hidden>" : capab->link->IPAddr.c_str())); this->SendCapabilities(1); diff --git a/src/modules/m_starttls.cpp b/src/modules/m_starttls.cpp index 09c9b4f0f..d591eed55 100644 --- a/src/modules/m_starttls.cpp +++ b/src/modules/m_starttls.cpp @@ -30,10 +30,10 @@ enum class CommandStartTLS : public SplitCommand { - dynamic_reference_nocheck<IOHook>& ssl; + dynamic_reference_nocheck<IOHookProvider>& ssl; public: - CommandStartTLS(Module* mod, dynamic_reference_nocheck<IOHook>& s) + CommandStartTLS(Module* mod, dynamic_reference_nocheck<IOHookProvider>& s) : SplitCommand(mod, "STARTTLS") , ssl(s) { @@ -71,8 +71,7 @@ class CommandStartTLS : public SplitCommand */ user->eh.DoWrite(); - user->eh.AddIOHook(*ssl); - ssl->OnStreamSocketAccept(&user->eh, NULL, NULL); + ssl->OnAccept(&user->eh, NULL, NULL); return CMD_SUCCESS; } @@ -82,7 +81,7 @@ class ModuleStartTLS : public Module { CommandStartTLS starttls; GenericCap tls; - dynamic_reference_nocheck<IOHook> ssl; + dynamic_reference_nocheck<IOHookProvider> ssl; public: ModuleStartTLS() diff --git a/src/socket.cpp b/src/socket.cpp index 4ebed1ccd..c65cd5b27 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -106,6 +106,8 @@ int InspIRCd::BindPorts(FailedPortList &failed_ports) if ((**n).bind_desc == bind_readable) { (*n)->bind_tag = tag; // Replace tag, we know addr and port match, but other info (type, ssl) may not + (*n)->ResetIOHookProvider(); + skip = true; old_ports.erase(n); break; diff --git a/src/usermanager.cpp b/src/usermanager.cpp index 745934fd4..29d1f7370 100644 --- a/src/usermanager.cpp +++ b/src/usermanager.cpp @@ -62,20 +62,9 @@ void UserManager::AddUser(int socket, ListenSocket* via, irc::sockets::sockaddrs } UserIOHandler* eh = &New->eh; - /* Give each of the modules an attempt to hook the user for I/O */ - FOREACH_MOD(OnHookIO, (eh, via)); - - if (eh->GetIOHook()) - { - try - { - eh->GetIOHook()->OnStreamSocketAccept(eh, client, server); - } - catch (CoreException& modexcept) - { - ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "%s threw an exception: %s", modexcept.GetSource().c_str(), modexcept.GetReason().c_str()); - } - } + // If this listener has an IO hook provider set then tell it about the connection + if (via->iohookprov) + via->iohookprov->OnAccept(eh, client, server); ServerInstance->Logs->Log("USERS", LOG_DEBUG, "New user fd: %d", socket); |