From 99f79a4e5c3abbe91a03216824e7659051872054 Mon Sep 17 00:00:00 2001 From: Attila Molnar Date: Tue, 24 Sep 2013 20:40:20 +0200 Subject: Split IOHook into IOHook and IOHookProvider Create one IOHook instance for each hooked socket which contains all the hook specific data and read/write/close functions, removing the need for the "issl_session" array in SSL modules. Register instances of the IOHookProvider class in the core and use them to create specialized IOHook instances (OnConnect/OnAccept). Remove the OnHookIO hook, add a dynamic reference to ListenSocket that points to the hook provider (if any) to use for incoming connections on that socket. For outgoing connections modules still have to find the IOHookProvider they want to use themselves but instead of calling AddIOHook(hookprov), now they have to call IOHookProvider::OnConnect() after the connection has been established. --- src/inspsocket.cpp | 5 +- src/listensocket.cpp | 16 ++ src/modules.cpp | 1 - src/modules/extra/m_ssl_gnutls.cpp | 273 ++++++++++++--------------- src/modules/extra/m_ssl_openssl.cpp | 285 +++++++++++------------------ src/modules/m_httpd.cpp | 5 +- src/modules/m_spanningtree/main.cpp | 4 +- src/modules/m_spanningtree/treesocket1.cpp | 27 +-- src/modules/m_starttls.cpp | 9 +- src/socket.cpp | 2 + src/usermanager.cpp | 17 +- 11 files changed, 271 insertions(+), 373 deletions(-) (limited to 'src') diff --git a/src/inspsocket.cpp b/src/inspsocket.cpp index 8822f69f8..ea09a8b1d 100644 --- a/src/inspsocket.cpp +++ b/src/inspsocket.cpp @@ -134,6 +134,7 @@ void StreamSocket::Close() ServerInstance->Logs->Log("SOCKET", LOG_DEFAULT, "%s threw an exception: %s", modexcept.GetSource().c_str(), modexcept.GetReason().c_str()); } + delete iohook; DelIOHook(); } ServerInstance->SE->Shutdown(this, 2); @@ -467,9 +468,7 @@ void BufferedSocket::DoWrite() { state = I_CONNECTED; this->OnConnected(); - if (GetIOHook()) - GetIOHook()->OnStreamSocketConnect(this); - else + if (!GetIOHook()) ServerInstance->SE->ChangeEventMask(this, FD_WANT_FAST_READ | FD_WANT_EDGE_WRITE); } this->StreamSocket::DoWrite(); diff --git a/src/listensocket.cpp b/src/listensocket.cpp index 108466ae3..01bc36cc5 100644 --- a/src/listensocket.cpp +++ b/src/listensocket.cpp @@ -28,6 +28,7 @@ ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_to) : bind_tag(tag) + , iohookprov(NULL, std::string()) { irc::sockets::satoap(bind_to, bind_addr, bind_port); bind_desc = bind_to.str(); @@ -85,6 +86,8 @@ ListenSocket::ListenSocket(ConfigTag* tag, const irc::sockets::sockaddrs& bind_t { ServerInstance->SE->NonBlocking(this->fd); ServerInstance->SE->AddFd(this, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); + + this->ResetIOHookProvider(); } } @@ -214,3 +217,16 @@ void ListenSocket::HandleEvent(EventType e, int err) break; } } + +bool ListenSocket::ResetIOHookProvider() +{ + std::string provname = bind_tag->getString("ssl"); + if (!provname.empty()) + provname.insert(0, "ssl/"); + + // Set the new provider name, dynref handles the rest + iohookprov.SetProvider(provname); + + // Return true if no provider was set, or one was set and it was also found + return (provname.empty() || iohookprov); +} diff --git a/src/modules.cpp b/src/modules.cpp index 23aceb3e1..c70a99d77 100644 --- a/src/modules.cpp +++ b/src/modules.cpp @@ -154,7 +154,6 @@ void Module::OnText(User*, void*, int, const std::string&, char, CUList&) { De void Module::OnRunTestSuite() { DetachEvent(I_OnRunTestSuite); } void Module::OnNamesListItem(User*, Membership*, std::string&, std::string&) { DetachEvent(I_OnNamesListItem); } ModResult Module::OnNumeric(User*, unsigned int, const std::string&) { DetachEvent(I_OnNumeric); return MOD_RES_PASSTHRU; } -void Module::OnHookIO(StreamSocket*, ListenSocket*) { DetachEvent(I_OnHookIO); } ModResult Module::OnAcceptConnection(int, ListenSocket*, irc::sockets::sockaddrs*, irc::sockets::sockaddrs*) { DetachEvent(I_OnAcceptConnection); return MOD_RES_PASSTHRU; } void Module::OnSendWhoLine(User*, const std::vector&, User*, std::string&) { DetachEvent(I_OnSendWhoLine); } void Module::OnSetUserIP(LocalUser*) { DetachEvent(I_OnSetUserIP); } diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 7c19925dd..2add962fd 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -543,58 +543,41 @@ namespace GnuTLS }; } -/** Represents an SSL user's extra data - */ -class issl_session +class GnuTLSIOHook : public SSLIOHook { -public: - StreamSocket* socket; + private: gnutls_session_t sess; issl_status status; - reference cert; reference profile; - issl_session() : socket(NULL), sess(NULL) {} -}; - -class GnuTLSIOHook : public SSLIOHook -{ - private: void InitSession(StreamSocket* user, bool me_server) { - issl_session* session = &sessions[user->GetFd()]; - - gnutls_init(&session->sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); - session->socket = user; + gnutls_init(&sess, me_server ? GNUTLS_SERVER : GNUTLS_CLIENT); - session->profile->SetupSession(session->sess); - gnutls_transport_set_ptr(session->sess, reinterpret_cast(session)); - gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); - gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); + profile->SetupSession(sess); + gnutls_transport_set_ptr(sess, reinterpret_cast(user)); + gnutls_transport_set_push_function(sess, gnutls_push_wrapper); + gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper); if (me_server) - gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. - - Handshake(session, user); + gnutls_certificate_server_set_request(sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. } - void CloseSession(issl_session* session) + void CloseSession() { - if (session->sess) + if (this->sess) { - gnutls_bye(session->sess, GNUTLS_SHUT_WR); - gnutls_deinit(session->sess); + gnutls_bye(this->sess, GNUTLS_SHUT_WR); + gnutls_deinit(this->sess); } - session->socket = NULL; - session->sess = NULL; - session->cert = NULL; - session->profile = NULL; - session->status = ISSL_NONE; + sess = NULL; + certificate = NULL; + status = ISSL_NONE; } - bool Handshake(issl_session* session, StreamSocket* user) + bool Handshake(StreamSocket* user) { - int ret = gnutls_handshake(session->sess); + int ret = gnutls_handshake(this->sess); if (ret < 0) { @@ -602,24 +585,24 @@ class GnuTLSIOHook : public SSLIOHook { // Handshake needs resuming later, read() or write() would have blocked. - if(gnutls_record_get_direction(session->sess) == 0) + if (gnutls_record_get_direction(this->sess) == 0) { // gnutls_handshake() wants to read() again. - session->status = ISSL_HANDSHAKING_READ; + this->status = ISSL_HANDSHAKING_READ; ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); } else { // gnutls_handshake() wants to write() again. - session->status = ISSL_HANDSHAKING_WRITE; + this->status = ISSL_HANDSHAKING_WRITE; ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); } } else { user->SetError("Handshake Failed - " + std::string(gnutls_strerror(ret))); - CloseSession(session); - session->status = ISSL_CLOSING; + CloseSession(); + this->status = ISSL_CLOSING; } return false; @@ -627,9 +610,9 @@ class GnuTLSIOHook : public SSLIOHook else { // Change the seesion state - session->status = ISSL_HANDSHAKEN; + this->status = ISSL_HANDSHAKEN; - VerifyCertificate(session,user); + VerifyCertificate(); // Finish writing, if any left ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); @@ -638,12 +621,9 @@ class GnuTLSIOHook : public SSLIOHook } } - void VerifyCertificate(issl_session* session, StreamSocket* user) + void VerifyCertificate() { - if (!session->sess || !user) - return; - - unsigned int status; + unsigned int certstatus; const gnutls_datum_t* cert_list; int ret; unsigned int cert_list_size; @@ -653,12 +633,12 @@ class GnuTLSIOHook : public SSLIOHook size_t digest_size = sizeof(digest); size_t name_size = sizeof(str); ssl_cert* certinfo = new ssl_cert; - session->cert = certinfo; + this->certificate = certinfo; /* This verification function uses the trusted CAs in the credentials * structure. So you must have installed one or more CA certificates. */ - ret = gnutls_certificate_verify_peers2(session->sess, &status); + ret = gnutls_certificate_verify_peers2(this->sess, &certstatus); if (ret < 0) { @@ -666,16 +646,16 @@ class GnuTLSIOHook : public SSLIOHook return; } - certinfo->invalid = (status & GNUTLS_CERT_INVALID); - certinfo->unknownsigner = (status & GNUTLS_CERT_SIGNER_NOT_FOUND); - certinfo->revoked = (status & GNUTLS_CERT_REVOKED); - certinfo->trusted = !(status & GNUTLS_CERT_SIGNER_NOT_CA); + certinfo->invalid = (certstatus & GNUTLS_CERT_INVALID); + certinfo->unknownsigner = (certstatus & GNUTLS_CERT_SIGNER_NOT_FOUND); + certinfo->revoked = (certstatus & GNUTLS_CERT_REVOKED); + certinfo->trusted = !(certstatus & GNUTLS_CERT_SIGNER_NOT_CA); /* Up to here the process is the same for X.509 certificates and * OpenPGP keys. From now on X.509 certificates are assumed. This can * be easily extended to work with openpgp keys as well. */ - if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) + if (gnutls_certificate_type_get(this->sess) != GNUTLS_CRT_X509) { certinfo->error = "No X509 keys sent"; return; @@ -689,7 +669,7 @@ class GnuTLSIOHook : public SSLIOHook } cert_list_size = 0; - cert_list = gnutls_certificate_get_peers(session->sess, &cert_list_size); + cert_list = gnutls_certificate_get_peers(this->sess, &cert_list_size); if (cert_list == NULL) { certinfo->error = "No certificate was found"; @@ -713,7 +693,7 @@ class GnuTLSIOHook : public SSLIOHook gnutls_x509_crt_get_issuer_dn(cert, str, &name_size); certinfo->issuer = str; - if ((ret = gnutls_x509_crt_get_fingerprint(cert, session->profile->GetHash(), digest, &digest_size)) < 0) + if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0) { certinfo->error = gnutls_strerror(ret); } @@ -740,8 +720,12 @@ info_done_dealloc: static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t session_wrap, void* buffer, size_t size) { - issl_session* session = reinterpret_cast(session_wrap); - if (session->socket->GetEventMask() & FD_READ_WILL_BLOCK) + StreamSocket* sock = reinterpret_cast(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast(sock->GetIOHook()); +#endif + + if (sock->GetEventMask() & FD_READ_WILL_BLOCK) { #ifdef _WIN32 gnutls_transport_set_errno(session->sess, EAGAIN); @@ -751,7 +735,7 @@ info_done_dealloc: return -1; } - int rv = ServerInstance->SE->Recv(session->socket, reinterpret_cast(buffer), size, 0); + int rv = ServerInstance->SE->Recv(sock, reinterpret_cast(buffer), size, 0); #ifdef _WIN32 if (rv < 0) @@ -766,14 +750,18 @@ info_done_dealloc: #endif if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_READ_WILL_BLOCK); + ServerInstance->SE->ChangeEventMask(sock, FD_READ_WILL_BLOCK); return rv; } static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t session_wrap, const void* buffer, size_t size) { - issl_session* session = reinterpret_cast(session_wrap); - if (session->socket->GetEventMask() & FD_WRITE_WILL_BLOCK) + StreamSocket* sock = reinterpret_cast(session_wrap); +#ifdef _WIN32 + GnuTLSIOHook* session = static_cast(sock->GetIOHook()); +#endif + + if (sock->GetEventMask() & FD_WRITE_WILL_BLOCK) { #ifdef _WIN32 gnutls_transport_set_errno(session->sess, EAGAIN); @@ -783,7 +771,7 @@ info_done_dealloc: return -1; } - int rv = ServerInstance->SE->Send(session->socket, reinterpret_cast(buffer), size, 0); + int rv = ServerInstance->SE->Send(sock, reinterpret_cast(buffer), size, 0); #ifdef _WIN32 if (rv < 0) @@ -798,75 +786,55 @@ info_done_dealloc: #endif if (rv < (int)size) - ServerInstance->SE->ChangeEventMask(session->socket, FD_WRITE_WILL_BLOCK); + ServerInstance->SE->ChangeEventMask(sock, FD_WRITE_WILL_BLOCK); return rv; } public: - issl_session* sessions; - - GnuTLSIOHook(Module* parent) - : SSLIOHook(parent, "ssl/gnutls") - { - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; - } - - ~GnuTLSIOHook() - { - delete[] sessions; - } - - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool outbound, const reference& sslprofile) + : SSLIOHook(hookprov) + , sess(NULL) + , status(ISSL_NONE) + , profile(sslprofile) { - issl_session* session = &sessions[user->GetFd()]; - - /* For STARTTLS: Don't try and init a session on a socket that already has a session */ - if (session->sess) - return; - - InitSession(user, true); - } - - void OnStreamSocketConnect(StreamSocket* user) CXX11_OVERRIDE - { - InitSession(user, false); + InitSession(sock, outbound); + sock->AddIOHook(this); + Handshake(sock); } void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE { - CloseSession(&sessions[user->GetFd()]); + CloseSession(); } int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE { - issl_session* session = &sessions[user->GetFd()]; - - if (!session->sess) + if (!this->sess) { - CloseSession(session); + CloseSession(); user->SetError("No SSL session"); return -1; } - if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) + if (this->status == ISSL_HANDSHAKING_READ || this->status == ISSL_HANDSHAKING_WRITE) { // The handshake isn't finished, try to finish it. - if(!Handshake(session, user)) + if (!Handshake(user)) { - if (session->status != ISSL_CLOSING) + if (this->status != ISSL_CLOSING) return 0; return -1; } } - // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. + // If we resumed the handshake then this->status will be ISSL_HANDSHAKEN. - if (session->status == ISSL_HANDSHAKEN) + if (this->status == ISSL_HANDSHAKEN) { char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - int ret = gnutls_record_recv(session->sess, buffer, bufsiz); + int ret = gnutls_record_recv(this->sess, buffer, bufsiz); if (ret > 0) { recvq.append(buffer, ret); @@ -879,17 +847,17 @@ info_done_dealloc: else if (ret == 0) { user->SetError("Connection closed"); - CloseSession(session); + CloseSession(); return -1; } else { user->SetError(gnutls_strerror(ret)); - CloseSession(session); + CloseSession(); return -1; } } - else if (session->status == ISSL_CLOSING) + else if (this->status == ISSL_CLOSING) return -1; return 0; @@ -897,29 +865,27 @@ info_done_dealloc: int OnStreamSocketWrite(StreamSocket* user, std::string& sendq) CXX11_OVERRIDE { - issl_session* session = &sessions[user->GetFd()]; - - if (!session->sess) + if (!this->sess) { - CloseSession(session); + CloseSession(); user->SetError("No SSL session"); return -1; } - if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ) + if (this->status == ISSL_HANDSHAKING_WRITE || this->status == ISSL_HANDSHAKING_READ) { // The handshake isn't finished, try to finish it. - Handshake(session, user); - if (session->status != ISSL_CLOSING) + Handshake(user); + if (this->status != ISSL_CLOSING) return 0; return -1; } int ret = 0; - if (session->status == ISSL_HANDSHAKEN) + if (this->status == ISSL_HANDSHAKEN) { - ret = gnutls_record_send(session->sess, sendq.data(), sendq.length()); + ret = gnutls_record_send(this->sess, sendq.data(), sendq.length()); if (ret == (int)sendq.length()) { @@ -940,7 +906,7 @@ info_done_dealloc: else // (ret < 0) { user->SetError(gnutls_strerror(ret)); - CloseSession(session); + CloseSession(); return -1; } } @@ -948,16 +914,8 @@ info_done_dealloc: return 0; } - ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE - { - int fd = sock->GetFd(); - issl_session* session = &sessions[fd]; - return session->cert; - } - void TellCiphersAndFingerprint(LocalUser* user) { - const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess; if (sess) { std::string text = "*** You are connected using SSL cipher '"; @@ -966,13 +924,14 @@ info_done_dealloc: text.append("-").append(UnknownIfNULL(gnutls_cipher_get_name(gnutls_cipher_get(sess)))).append("-"); text.append(UnknownIfNULL(gnutls_mac_get_name(gnutls_mac_get(sess)))).append("'"); - ssl_cert* cert = sessions[user->eh.GetFd()].cert; - if (!cert->fingerprint.empty()) - text += " and your SSL fingerprint is " + cert->fingerprint; + if (!certificate->fingerprint.empty()) + text += " and your SSL fingerprint is " + certificate->fingerprint; user->WriteNotice(text); } } + + GnuTLS::Profile* GetProfile() { return profile; } }; int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_datum_t* req_ca_rdn, int nreqs, const gnutls_pk_algorithm_t* sign_algos, int sign_algos_length, cert_cb_last_param_type* st) @@ -983,8 +942,8 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d st->cert_type = GNUTLS_CRT_X509; st->key_type = GNUTLS_PRIVKEY_X509; #endif - issl_session* session = reinterpret_cast(gnutls_transport_get_ptr(sess)); - GnuTLS::X509Credentials& cred = session->profile->GetX509Credentials(); + StreamSocket* sock = reinterpret_cast(gnutls_transport_get_ptr(sess)); + GnuTLS::X509Credentials& cred = static_cast(sock->GetIOHook())->GetProfile()->GetX509Credentials(); st->ncerts = cred.certs.size(); st->cert.x509 = cred.certs.raw(); @@ -994,15 +953,41 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d return 0; } +class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider +{ + reference profile; + + public: + GnuTLSIOHookProvider(Module* mod, reference& prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); + } + + ~GnuTLSIOHookProvider() + { + ServerInstance->Modules->DelService(*this); + } + + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, true, profile); + } + + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new GnuTLSIOHook(this, sock, false, profile); + } +}; + class ModuleSSLGnuTLS : public Module { - typedef std::vector > ProfileList; + typedef std::vector > ProfileList; // First member of the class, gets constructed first and destructed last GnuTLS::Init libinit; - GnuTLSIOHook iohook; - std::string sslports; RandGen randhandler; @@ -1026,7 +1011,7 @@ class ModuleSSLGnuTLS : public Module try { reference profile(GnuTLS::Profile::Create(defname, tag)); - newprofiles.push_back(profile); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); } catch (CoreException& ex) { @@ -1057,7 +1042,7 @@ class ModuleSSLGnuTLS : public Module throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(profile); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); } // New profiles are ok, begin using them @@ -1066,7 +1051,7 @@ class ModuleSSLGnuTLS : public Module } public: - ModuleSSLGnuTLS() : iohook(this) + ModuleSSLGnuTLS() { #ifndef GNUTLS_HAS_RND gcry_control (GCRYCTL_INITIALIZATION_FINISHED, 0); @@ -1144,7 +1129,7 @@ class ModuleSSLGnuTLS : public Module { LocalUser* user = IS_LOCAL(static_cast(item)); - if (user && user->eh.GetIOHook() == &iohook) + if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. @@ -1164,27 +1149,11 @@ class ModuleSSLGnuTLS : public Module tokens["SSL"] = sslports; } - void OnHookIO(StreamSocket* user, ListenSocket* lsb) CXX11_OVERRIDE - { - if (!user->GetIOHook()) - { - std::string profilename = lsb->bind_tag->getString("ssl"); - for (ProfileList::const_iterator i = profiles.begin(); i != profiles.end(); ++i) - { - if ((*i)->GetName() == profilename) - { - iohook.sessions[user->GetFd()].profile = *i; - user->AddIOHook(&iohook); - break; - } - } - } - } - void OnUserConnect(LocalUser* user) CXX11_OVERRIDE { - if (user->eh.GetIOHook() == &iohook) - iohook.TellCiphersAndFingerprint(user); + IOHook* hook = user->eh.GetIOHook(); + if (hook && hook->prov->creator == this) + static_cast(hook)->TellCiphersAndFingerprint(user); } }; diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index 11f4a365e..962350e1c 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -235,26 +235,6 @@ namespace OpenSSL }; } -/** Represents an SSL user's extra data - */ -class issl_session -{ -public: - SSL* sess; - issl_status status; - reference cert; - reference profile; - - bool outbound; - bool data_to_write; - - issl_session() - { - outbound = false; - data_to_write = false; - } -}; - static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) { /* XXX: This will allow self signed certificates. @@ -272,34 +252,40 @@ static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) class OpenSSLIOHook : public SSLIOHook { private: - bool Handshake(StreamSocket* user, issl_session* session) + SSL* sess; + issl_status status; + const bool outbound; + bool data_to_write; + reference profile; + + bool Handshake(StreamSocket* user) { int ret; - if (session->outbound) - ret = SSL_connect(session->sess); + if (outbound) + ret = SSL_connect(sess); else - ret = SSL_accept(session->sess); + ret = SSL_accept(sess); if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_READ) { ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); - session->status = ISSL_HANDSHAKING; + this->status = ISSL_HANDSHAKING; return true; } else if (err == SSL_ERROR_WANT_WRITE) { ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_SINGLE_WRITE); - session->status = ISSL_HANDSHAKING; + this->status = ISSL_HANDSHAKING; return true; } else { - CloseSession(session); + CloseSession(); } return false; @@ -307,9 +293,9 @@ class OpenSSLIOHook : public SSLIOHook else if (ret > 0) { // Handshake complete. - VerifyCertificate(session, user); + VerifyCertificate(); - session->status = ISSL_OPEN; + status = ISSL_OPEN; ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE | FD_ADD_TRIAL_WRITE); @@ -317,38 +303,35 @@ class OpenSSLIOHook : public SSLIOHook } else if (ret == 0) { - CloseSession(session); + CloseSession(); return true; } return true; } - void CloseSession(issl_session* session) + void CloseSession() { - if (session->sess) + if (sess) { - SSL_shutdown(session->sess); - SSL_free(session->sess); + SSL_shutdown(sess); + SSL_free(sess); } - - session->sess = NULL; - session->status = ISSL_NONE; + sess = NULL; + certificate = NULL; + status = ISSL_NONE; errno = EIO; } - void VerifyCertificate(issl_session* session, StreamSocket* user) + void VerifyCertificate() { - if (!session->sess || !user) - return; - X509* cert; ssl_cert* certinfo = new ssl_cert; - session->cert = certinfo; + this->certificate = certinfo; unsigned int n; unsigned char md[EVP_MAX_MD_SIZE]; - cert = SSL_get_peer_certificate((SSL*)session->sess); + cert = SSL_get_peer_certificate(sess); if (!cert) { @@ -356,7 +339,7 @@ class OpenSSLIOHook : public SSLIOHook return; } - certinfo->invalid = (SSL_get_verify_result(session->sess) != X509_V_OK); + certinfo->invalid = (SSL_get_verify_result(sess) != X509_V_OK); if (!SelfSigned) { @@ -372,7 +355,7 @@ class OpenSSLIOHook : public SSLIOHook certinfo->dn = X509_NAME_oneline(X509_get_subject_name(cert),0,0); certinfo->issuer = X509_NAME_oneline(X509_get_issuer_name(cert),0,0); - if (!X509_digest(cert, session->profile->GetDigest(), md, &n)) + if (!X509_digest(cert, profile->GetDigest(), md, &n)) { certinfo->error = "Out of memory generating fingerprint"; } @@ -390,129 +373,73 @@ class OpenSSLIOHook : public SSLIOHook } public: - issl_session* sessions; - - OpenSSLIOHook(Module* mod) - : SSLIOHook(mod, "ssl/openssl") + OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool is_outbound, SSL* session, const reference& sslprofile) + : SSLIOHook(hookprov) + , sess(session) + , status(ISSL_NONE) + , outbound(is_outbound) + , data_to_write(false) + , profile(sslprofile) { - sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; - } - - ~OpenSSLIOHook() - { - delete[] sessions; - } - - void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE - { - int fd = user->GetFd(); - - issl_session* session = &sessions[fd]; - - session->sess = session->profile->CreateServerSession(); - session->status = ISSL_NONE; - session->outbound = false; - session->cert = NULL; - - if (session->sess == NULL) + if (sess == NULL) return; + if (SSL_set_fd(sess, sock->GetFd()) == 0) + throw ModuleException("Can't set fd with SSL_set_fd: " + ConvToStr(sock->GetFd())); - if (SSL_set_fd(session->sess, fd) == 0) - { - ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Can't set fd with SSL_set_fd: %d", fd); - return; - } - - Handshake(user, session); - } - - void OnStreamSocketConnect(StreamSocket* user) CXX11_OVERRIDE - { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() -1)) - return; - - issl_session* session = &sessions[fd]; - - session->sess = session->profile->CreateClientSession(); - session->status = ISSL_NONE; - session->outbound = true; - - if (session->sess == NULL) - return; - - if (SSL_set_fd(session->sess, fd) == 0) - { - ServerInstance->Logs->Log(MODNAME, LOG_DEBUG, "BUG: Can't set fd with SSL_set_fd: %d", fd); - return; - } - - Handshake(user, session); + sock->AddIOHook(this); + Handshake(sock); } void OnStreamSocketClose(StreamSocket* user) CXX11_OVERRIDE { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return; - - CloseSession(&sessions[fd]); + CloseSession(); } int OnStreamSocketRead(StreamSocket* user, std::string& recvq) CXX11_OVERRIDE { - int fd = user->GetFd(); - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return -1; - - issl_session* session = &sessions[fd]; - - if (!session->sess) + if (!sess) { - CloseSession(session); + CloseSession(); return -1; } - if (session->status == ISSL_HANDSHAKING) + if (status == ISSL_HANDSHAKING) { // The handshake isn't finished and it wants to read, try to finish it. - if (!Handshake(user, session)) + if (!Handshake(user)) { // Couldn't resume handshake. - if (session->status == ISSL_NONE) + if (status == ISSL_NONE) return -1; return 0; } } - // If we resumed the handshake then session->status will be ISSL_OPEN + // If we resumed the handshake then this->status will be ISSL_OPEN - if (session->status == ISSL_OPEN) + if (status == ISSL_OPEN) { char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - int ret = SSL_read(session->sess, buffer, bufsiz); + int ret = SSL_read(sess, buffer, bufsiz); if (ret > 0) { recvq.append(buffer, ret); - if (session->data_to_write) + if (data_to_write) ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_SINGLE_WRITE); return 1; } else if (ret == 0) { // Client closed connection. - CloseSession(session); + CloseSession(); user->SetError("Connection closed"); return -1; } else if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_READ) { @@ -526,7 +453,7 @@ class OpenSSLIOHook : public SSLIOHook } else { - CloseSession(session); + CloseSession(); return -1; } } @@ -537,35 +464,31 @@ class OpenSSLIOHook : public SSLIOHook int OnStreamSocketWrite(StreamSocket* user, std::string& buffer) CXX11_OVERRIDE { - int fd = user->GetFd(); - - issl_session* session = &sessions[fd]; - - if (!session->sess) + if (!sess) { - CloseSession(session); + CloseSession(); return -1; } - session->data_to_write = true; + data_to_write = true; - if (session->status == ISSL_HANDSHAKING) + if (status == ISSL_HANDSHAKING) { - if (!Handshake(user, session)) + if (!Handshake(user)) { // Couldn't resume handshake. - if (session->status == ISSL_NONE) + if (status == ISSL_NONE) return -1; return 0; } } - if (session->status == ISSL_OPEN) + if (status == ISSL_OPEN) { - int ret = SSL_write(session->sess, buffer.data(), buffer.size()); + int ret = SSL_write(sess, buffer.data(), buffer.size()); if (ret == (int)buffer.length()) { - session->data_to_write = false; + data_to_write = false; ServerInstance->SE->ChangeEventMask(user, FD_WANT_POLL_READ | FD_WANT_NO_WRITE); return 1; } @@ -577,12 +500,12 @@ class OpenSSLIOHook : public SSLIOHook } else if (ret == 0) { - CloseSession(session); + CloseSession(); return -1; } else if (ret < 0) { - int err = SSL_get_error(session->sess, ret); + int err = SSL_get_error(sess, ret); if (err == SSL_ERROR_WANT_WRITE) { @@ -596,7 +519,7 @@ class OpenSSLIOHook : public SSLIOHook } else { - CloseSession(session); + CloseSession(); return -1; } } @@ -604,20 +527,12 @@ class OpenSSLIOHook : public SSLIOHook return 0; } - ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE - { - int fd = sock->GetFd(); - issl_session* session = &sessions[fd]; - return session->cert; - } - void TellCiphersAndFingerprint(LocalUser* user) { - issl_session& s = sessions[user->eh.GetFd()]; - if (s.sess) + if (sess) { - std::string text = "*** You are connected using SSL cipher '" + std::string(SSL_get_cipher(s.sess)) + "'"; - const std::string& fingerprint = s.cert->fingerprint; + std::string text = "*** You are connected using SSL cipher '" + std::string(SSL_get_cipher(sess)) + "'"; + const std::string& fingerprint = certificate->fingerprint; if (!fingerprint.empty()) text += " and your SSL fingerprint is " + fingerprint; @@ -626,12 +541,39 @@ class OpenSSLIOHook : public SSLIOHook } }; +class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider +{ + reference profile; + + public: + OpenSSLIOHookProvider(Module* mod, reference& prof) + : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) + , profile(prof) + { + ServerInstance->Modules->AddService(*this); + } + + ~OpenSSLIOHookProvider() + { + ServerInstance->Modules->DelService(*this); + } + + void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE + { + new OpenSSLIOHook(this, sock, false, profile->CreateServerSession(), profile); + } + + void OnConnect(StreamSocket* sock) CXX11_OVERRIDE + { + new OpenSSLIOHook(this, sock, true, profile->CreateClientSession(), profile); + } +}; + class ModuleSSLOpenSSL : public Module { - typedef std::vector > ProfileList; + typedef std::vector > ProfileList; std::string sslports; - OpenSSLIOHook iohook; ProfileList profiles; void ReadProfiles() @@ -648,7 +590,7 @@ class ModuleSSLOpenSSL : public Module try { reference profile(new OpenSSL::Profile(defname, tag)); - newprofiles.push_back(profile); + newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); } catch (OpenSSL::Exception& ex) { @@ -679,14 +621,14 @@ class ModuleSSLOpenSSL : public Module throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(profile); + newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); } profiles.swap(newprofiles); } public: - ModuleSSLOpenSSL() : iohook(this) + ModuleSSLOpenSSL() { // Initialize OpenSSL SSL_library_init(); @@ -698,24 +640,6 @@ class ModuleSSLOpenSSL : public Module ReadProfiles(); } - void OnHookIO(StreamSocket* user, ListenSocket* lsb) CXX11_OVERRIDE - { - if (user->GetIOHook()) - return; - - ConfigTag* tag = lsb->bind_tag; - std::string profilename = tag->getString("ssl"); - for (ProfileList::const_iterator i = profiles.begin(); i != profiles.end(); ++i) - { - if ((*i)->GetName() == profilename) - { - iohook.sessions[user->GetFd()].profile = *i; - user->AddIOHook(&iohook); - break; - } - } - } - void ReadConfig(ConfigStatus& status) CXX11_OVERRIDE { sslports.clear(); @@ -778,8 +702,9 @@ class ModuleSSLOpenSSL : public Module void OnUserConnect(LocalUser* user) CXX11_OVERRIDE { - if (user->eh.GetIOHook() == &iohook) - iohook.TellCiphersAndFingerprint(user); + IOHook* hook = user->eh.GetIOHook(); + if (hook && hook->prov->creator == this) + static_cast(hook)->TellCiphersAndFingerprint(user); } void OnCleanup(int target_type, void* item) CXX11_OVERRIDE @@ -788,7 +713,7 @@ class ModuleSSLOpenSSL : public Module { LocalUser* user = IS_LOCAL((User*)item); - if (user && user->eh.GetIOHook() == &iohook) + if (user && user->eh.GetIOHook() && user->eh.GetIOHook()->prov->creator == this) { // User is using SSL, they're a local user, and they're using one of *our* SSL ports. // Potentially there could be multiple SSL modules loaded at once on different ports. diff --git a/src/modules/m_httpd.cpp b/src/modules/m_httpd.cpp index 735551dff..d0291b8cc 100644 --- a/src/modules/m_httpd.cpp +++ b/src/modules/m_httpd.cpp @@ -65,9 +65,8 @@ class HttpServerSocket : public BufferedSocket { InternalState = HTTP_SERVE_WAIT_REQUEST; - FOREACH_MOD(OnHookIO, (this, via)); - if (GetIOHook()) - GetIOHook()->OnStreamSocketAccept(this, client, server); + if (via->iohookprov) + via->iohookprov->OnAccept(this, client, server); } ~HttpServerSocket() diff --git a/src/modules/m_spanningtree/main.cpp b/src/modules/m_spanningtree/main.cpp index 671e10269..1782f7e2a 100644 --- a/src/modules/m_spanningtree/main.cpp +++ b/src/modules/m_spanningtree/main.cpp @@ -677,7 +677,7 @@ void ModuleSpanningTree::OnUnloadModule(Module* mod) for (TreeServer::ChildServers::const_iterator i = list.begin(); i != list.end(); ++i) { TreeSocket* sock = (*i)->GetSocket(); - if (sock && sock->GetIOHook() && sock->GetIOHook()->creator == mod) + if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod) { sock->SendError("SSL module unloaded"); sock->Close(); @@ -687,7 +687,7 @@ void ModuleSpanningTree::OnUnloadModule(Module* mod) for (SpanningTreeUtilities::TimeoutList::const_iterator i = Utils->timeoutlist.begin(); i != Utils->timeoutlist.end(); ++i) { TreeSocket* sock = i->first; - if (sock->GetIOHook() && sock->GetIOHook()->creator == mod) + if (sock->GetIOHook() && sock->GetIOHook()->prov->creator == mod) sock->Close(); } } diff --git a/src/modules/m_spanningtree/treesocket1.cpp b/src/modules/m_spanningtree/treesocket1.cpp index fa8a94f72..9c262f1ea 100644 --- a/src/modules/m_spanningtree/treesocket1.cpp +++ b/src/modules/m_spanningtree/treesocket1.cpp @@ -44,16 +44,7 @@ TreeSocket::TreeSocket(Link* link, Autoconnect* myac, const std::string& ipaddr) capab->link = link; capab->ac = myac; capab->capab_phase = 0; - if (!link->Hook.empty()) - { - ServiceProvider* prov = ServerInstance->Modules->FindService(SERVICE_IOHOOK, link->Hook); - if (!prov) - { - SetError("Could not find hook '" + link->Hook + "' for connection to " + linkID); - return; - } - AddIOHook(static_cast(prov)); - } + DoConnect(ipaddr, link->Port, link->Timeout, link->Bind); Utils->timeoutlist[this] = std::pair(linkID, link->Timeout); SendCapabilities(1); @@ -71,9 +62,8 @@ TreeSocket::TreeSocket(int newfd, ListenSocket* via, irc::sockets::sockaddrs* cl capab = new CapabData; capab->capab_phase = 0; - FOREACH_MOD(OnHookIO, (this, via)); - if (GetIOHook()) - GetIOHook()->OnStreamSocketAccept(this, client, server); + if (via->iohookprov) + via->iohookprov->OnAccept(this, client, server); SendCapabilities(1); Utils->timeoutlist[this] = std::pair(linkID, 30); @@ -116,6 +106,17 @@ void TreeSocket::OnConnected() { if (this->LinkState == CONNECTING) { + if (!capab->link->Hook.empty()) + { + ServiceProvider* prov = ServerInstance->Modules->FindService(SERVICE_IOHOOK, capab->link->Hook); + if (!prov) + { + SetError("Could not find hook '" + capab->link->Hook + "' for connection to " + linkID); + return; + } + static_cast(prov)->OnConnect(this); + } + ServerInstance->SNO->WriteGlobalSno('l', "Connection to \2%s\2[%s] started.", linkID.c_str(), (capab->link->HiddenFromStats ? "" : capab->link->IPAddr.c_str())); this->SendCapabilities(1); diff --git a/src/modules/m_starttls.cpp b/src/modules/m_starttls.cpp index 09c9b4f0f..d591eed55 100644 --- a/src/modules/m_starttls.cpp +++ b/src/modules/m_starttls.cpp @@ -30,10 +30,10 @@ enum class CommandStartTLS : public SplitCommand { - dynamic_reference_nocheck& ssl; + dynamic_reference_nocheck& ssl; public: - CommandStartTLS(Module* mod, dynamic_reference_nocheck& s) + CommandStartTLS(Module* mod, dynamic_reference_nocheck& s) : SplitCommand(mod, "STARTTLS") , ssl(s) { @@ -71,8 +71,7 @@ class CommandStartTLS : public SplitCommand */ user->eh.DoWrite(); - user->eh.AddIOHook(*ssl); - ssl->OnStreamSocketAccept(&user->eh, NULL, NULL); + ssl->OnAccept(&user->eh, NULL, NULL); return CMD_SUCCESS; } @@ -82,7 +81,7 @@ class ModuleStartTLS : public Module { CommandStartTLS starttls; GenericCap tls; - dynamic_reference_nocheck ssl; + dynamic_reference_nocheck ssl; public: ModuleStartTLS() diff --git a/src/socket.cpp b/src/socket.cpp index 4ebed1ccd..c65cd5b27 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -106,6 +106,8 @@ int InspIRCd::BindPorts(FailedPortList &failed_ports) if ((**n).bind_desc == bind_readable) { (*n)->bind_tag = tag; // Replace tag, we know addr and port match, but other info (type, ssl) may not + (*n)->ResetIOHookProvider(); + skip = true; old_ports.erase(n); break; diff --git a/src/usermanager.cpp b/src/usermanager.cpp index 745934fd4..29d1f7370 100644 --- a/src/usermanager.cpp +++ b/src/usermanager.cpp @@ -62,20 +62,9 @@ void UserManager::AddUser(int socket, ListenSocket* via, irc::sockets::sockaddrs } UserIOHandler* eh = &New->eh; - /* Give each of the modules an attempt to hook the user for I/O */ - FOREACH_MOD(OnHookIO, (eh, via)); - - if (eh->GetIOHook()) - { - try - { - eh->GetIOHook()->OnStreamSocketAccept(eh, client, server); - } - catch (CoreException& modexcept) - { - ServerInstance->Logs->Log("SOCKET", LOG_DEBUG, "%s threw an exception: %s", modexcept.GetSource().c_str(), modexcept.GetReason().c_str()); - } - } + // If this listener has an IO hook provider set then tell it about the connection + if (via->iohookprov) + via->iohookprov->OnAccept(eh, client, server); ServerInstance->Logs->Log("USERS", LOG_DEBUG, "New user fd: %d", socket); -- cgit v1.2.3