diff options
-rw-r--r-- | include/iohook.h | 12 | ||||
-rw-r--r-- | include/modules/ssl.h | 63 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 23 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_openssl.cpp | 23 | ||||
-rw-r--r-- | src/modules/m_sasl.cpp | 6 | ||||
-rw-r--r-- | src/modules/m_spanningtree/hmac.cpp | 11 | ||||
-rw-r--r-- | src/modules/m_sslinfo.cpp | 13 |
7 files changed, 93 insertions, 58 deletions
diff --git a/include/iohook.h b/include/iohook.h index 87403681d..7c3a0faee 100644 --- a/include/iohook.h +++ b/include/iohook.h @@ -24,8 +24,16 @@ class StreamSocket; class IOHook : public ServiceProvider { public: - IOHook(Module* mod, const std::string& Name) - : ServiceProvider(mod, Name, SERVICE_IOHOOK) { } + enum Type + { + IOH_UNKNOWN, + IOH_SSL + }; + + const Type type; + + IOHook(Module* mod, const std::string& Name, Type hooktype = IOH_UNKNOWN) + : ServiceProvider(mod, Name, SERVICE_IOHOOK), type(hooktype) { } /** Called immediately after any connection is accepted. This is intended for raw socket * processing (e.g. modules which wrap the tcp connection within another library) and provides diff --git a/include/modules/ssl.h b/include/modules/ssl.h index a45121537..9830b1ca6 100644 --- a/include/modules/ssl.h +++ b/include/modules/ssl.h @@ -132,20 +132,67 @@ class ssl_cert : public refcountbase } }; -/** Get certificate from a socket (only useful with an SSL module) */ -struct SocketCertificateRequest : public Request +class SSLIOHook : public IOHook { - StreamSocket* const sock; - ssl_cert* cert; + public: + SSLIOHook(Module* mod, const std::string& Name) + : IOHook(mod, Name, IOHook::IOH_SSL) + { + } + + /** + * Get the client certificate from a socket + * @param sock The socket to get the certificate from, must be using this IOHook + * @return The SSL client certificate information + */ + virtual ssl_cert* GetCertificate(StreamSocket* sock) = 0; - SocketCertificateRequest(StreamSocket* ss, Module* Me) - : Request(Me, (ss->GetIOHook() ? (Module*)ss->GetIOHook()->creator : NULL), "GET_SSL_CERT"), sock(ss), cert(NULL) + /** + * Get the fingerprint of a client certificate from a socket + * @param sock The socket to get the certificate fingerprint from, must be using this IOHook + * @return The fingerprint of the SSL client certificate sent by the peer, + * empty if no cert was sent + */ + std::string GetFingerprint(StreamSocket* sock) { - Send(); + ssl_cert* cert = GetCertificate(sock); + if (cert) + return cert->GetFingerprint(); + return ""; } +}; - std::string GetFingerprint() +/** Helper functions for obtaining SSL client certificates and key fingerprints + * from StreamSockets + */ +class SSLClientCert +{ + public: + /** + * Get the client certificate from a socket + * @param sock The socket to get the certificate from, the socket does not have to use SSL + * @return The SSL client certificate information, NULL if the peer is not using SSL + */ + static ssl_cert* GetCertificate(StreamSocket* sock) + { + IOHook* iohook = sock->GetIOHook(); + if ((!iohook) || (iohook->type != IOHook::IOH_SSL)) + return NULL; + + SSLIOHook* ssliohook = static_cast<SSLIOHook*>(iohook); + return ssliohook->GetCertificate(sock); + } + + /** + * Get the fingerprint of a client certificate from a socket + * @param sock The socket to get the certificate fingerprint from, the + * socket does not have to use SSL + * @return The key fingerprint from the SSL certificate sent by the peer, + * empty if no cert was sent or the peer is not using SSL + */ + static std::string GetFingerprint(StreamSocket* sock) { + ssl_cert* cert = SSLClientCert::GetCertificate(sock); if (cert) return cert->GetFingerprint(); return ""; diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index e051b34e7..3c82a5beb 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -100,7 +100,7 @@ public: issl_session() : socket(NULL), sess(NULL) {} }; -class GnuTLSIOHook : public IOHook +class GnuTLSIOHook : public SSLIOHook { private: void InitSession(StreamSocket* user, bool me_server) @@ -359,7 +359,7 @@ info_done_dealloc: int dh_bits; GnuTLSIOHook(Module* parent) - : IOHook(parent, "ssl/gnutls") + : SSLIOHook(parent, "ssl/gnutls") { sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; } @@ -501,6 +501,13 @@ info_done_dealloc: return 0; } + ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE + { + int fd = sock->GetFd(); + issl_session* session = &sessions[fd]; + return session->cert; + } + void TellCiphersAndFingerprint(LocalUser* user) { const gnutls_session_t& sess = sessions[user->eh.GetFd()].sess; @@ -895,18 +902,6 @@ class ModuleSSLGnuTLS : public Module } } - void OnRequest(Request& request) CXX11_OVERRIDE - { - if (strcmp("GET_SSL_CERT", request.id) == 0) - { - SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request); - int fd = req.sock->GetFd(); - issl_session* session = &iohook.sessions[fd]; - - req.cert = session->cert; - } - } - void OnUserConnect(LocalUser* user) CXX11_OVERRIDE { if (user->eh.GetIOHook() == &iohook) diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index 0c7362e6e..53c0ab875 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -101,7 +101,7 @@ static int OnVerify(int preverify_ok, X509_STORE_CTX *ctx) return 1; } -class OpenSSLIOHook : public IOHook +class OpenSSLIOHook : public SSLIOHook { private: bool Handshake(StreamSocket* user, issl_session* session) @@ -229,7 +229,7 @@ class OpenSSLIOHook : public IOHook bool use_sha; OpenSSLIOHook(Module* mod) - : IOHook(mod, "ssl/openssl") + : SSLIOHook(mod, "ssl/openssl") { sessions = new issl_session[ServerInstance->SE->GetMaxFds()]; } @@ -440,6 +440,13 @@ class OpenSSLIOHook : public IOHook return 0; } + ssl_cert* GetCertificate(StreamSocket* sock) CXX11_OVERRIDE + { + int fd = sock->GetFd(); + issl_session* session = &sessions[fd]; + return session->cert; + } + void TellCiphersAndFingerprint(LocalUser* user) { issl_session& s = sessions[user->eh.GetFd()]; @@ -653,18 +660,6 @@ class ModuleSSLOpenSSL : public Module { return Version("Provides SSL support for clients", VF_VENDOR); } - - void OnRequest(Request& request) CXX11_OVERRIDE - { - if (strcmp("GET_SSL_CERT", request.id) == 0) - { - SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request); - int fd = req.sock->GetFd(); - issl_session* session = &iohook.sessions[fd]; - - req.cert = session->cert; - } - } }; static int error_callback(const char *str, size_t len, void *u) diff --git a/src/modules/m_sasl.cpp b/src/modules/m_sasl.cpp index 322a726ce..45915ab4d 100644 --- a/src/modules/m_sasl.cpp +++ b/src/modules/m_sasl.cpp @@ -63,10 +63,10 @@ class SaslAuthenticator params.push_back("S"); params.push_back(method); - if (method == "EXTERNAL" && IS_LOCAL(user_)) + LocalUser* localuser = IS_LOCAL(user); + if (method == "EXTERNAL" && localuser) { - SocketCertificateRequest req(&((LocalUser*)user_)->eh, ServerInstance->Modules->Find("m_sasl.so")); - std::string fp = req.GetFingerprint(); + std::string fp = SSLClientCert::GetFingerprint(&localuser->eh); if (fp.size()) params.push_back(fp); diff --git a/src/modules/m_spanningtree/hmac.cpp b/src/modules/m_spanningtree/hmac.cpp index ad632dbc7..0b96f9b26 100644 --- a/src/modules/m_spanningtree/hmac.cpp +++ b/src/modules/m_spanningtree/hmac.cpp @@ -69,16 +69,6 @@ bool TreeSocket::ComparePass(const Link& link, const std::string &theirs) capab->auth_fingerprint = !link.Fingerprint.empty(); capab->auth_challenge = !capab->ourchallenge.empty() && !capab->theirchallenge.empty(); - std::string fp; - if (GetIOHook()) - { - SocketCertificateRequest req(this, Utils->Creator); - if (req.cert) - { - fp = req.cert->GetFingerprint(); - } - } - if (capab->auth_challenge) { std::string our_hmac = MakePass(link.RecvPass, capab->ourchallenge); @@ -94,6 +84,7 @@ bool TreeSocket::ComparePass(const Link& link, const std::string &theirs) return false; } + std::string fp = SSLClientCert::GetFingerprint(this); if (capab->auth_fingerprint) { /* Require fingerprint to exist and match */ diff --git a/src/modules/m_sslinfo.cpp b/src/modules/m_sslinfo.cpp index 8cdaa1cde..5516af7ef 100644 --- a/src/modules/m_sslinfo.cpp +++ b/src/modules/m_sslinfo.cpp @@ -191,10 +191,9 @@ class ModuleSSLInfo : public Module void OnUserConnect(LocalUser* user) CXX11_OVERRIDE { - SocketCertificateRequest req(&user->eh, this); - if (!req.cert) - return; - cmd.CertExt.set(user, req.cert); + ssl_cert* cert = SSLClientCert::GetCertificate(&user->eh); + if (cert) + cmd.CertExt.set(user, cert); } void OnPostConnect(User* user) CXX11_OVERRIDE @@ -214,15 +213,15 @@ class ModuleSSLInfo : public Module ModResult OnSetConnectClass(LocalUser* user, ConnectClass* myclass) CXX11_OVERRIDE { - SocketCertificateRequest req(&user->eh, this); + ssl_cert* cert = SSLClientCert::GetCertificate(&user->eh); bool ok = true; if (myclass->config->getString("requiressl") == "trusted") { - ok = (req.cert && req.cert->IsCAVerified()); + ok = (cert && cert->IsCAVerified()); } else if (myclass->config->getBool("requiressl")) { - ok = (req.cert != NULL); + ok = (cert != NULL); } if (!ok) |