diff options
author | Attila Molnar <attilamolnar@hush.com> | 2017-01-10 20:21:57 +0100 |
---|---|---|
committer | Peter Powell <petpow@saberuk.com> | 2017-11-13 16:38:30 +0000 |
commit | 451e687f681ccab5c02a8de1a7d59b324efbfe08 (patch) | |
tree | a748afd720e4e4c1ae2d6a8717022e4d18f43d3b /src/modules/extra | |
parent | 0fd2d50fcf1bcff107d6d185aad5d8e9245d4141 (diff) |
Unite SSL service providers and SSL profile classes
Diffstat (limited to 'src/modules/extra')
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 145 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_mbedtls.cpp | 175 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_openssl.cpp | 50 |
3 files changed, 215 insertions, 155 deletions
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index 50c847ee4..534c3abbc 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -566,7 +566,7 @@ namespace GnuTLS int ret() const { return retval; } }; - class Profile : public refcountbase + class Profile { /** Name of this profile */ @@ -596,22 +596,6 @@ namespace GnuTLS */ const bool requestclientcert; - Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr, - std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr, - const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL, - unsigned int recsize, bool Requestclientcert) - : name(profilename) - , x509cred(certstr, keystr) - , min_dh_bits(mindh) - , hash(hashstr) - , priority(priostr) - , outrecsize(recsize) - , requestclientcert(Requestclientcert) - { - x509cred.SetDH(DH); - x509cred.SetCA(CA, CRL); - } - static std::string ReadFile(const std::string& filename) { FileReader reader(filename); @@ -647,42 +631,66 @@ namespace GnuTLS } public: - static reference<Profile> Create(const std::string& profilename, ConfigTag* tag) + struct Config { - std::string certstr = ReadFile(tag->getString("certfile", "cert.pem")); - std::string keystr = ReadFile(tag->getString("keyfile", "key.pem")); + std::string name; - std::auto_ptr<DHParams> dh = DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem"))); - - std::string priostr = GetPrioStr(profilename, tag); - unsigned int mindh = tag->getInt("mindhbits", 1024); - std::string hashstr = tag->getString("hash", "md5"); - - // Load trusted CA and revocation list, if set std::auto_ptr<X509CertList> ca; std::auto_ptr<X509CRL> crl; - std::string filename = tag->getString("cafile"); - if (!filename.empty()) - { - ca.reset(new X509CertList(ReadFile(filename))); - filename = tag->getString("crlfile"); + std::string certstr; + std::string keystr; + std::auto_ptr<DHParams> dh; + + std::string priostr; + unsigned int mindh; + std::string hashstr; + + unsigned int outrecsize; + bool requestclientcert; + + Config(const std::string& profilename, ConfigTag* tag) + : name(profilename) + , certstr(ReadFile(tag->getString("certfile", "cert.pem"))) + , keystr(ReadFile(tag->getString("keyfile", "key.pem"))) + , dh(DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem")))) + , priostr(GetPrioStr(profilename, tag)) + , mindh(tag->getInt("mindhbits", 1024)) + , hashstr(tag->getString("hash", "md5")) + , requestclientcert(tag->getBool("requestclientcert", true)) + { + // Load trusted CA and revocation list, if set + std::string filename = tag->getString("cafile"); if (!filename.empty()) - crl.reset(new X509CRL(ReadFile(filename))); - } + { + ca.reset(new X509CertList(ReadFile(filename))); + + filename = tag->getString("crlfile"); + if (!filename.empty()) + crl.reset(new X509CRL(ReadFile(filename))); + } #ifdef INSPIRCD_GNUTLS_HAS_CORK - // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked - unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512); + // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked + outrecsize = tag->getInt("outrecsize", 2048, 512); #else - unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384); + outrecsize = tag->getInt("outrecsize", 2048, 512, 16384); #endif + } + }; - const bool requestclientcert = tag->getBool("requestclientcert", true); - - return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert); + Profile(Config& config) + : name(config.name) + , x509cred(config.certstr, config.keystr) + , min_dh_bits(config.mindh) + , hash(config.hashstr) + , priority(config.priostr) + , outrecsize(config.outrecsize) + , requestclientcert(config.requestclientcert) + { + x509cred.SetDH(config.dh); + x509cred.SetCA(config.ca, config.crl); } - /** Set up the given session with the settings in this profile */ void SetupSession(gnutls_session_t sess) @@ -708,7 +716,6 @@ class GnuTLSIOHook : public SSLIOHook private: gnutls_session_t sess; issl_status status; - reference<GnuTLS::Profile> profile; #ifdef INSPIRCD_GNUTLS_HAS_CORK size_t gbuffersize; #endif @@ -855,7 +862,7 @@ class GnuTLSIOHook : public SSLIOHook issuer.clear(); } - if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0) + if ((ret = gnutls_x509_crt_get_fingerprint(cert, GetProfile().GetHash(), digest, &digest_size)) < 0) { certinfo->error = gnutls_strerror(ret); } @@ -1043,11 +1050,10 @@ info_done_dealloc: #endif // INSPIRCD_GNUTLS_HAS_VECTOR_PUSH public: - GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags, const reference<GnuTLS::Profile>& sslprofile) + GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags) : SSLIOHook(hookprov) , sess(NULL) , status(ISSL_NONE) - , profile(sslprofile) #ifdef INSPIRCD_GNUTLS_HAS_CORK , gbuffersize(0) #endif @@ -1060,7 +1066,7 @@ info_done_dealloc: gnutls_transport_set_push_function(sess, gnutls_push_wrapper); #endif gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper); - profile->SetupSession(sess); + GetProfile().SetupSession(sess); sock->AddIOHook(this); Handshake(sock); @@ -1132,7 +1138,7 @@ info_done_dealloc: // GnuTLS buffer is empty but sendq is not, begin sending data from the sendq gnutls_record_cork(this->sess); - while ((!sendq.empty()) && (gbuffersize < profile->GetOutgoingRecordSize())) + while ((!sendq.empty()) && (gbuffersize < GetProfile().GetOutgoingRecordSize())) { const StreamSocket::SendQueue::Element& elem = sendq.front(); gbuffersize += elem.length(); @@ -1150,7 +1156,7 @@ info_done_dealloc: while (!sendq.empty()) { - FlattenSendQueue(sendq, profile->GetOutgoingRecordSize()); + FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize()); const StreamSocket::SendQueue::Element& buffer = sendq.front(); ret = HandleWriteRet(user, gnutls_record_send(this->sess, buffer.data(), buffer.length())); @@ -1201,7 +1207,7 @@ info_done_dealloc: return true; } - GnuTLS::Profile* GetProfile() { return profile; } + GnuTLS::Profile& GetProfile(); bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); } }; @@ -1214,7 +1220,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d st->key_type = GNUTLS_PRIVKEY_X509; #endif StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess)); - GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials(); + GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile().GetX509Credentials(); st->ncerts = cred.certs.size(); st->cert.x509 = cred.certs.raw(); @@ -1224,14 +1230,14 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d return 0; } -class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider +class GnuTLSIOHookProvider : public IOHookProvider { - reference<GnuTLS::Profile> profile; + GnuTLS::Profile profile; public: - GnuTLSIOHookProvider(Module* mod, reference<GnuTLS::Profile>& prof) - : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) - , profile(prof) + GnuTLSIOHookProvider(Module* mod, GnuTLS::Profile::Config& config) + : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL) + , profile(config) { ServerInstance->Modules->AddService(*this); } @@ -1243,15 +1249,23 @@ class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE { - new GnuTLSIOHook(this, sock, GNUTLS_SERVER, profile); + new GnuTLSIOHook(this, sock, GNUTLS_SERVER); } void OnConnect(StreamSocket* sock) CXX11_OVERRIDE { - new GnuTLSIOHook(this, sock, GNUTLS_CLIENT, profile); + new GnuTLSIOHook(this, sock, GNUTLS_CLIENT); } + + GnuTLS::Profile& GetProfile() { return profile; } }; +GnuTLS::Profile& GnuTLSIOHook::GetProfile() +{ + IOHookProvider* hookprov = prov; + return static_cast<GnuTLSIOHookProvider*>(hookprov)->GetProfile(); +} + class ModuleSSLGnuTLS : public Module { typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList; @@ -1278,8 +1292,8 @@ class ModuleSSLGnuTLS : public Module try { - reference<GnuTLS::Profile> profile(GnuTLS::Profile::Create(defname, tag)); - newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); + GnuTLS::Profile::Config profileconfig(defname, tag); + newprofiles.push_back(new GnuTLSIOHookProvider(this, profileconfig)); } catch (CoreException& ex) { @@ -1300,21 +1314,28 @@ class ModuleSSLGnuTLS : public Module continue; } - reference<GnuTLS::Profile> profile; + reference<GnuTLSIOHookProvider> prov; try { - profile = GnuTLS::Profile::Create(name, tag); + GnuTLS::Profile::Config profileconfig(name, tag); + prov = new GnuTLSIOHookProvider(this, profileconfig); } catch (CoreException& ex) { throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(new GnuTLSIOHookProvider(this, profile)); + newprofiles.push_back(prov); } // New profiles are ok, begin using them // Old profiles are deleted when their refcount drops to zero + for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i) + { + GnuTLSIOHookProvider& prov = **i; + ServerInstance->Modules.DelService(prov); + } + profiles.swap(newprofiles); } diff --git a/src/modules/extra/m_ssl_mbedtls.cpp b/src/modules/extra/m_ssl_mbedtls.cpp index 4e0032fdc..8c15342f2 100644 --- a/src/modules/extra/m_ssl_mbedtls.cpp +++ b/src/modules/extra/m_ssl_mbedtls.cpp @@ -345,7 +345,7 @@ namespace mbedTLS } }; - class Profile : public refcountbase + class Profile { /** Name of this profile */ @@ -378,29 +378,71 @@ namespace mbedTLS */ const unsigned int outrecsize; - Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr, - const std::string& dhstr, unsigned int mindh, const std::string& hashstr, - const std::string& ciphersuitestr, const std::string& curvestr, - const std::string& castr, const std::string& crlstr, - unsigned int recsize, - CTRDRBG& ctrdrbg, - int minver, int maxver, - bool requestclientcert - ) - : name(profilename) - , x509cred(certstr, keystr) - , ciphersuites(ciphersuitestr) - , curves(curvestr) - , serverctx(ctrdrbg, MBEDTLS_SSL_IS_SERVER) - , clientctx(ctrdrbg, MBEDTLS_SSL_IS_CLIENT) - , cacerts(castr, true) - , crl(crlstr) - , hash(hashstr) - , outrecsize(recsize) + public: + struct Config + { + const std::string name; + + CTRDRBG& ctrdrbg; + + const std::string certstr; + const std::string keystr; + const std::string dhstr; + + const std::string ciphersuitestr; + const std::string curvestr; + const unsigned int mindh; + const std::string hashstr; + + std::string crlstr; + std::string castr; + + const int minver; + const int maxver; + const unsigned int outrecsize; + const bool requestclientcert; + + Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg) + : name(profilename) + , ctrdrbg(ctr_drbg) + , certstr(ReadFile(tag->getString("certfile", "cert.pem"))) + , keystr(ReadFile(tag->getString("keyfile", "key.pem"))) + , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem"))) + , ciphersuitestr(tag->getString("ciphersuites")) + , curvestr(tag->getString("curves")) + , mindh(tag->getInt("mindhbits", 2048)) + , hashstr(tag->getString("hash", "sha256")) + , castr(tag->getString("cafile")) + , minver(tag->getInt("minver")) + , maxver(tag->getInt("maxver")) + , outrecsize(tag->getInt("outrecsize", 2048, 512, 16384)) + , requestclientcert(tag->getBool("requestclientcert", true)) + { + if (!castr.empty()) + { + castr = ReadFile(castr); + crlstr = tag->getString("crlfile"); + if (!crlstr.empty()) + crlstr = ReadFile(crlstr); + } + } + }; + + Profile(Config& config) + : name(config.name) + , x509cred(config.certstr, config.keystr) + , ciphersuites(config.ciphersuitestr) + , curves(config.curvestr) + , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER) + , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT) + , cacerts(config.castr, true) + , crl(config.crlstr) + , hash(config.hashstr) + , outrecsize(config.outrecsize) { serverctx.SetX509CertAndKey(x509cred); clientctx.SetX509CertAndKey(x509cred); - clientctx.SetMinDHBits(mindh); + clientctx.SetMinDHBits(config.mindh); if (!ciphersuites.empty()) { @@ -414,19 +456,19 @@ namespace mbedTLS clientctx.SetCurves(curves); } - serverctx.SetVersion(minver, maxver); - clientctx.SetVersion(minver, maxver); + serverctx.SetVersion(config.minver, config.maxver); + clientctx.SetVersion(config.minver, config.maxver); - if (!dhstr.empty()) + if (!config.dhstr.empty()) { - dhparams.set(dhstr); + dhparams.set(config.dhstr); serverctx.SetDHParams(dhparams); } clientctx.SetOptionalVerifyCert(); clientctx.SetCA(cacerts, crl); // The default for servers is to not request a client certificate from the peer - if (requestclientcert) + if (config.requestclientcert) { serverctx.SetOptionalVerifyCert(); serverctx.SetCA(cacerts, crl); @@ -442,35 +484,6 @@ namespace mbedTLS return ret; } - public: - static reference<Profile> Create(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg) - { - const std::string certstr = ReadFile(tag->getString("certfile", "cert.pem")); - const std::string keystr = ReadFile(tag->getString("keyfile", "key.pem")); - const std::string dhstr = ReadFile(tag->getString("dhfile", "dhparams.pem")); - - const std::string ciphersuitestr = tag->getString("ciphersuites"); - const std::string curvestr = tag->getString("curves"); - unsigned int mindh = tag->getInt("mindhbits", 2048); - std::string hashstr = tag->getString("hash", "sha256"); - - std::string crlstr; - std::string castr = tag->getString("cafile"); - if (!castr.empty()) - { - castr = ReadFile(castr); - crlstr = tag->getString("crlfile"); - if (!crlstr.empty()) - crlstr = ReadFile(crlstr); - } - - int minver = tag->getInt("minver"); - int maxver = tag->getInt("maxver"); - unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384); - const bool requestclientcert = tag->getBool("requestclientcert", true); - return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert); - } - /** Set up the given session with the settings in this profile */ void SetupClientSession(mbedtls_ssl_context* sess) @@ -501,7 +514,6 @@ class mbedTLSIOHook : public SSLIOHook mbedtls_ssl_context sess; Status status; - reference<mbedTLS::Profile> profile; void CloseSession() { @@ -575,7 +587,7 @@ class mbedTLSIOHook : public SSLIOHook } // If there is a certificate we can always generate a fingerprint - certificate->fingerprint = profile->GetHash().hash(cert->raw.p, cert->raw.len); + certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len); // At this point mbedTLS verified the cert already, we just need to check the results const uint32_t flags = mbedtls_ssl_get_verify_result(&sess); @@ -649,16 +661,15 @@ class mbedTLSIOHook : public SSLIOHook } public: - mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver, mbedTLS::Profile* sslprofile) + mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver) : SSLIOHook(hookprov) , status(ISSL_NONE) - , profile(sslprofile) { mbedtls_ssl_init(&sess); if (isserver) - profile->SetupServerSession(&sess); + GetProfile().SetupServerSession(&sess); else - profile->SetupClientSession(&sess); + GetProfile().SetupClientSession(&sess); mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL); @@ -725,7 +736,7 @@ class mbedTLSIOHook : public SSLIOHook // Session is ready for transferring application data while (!sendq.empty()) { - FlattenSendQueue(sendq, profile->GetOutgoingRecordSize()); + FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize()); const StreamSocket::SendQueue::Element& buffer = sendq.front(); int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length()); if (ret == (int)buffer.length()) @@ -788,17 +799,18 @@ class mbedTLSIOHook : public SSLIOHook return false; } + mbedTLS::Profile& GetProfile(); bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); } }; -class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider +class mbedTLSIOHookProvider : public IOHookProvider { - reference<mbedTLS::Profile> profile; + mbedTLS::Profile profile; public: - mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile* prof) - : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) - , profile(prof) + mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config) + : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL) + , profile(config) { ServerInstance->Modules->AddService(*this); } @@ -810,15 +822,23 @@ class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE { - new mbedTLSIOHook(this, sock, true, profile); + new mbedTLSIOHook(this, sock, true); } void OnConnect(StreamSocket* sock) CXX11_OVERRIDE { - new mbedTLSIOHook(this, sock, false, profile); + new mbedTLSIOHook(this, sock, false); } + + mbedTLS::Profile& GetProfile() { return profile; } }; +mbedTLS::Profile& mbedTLSIOHook::GetProfile() +{ + IOHookProvider* hookprov = prov; + return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile(); +} + class ModuleSSLmbedTLS : public Module { typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList; @@ -844,8 +864,8 @@ class ModuleSSLmbedTLS : public Module try { - reference<mbedTLS::Profile> profile(mbedTLS::Profile::Create(defname, tag, ctr_drbg)); - newprofiles.push_back(new mbedTLSIOHookProvider(this, profile)); + mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg); + newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig)); } catch (CoreException& ex) { @@ -866,21 +886,28 @@ class ModuleSSLmbedTLS : public Module continue; } - reference<mbedTLS::Profile> profile; + reference<mbedTLSIOHookProvider> prov; try { - profile = mbedTLS::Profile::Create(name, tag, ctr_drbg); + mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg); + prov = new mbedTLSIOHookProvider(this, profileconfig); } catch (CoreException& ex) { throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(new mbedTLSIOHookProvider(this, profile)); + newprofiles.push_back(prov); } // New profiles are ok, begin using them // Old profiles are deleted when their refcount drops to zero + for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i) + { + mbedTLSIOHookProvider& prov = **i; + ServerInstance->Modules.DelService(prov); + } + profiles.swap(newprofiles); } diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index 9b7e608a2..ae5e213b7 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -240,7 +240,7 @@ namespace OpenSSL } }; - class Profile : public refcountbase + class Profile { /** Name of this profile */ @@ -459,7 +459,6 @@ class OpenSSLIOHook : public SSLIOHook SSL* sess; issl_status status; bool data_to_write; - reference<OpenSSL::Profile> profile; // Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed int Handshake(StreamSocket* user) @@ -559,7 +558,7 @@ class OpenSSLIOHook : public SSLIOHook if (certinfo->issuer.find_first_of("\r\n") != std::string::npos) certinfo->issuer.clear(); - if (!X509_digest(cert, profile->GetDigest(), md, &n)) + if (!X509_digest(cert, GetProfile().GetDigest(), md, &n)) { certinfo->error = "Out of memory generating fingerprint"; } @@ -580,7 +579,7 @@ class OpenSSLIOHook : public SSLIOHook { if ((where & SSL_CB_HANDSHAKE_START) && (status == ISSL_OPEN)) { - if (profile->AllowRenegotiation()) + if (GetProfile().AllowRenegotiation()) return; // The other side is trying to renegotiate, kill the connection and change status @@ -622,12 +621,11 @@ class OpenSSLIOHook : public SSLIOHook friend void StaticSSLInfoCallback(const SSL* ssl, int where, int rc); public: - OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session, const reference<OpenSSL::Profile>& sslprofile) + OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session) : SSLIOHook(hookprov) , sess(session) , status(ISSL_NONE) , data_to_write(false) - , profile(sslprofile) { // Create BIO instance and store a pointer to the socket in it which will be used by the read and write functions #ifdef INSPIRCD_OPENSSL_OPAQUE_BIO @@ -721,7 +719,7 @@ class OpenSSLIOHook : public SSLIOHook while (!sendq.empty()) { ERR_clear_error(); - FlattenSendQueue(sendq, profile->GetOutgoingRecordSize()); + FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize()); const StreamSocket::SendQueue::Element& buffer = sendq.front(); int ret = SSL_write(sess, buffer.data(), buffer.size()); @@ -790,6 +788,7 @@ class OpenSSLIOHook : public SSLIOHook } bool IsHandshakeDone() const { return (status == ISSL_OPEN); } + OpenSSL::Profile& GetProfile(); }; static void StaticSSLInfoCallback(const SSL* ssl, int where, int rc) @@ -844,14 +843,14 @@ static int OpenSSL::BIOMethod::read(BIO* bio, char* buffer, int size) return ret; } -class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider +class OpenSSLIOHookProvider : public IOHookProvider { - reference<OpenSSL::Profile> profile; + OpenSSL::Profile profile; public: - OpenSSLIOHookProvider(Module* mod, reference<OpenSSL::Profile>& prof) - : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL) - , profile(prof) + OpenSSLIOHookProvider(Module* mod, const std::string& profilename, ConfigTag* tag) + : IOHookProvider(mod, "ssl/" + profilename, IOHookProvider::IOH_SSL) + , profile(profilename, tag) { ServerInstance->Modules->AddService(*this); } @@ -863,15 +862,23 @@ class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE { - new OpenSSLIOHook(this, sock, profile->CreateServerSession(), profile); + new OpenSSLIOHook(this, sock, profile.CreateServerSession()); } void OnConnect(StreamSocket* sock) CXX11_OVERRIDE { - new OpenSSLIOHook(this, sock, profile->CreateClientSession(), profile); + new OpenSSLIOHook(this, sock, profile.CreateClientSession()); } + + OpenSSL::Profile& GetProfile() { return profile; } }; +OpenSSL::Profile& OpenSSLIOHook::GetProfile() +{ + IOHookProvider* hookprov = prov; + return static_cast<OpenSSLIOHookProvider*>(hookprov)->GetProfile(); +} + class ModuleSSLOpenSSL : public Module { typedef std::vector<reference<OpenSSLIOHookProvider> > ProfileList; @@ -891,8 +898,7 @@ class ModuleSSLOpenSSL : public Module try { - reference<OpenSSL::Profile> profile(new OpenSSL::Profile(defname, tag)); - newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); + newprofiles.push_back(new OpenSSLIOHookProvider(this, defname, tag)); } catch (OpenSSL::Exception& ex) { @@ -913,17 +919,23 @@ class ModuleSSLOpenSSL : public Module continue; } - reference<OpenSSL::Profile> profile; + reference<OpenSSLIOHookProvider> prov; try { - profile = new OpenSSL::Profile(name, tag); + prov = new OpenSSLIOHookProvider(this, name, tag); } catch (CoreException& ex) { throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason()); } - newprofiles.push_back(new OpenSSLIOHookProvider(this, profile)); + newprofiles.push_back(prov); + } + + for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i) + { + OpenSSLIOHookProvider& prov = **i; + ServerInstance->Modules.DelService(prov); } profiles.swap(newprofiles); |