summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAttila Molnar <attilamolnar@hush.com>2017-01-10 20:21:57 +0100
committerPeter Powell <petpow@saberuk.com>2017-11-13 16:38:30 +0000
commit451e687f681ccab5c02a8de1a7d59b324efbfe08 (patch)
treea748afd720e4e4c1ae2d6a8717022e4d18f43d3b
parent0fd2d50fcf1bcff107d6d185aad5d8e9245d4141 (diff)
Unite SSL service providers and SSL profile classes
-rw-r--r--include/iohook.h4
-rw-r--r--src/modules/extra/m_ssl_gnutls.cpp145
-rw-r--r--src/modules/extra/m_ssl_mbedtls.cpp175
-rw-r--r--src/modules/extra/m_ssl_openssl.cpp50
-rw-r--r--src/modules/m_websocket.cpp4
5 files changed, 219 insertions, 159 deletions
diff --git a/include/iohook.h b/include/iohook.h
index e99316b99..9ca17d77e 100644
--- a/include/iohook.h
+++ b/include/iohook.h
@@ -21,7 +21,7 @@
class StreamSocket;
-class IOHookProvider : public ServiceProvider
+class IOHookProvider : public refcountbase, public ServiceProvider
{
const bool middlehook;
@@ -69,7 +69,7 @@ class IOHook : public classbase
/** The IOHookProvider for this hook, contains information about the hook,
* such as the module providing it and the hook type.
*/
- IOHookProvider* const prov;
+ reference<IOHookProvider> prov;
/** Constructor
* @param provider IOHookProvider that creates this object
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp
index 50c847ee4..534c3abbc 100644
--- a/src/modules/extra/m_ssl_gnutls.cpp
+++ b/src/modules/extra/m_ssl_gnutls.cpp
@@ -566,7 +566,7 @@ namespace GnuTLS
int ret() const { return retval; }
};
- class Profile : public refcountbase
+ class Profile
{
/** Name of this profile
*/
@@ -596,22 +596,6 @@ namespace GnuTLS
*/
const bool requestclientcert;
- Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
- std::auto_ptr<DHParams>& DH, unsigned int mindh, const std::string& hashstr,
- const std::string& priostr, std::auto_ptr<X509CertList>& CA, std::auto_ptr<X509CRL>& CRL,
- unsigned int recsize, bool Requestclientcert)
- : name(profilename)
- , x509cred(certstr, keystr)
- , min_dh_bits(mindh)
- , hash(hashstr)
- , priority(priostr)
- , outrecsize(recsize)
- , requestclientcert(Requestclientcert)
- {
- x509cred.SetDH(DH);
- x509cred.SetCA(CA, CRL);
- }
-
static std::string ReadFile(const std::string& filename)
{
FileReader reader(filename);
@@ -647,42 +631,66 @@ namespace GnuTLS
}
public:
- static reference<Profile> Create(const std::string& profilename, ConfigTag* tag)
+ struct Config
{
- std::string certstr = ReadFile(tag->getString("certfile", "cert.pem"));
- std::string keystr = ReadFile(tag->getString("keyfile", "key.pem"));
+ std::string name;
- std::auto_ptr<DHParams> dh = DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem")));
-
- std::string priostr = GetPrioStr(profilename, tag);
- unsigned int mindh = tag->getInt("mindhbits", 1024);
- std::string hashstr = tag->getString("hash", "md5");
-
- // Load trusted CA and revocation list, if set
std::auto_ptr<X509CertList> ca;
std::auto_ptr<X509CRL> crl;
- std::string filename = tag->getString("cafile");
- if (!filename.empty())
- {
- ca.reset(new X509CertList(ReadFile(filename)));
- filename = tag->getString("crlfile");
+ std::string certstr;
+ std::string keystr;
+ std::auto_ptr<DHParams> dh;
+
+ std::string priostr;
+ unsigned int mindh;
+ std::string hashstr;
+
+ unsigned int outrecsize;
+ bool requestclientcert;
+
+ Config(const std::string& profilename, ConfigTag* tag)
+ : name(profilename)
+ , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
+ , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
+ , dh(DHParams::Import(ReadFile(tag->getString("dhfile", "dhparams.pem"))))
+ , priostr(GetPrioStr(profilename, tag))
+ , mindh(tag->getInt("mindhbits", 1024))
+ , hashstr(tag->getString("hash", "md5"))
+ , requestclientcert(tag->getBool("requestclientcert", true))
+ {
+ // Load trusted CA and revocation list, if set
+ std::string filename = tag->getString("cafile");
if (!filename.empty())
- crl.reset(new X509CRL(ReadFile(filename)));
- }
+ {
+ ca.reset(new X509CertList(ReadFile(filename)));
+
+ filename = tag->getString("crlfile");
+ if (!filename.empty())
+ crl.reset(new X509CRL(ReadFile(filename)));
+ }
#ifdef INSPIRCD_GNUTLS_HAS_CORK
- // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked
- unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512);
+ // If cork support is available outrecsize represents the (rough) max amount of data we give GnuTLS while corked
+ outrecsize = tag->getInt("outrecsize", 2048, 512);
#else
- unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
+ outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
#endif
+ }
+ };
- const bool requestclientcert = tag->getBool("requestclientcert", true);
-
- return new Profile(profilename, certstr, keystr, dh, mindh, hashstr, priostr, ca, crl, outrecsize, requestclientcert);
+ Profile(Config& config)
+ : name(config.name)
+ , x509cred(config.certstr, config.keystr)
+ , min_dh_bits(config.mindh)
+ , hash(config.hashstr)
+ , priority(config.priostr)
+ , outrecsize(config.outrecsize)
+ , requestclientcert(config.requestclientcert)
+ {
+ x509cred.SetDH(config.dh);
+ x509cred.SetCA(config.ca, config.crl);
}
-
/** Set up the given session with the settings in this profile
*/
void SetupSession(gnutls_session_t sess)
@@ -708,7 +716,6 @@ class GnuTLSIOHook : public SSLIOHook
private:
gnutls_session_t sess;
issl_status status;
- reference<GnuTLS::Profile> profile;
#ifdef INSPIRCD_GNUTLS_HAS_CORK
size_t gbuffersize;
#endif
@@ -855,7 +862,7 @@ class GnuTLSIOHook : public SSLIOHook
issuer.clear();
}
- if ((ret = gnutls_x509_crt_get_fingerprint(cert, profile->GetHash(), digest, &digest_size)) < 0)
+ if ((ret = gnutls_x509_crt_get_fingerprint(cert, GetProfile().GetHash(), digest, &digest_size)) < 0)
{
certinfo->error = gnutls_strerror(ret);
}
@@ -1043,11 +1050,10 @@ info_done_dealloc:
#endif // INSPIRCD_GNUTLS_HAS_VECTOR_PUSH
public:
- GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags, const reference<GnuTLS::Profile>& sslprofile)
+ GnuTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, inspircd_gnutls_session_init_flags_t flags)
: SSLIOHook(hookprov)
, sess(NULL)
, status(ISSL_NONE)
- , profile(sslprofile)
#ifdef INSPIRCD_GNUTLS_HAS_CORK
, gbuffersize(0)
#endif
@@ -1060,7 +1066,7 @@ info_done_dealloc:
gnutls_transport_set_push_function(sess, gnutls_push_wrapper);
#endif
gnutls_transport_set_pull_function(sess, gnutls_pull_wrapper);
- profile->SetupSession(sess);
+ GetProfile().SetupSession(sess);
sock->AddIOHook(this);
Handshake(sock);
@@ -1132,7 +1138,7 @@ info_done_dealloc:
// GnuTLS buffer is empty but sendq is not, begin sending data from the sendq
gnutls_record_cork(this->sess);
- while ((!sendq.empty()) && (gbuffersize < profile->GetOutgoingRecordSize()))
+ while ((!sendq.empty()) && (gbuffersize < GetProfile().GetOutgoingRecordSize()))
{
const StreamSocket::SendQueue::Element& elem = sendq.front();
gbuffersize += elem.length();
@@ -1150,7 +1156,7 @@ info_done_dealloc:
while (!sendq.empty())
{
- FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
+ FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
const StreamSocket::SendQueue::Element& buffer = sendq.front();
ret = HandleWriteRet(user, gnutls_record_send(this->sess, buffer.data(), buffer.length()));
@@ -1201,7 +1207,7 @@ info_done_dealloc:
return true;
}
- GnuTLS::Profile* GetProfile() { return profile; }
+ GnuTLS::Profile& GetProfile();
bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
};
@@ -1214,7 +1220,7 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
st->key_type = GNUTLS_PRIVKEY_X509;
#endif
StreamSocket* sock = reinterpret_cast<StreamSocket*>(gnutls_transport_get_ptr(sess));
- GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile()->GetX509Credentials();
+ GnuTLS::X509Credentials& cred = static_cast<GnuTLSIOHook*>(sock->GetModHook(thismod))->GetProfile().GetX509Credentials();
st->ncerts = cred.certs.size();
st->cert.x509 = cred.certs.raw();
@@ -1224,14 +1230,14 @@ int GnuTLS::X509Credentials::cert_callback(gnutls_session_t sess, const gnutls_d
return 0;
}
-class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider
+class GnuTLSIOHookProvider : public IOHookProvider
{
- reference<GnuTLS::Profile> profile;
+ GnuTLS::Profile profile;
public:
- GnuTLSIOHookProvider(Module* mod, reference<GnuTLS::Profile>& prof)
- : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
- , profile(prof)
+ GnuTLSIOHookProvider(Module* mod, GnuTLS::Profile::Config& config)
+ : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
+ , profile(config)
{
ServerInstance->Modules->AddService(*this);
}
@@ -1243,15 +1249,23 @@ class GnuTLSIOHookProvider : public refcountbase, public IOHookProvider
void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
{
- new GnuTLSIOHook(this, sock, GNUTLS_SERVER, profile);
+ new GnuTLSIOHook(this, sock, GNUTLS_SERVER);
}
void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
{
- new GnuTLSIOHook(this, sock, GNUTLS_CLIENT, profile);
+ new GnuTLSIOHook(this, sock, GNUTLS_CLIENT);
}
+
+ GnuTLS::Profile& GetProfile() { return profile; }
};
+GnuTLS::Profile& GnuTLSIOHook::GetProfile()
+{
+ IOHookProvider* hookprov = prov;
+ return static_cast<GnuTLSIOHookProvider*>(hookprov)->GetProfile();
+}
+
class ModuleSSLGnuTLS : public Module
{
typedef std::vector<reference<GnuTLSIOHookProvider> > ProfileList;
@@ -1278,8 +1292,8 @@ class ModuleSSLGnuTLS : public Module
try
{
- reference<GnuTLS::Profile> profile(GnuTLS::Profile::Create(defname, tag));
- newprofiles.push_back(new GnuTLSIOHookProvider(this, profile));
+ GnuTLS::Profile::Config profileconfig(defname, tag);
+ newprofiles.push_back(new GnuTLSIOHookProvider(this, profileconfig));
}
catch (CoreException& ex)
{
@@ -1300,21 +1314,28 @@ class ModuleSSLGnuTLS : public Module
continue;
}
- reference<GnuTLS::Profile> profile;
+ reference<GnuTLSIOHookProvider> prov;
try
{
- profile = GnuTLS::Profile::Create(name, tag);
+ GnuTLS::Profile::Config profileconfig(name, tag);
+ prov = new GnuTLSIOHookProvider(this, profileconfig);
}
catch (CoreException& ex)
{
throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
}
- newprofiles.push_back(new GnuTLSIOHookProvider(this, profile));
+ newprofiles.push_back(prov);
}
// New profiles are ok, begin using them
// Old profiles are deleted when their refcount drops to zero
+ for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+ {
+ GnuTLSIOHookProvider& prov = **i;
+ ServerInstance->Modules.DelService(prov);
+ }
+
profiles.swap(newprofiles);
}
diff --git a/src/modules/extra/m_ssl_mbedtls.cpp b/src/modules/extra/m_ssl_mbedtls.cpp
index 4e0032fdc..8c15342f2 100644
--- a/src/modules/extra/m_ssl_mbedtls.cpp
+++ b/src/modules/extra/m_ssl_mbedtls.cpp
@@ -345,7 +345,7 @@ namespace mbedTLS
}
};
- class Profile : public refcountbase
+ class Profile
{
/** Name of this profile
*/
@@ -378,29 +378,71 @@ namespace mbedTLS
*/
const unsigned int outrecsize;
- Profile(const std::string& profilename, const std::string& certstr, const std::string& keystr,
- const std::string& dhstr, unsigned int mindh, const std::string& hashstr,
- const std::string& ciphersuitestr, const std::string& curvestr,
- const std::string& castr, const std::string& crlstr,
- unsigned int recsize,
- CTRDRBG& ctrdrbg,
- int minver, int maxver,
- bool requestclientcert
- )
- : name(profilename)
- , x509cred(certstr, keystr)
- , ciphersuites(ciphersuitestr)
- , curves(curvestr)
- , serverctx(ctrdrbg, MBEDTLS_SSL_IS_SERVER)
- , clientctx(ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
- , cacerts(castr, true)
- , crl(crlstr)
- , hash(hashstr)
- , outrecsize(recsize)
+ public:
+ struct Config
+ {
+ const std::string name;
+
+ CTRDRBG& ctrdrbg;
+
+ const std::string certstr;
+ const std::string keystr;
+ const std::string dhstr;
+
+ const std::string ciphersuitestr;
+ const std::string curvestr;
+ const unsigned int mindh;
+ const std::string hashstr;
+
+ std::string crlstr;
+ std::string castr;
+
+ const int minver;
+ const int maxver;
+ const unsigned int outrecsize;
+ const bool requestclientcert;
+
+ Config(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
+ : name(profilename)
+ , ctrdrbg(ctr_drbg)
+ , certstr(ReadFile(tag->getString("certfile", "cert.pem")))
+ , keystr(ReadFile(tag->getString("keyfile", "key.pem")))
+ , dhstr(ReadFile(tag->getString("dhfile", "dhparams.pem")))
+ , ciphersuitestr(tag->getString("ciphersuites"))
+ , curvestr(tag->getString("curves"))
+ , mindh(tag->getInt("mindhbits", 2048))
+ , hashstr(tag->getString("hash", "sha256"))
+ , castr(tag->getString("cafile"))
+ , minver(tag->getInt("minver"))
+ , maxver(tag->getInt("maxver"))
+ , outrecsize(tag->getInt("outrecsize", 2048, 512, 16384))
+ , requestclientcert(tag->getBool("requestclientcert", true))
+ {
+ if (!castr.empty())
+ {
+ castr = ReadFile(castr);
+ crlstr = tag->getString("crlfile");
+ if (!crlstr.empty())
+ crlstr = ReadFile(crlstr);
+ }
+ }
+ };
+
+ Profile(Config& config)
+ : name(config.name)
+ , x509cred(config.certstr, config.keystr)
+ , ciphersuites(config.ciphersuitestr)
+ , curves(config.curvestr)
+ , serverctx(config.ctrdrbg, MBEDTLS_SSL_IS_SERVER)
+ , clientctx(config.ctrdrbg, MBEDTLS_SSL_IS_CLIENT)
+ , cacerts(config.castr, true)
+ , crl(config.crlstr)
+ , hash(config.hashstr)
+ , outrecsize(config.outrecsize)
{
serverctx.SetX509CertAndKey(x509cred);
clientctx.SetX509CertAndKey(x509cred);
- clientctx.SetMinDHBits(mindh);
+ clientctx.SetMinDHBits(config.mindh);
if (!ciphersuites.empty())
{
@@ -414,19 +456,19 @@ namespace mbedTLS
clientctx.SetCurves(curves);
}
- serverctx.SetVersion(minver, maxver);
- clientctx.SetVersion(minver, maxver);
+ serverctx.SetVersion(config.minver, config.maxver);
+ clientctx.SetVersion(config.minver, config.maxver);
- if (!dhstr.empty())
+ if (!config.dhstr.empty())
{
- dhparams.set(dhstr);
+ dhparams.set(config.dhstr);
serverctx.SetDHParams(dhparams);
}
clientctx.SetOptionalVerifyCert();
clientctx.SetCA(cacerts, crl);
// The default for servers is to not request a client certificate from the peer
- if (requestclientcert)
+ if (config.requestclientcert)
{
serverctx.SetOptionalVerifyCert();
serverctx.SetCA(cacerts, crl);
@@ -442,35 +484,6 @@ namespace mbedTLS
return ret;
}
- public:
- static reference<Profile> Create(const std::string& profilename, ConfigTag* tag, CTRDRBG& ctr_drbg)
- {
- const std::string certstr = ReadFile(tag->getString("certfile", "cert.pem"));
- const std::string keystr = ReadFile(tag->getString("keyfile", "key.pem"));
- const std::string dhstr = ReadFile(tag->getString("dhfile", "dhparams.pem"));
-
- const std::string ciphersuitestr = tag->getString("ciphersuites");
- const std::string curvestr = tag->getString("curves");
- unsigned int mindh = tag->getInt("mindhbits", 2048);
- std::string hashstr = tag->getString("hash", "sha256");
-
- std::string crlstr;
- std::string castr = tag->getString("cafile");
- if (!castr.empty())
- {
- castr = ReadFile(castr);
- crlstr = tag->getString("crlfile");
- if (!crlstr.empty())
- crlstr = ReadFile(crlstr);
- }
-
- int minver = tag->getInt("minver");
- int maxver = tag->getInt("maxver");
- unsigned int outrecsize = tag->getInt("outrecsize", 2048, 512, 16384);
- const bool requestclientcert = tag->getBool("requestclientcert", true);
- return new Profile(profilename, certstr, keystr, dhstr, mindh, hashstr, ciphersuitestr, curvestr, castr, crlstr, outrecsize, ctr_drbg, minver, maxver, requestclientcert);
- }
-
/** Set up the given session with the settings in this profile
*/
void SetupClientSession(mbedtls_ssl_context* sess)
@@ -501,7 +514,6 @@ class mbedTLSIOHook : public SSLIOHook
mbedtls_ssl_context sess;
Status status;
- reference<mbedTLS::Profile> profile;
void CloseSession()
{
@@ -575,7 +587,7 @@ class mbedTLSIOHook : public SSLIOHook
}
// If there is a certificate we can always generate a fingerprint
- certificate->fingerprint = profile->GetHash().hash(cert->raw.p, cert->raw.len);
+ certificate->fingerprint = GetProfile().GetHash().hash(cert->raw.p, cert->raw.len);
// At this point mbedTLS verified the cert already, we just need to check the results
const uint32_t flags = mbedtls_ssl_get_verify_result(&sess);
@@ -649,16 +661,15 @@ class mbedTLSIOHook : public SSLIOHook
}
public:
- mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver, mbedTLS::Profile* sslprofile)
+ mbedTLSIOHook(IOHookProvider* hookprov, StreamSocket* sock, bool isserver)
: SSLIOHook(hookprov)
, status(ISSL_NONE)
- , profile(sslprofile)
{
mbedtls_ssl_init(&sess);
if (isserver)
- profile->SetupServerSession(&sess);
+ GetProfile().SetupServerSession(&sess);
else
- profile->SetupClientSession(&sess);
+ GetProfile().SetupClientSession(&sess);
mbedtls_ssl_set_bio(&sess, reinterpret_cast<void*>(sock), Push, Pull, NULL);
@@ -725,7 +736,7 @@ class mbedTLSIOHook : public SSLIOHook
// Session is ready for transferring application data
while (!sendq.empty())
{
- FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
+ FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
const StreamSocket::SendQueue::Element& buffer = sendq.front();
int ret = mbedtls_ssl_write(&sess, reinterpret_cast<const unsigned char*>(buffer.data()), buffer.length());
if (ret == (int)buffer.length())
@@ -788,17 +799,18 @@ class mbedTLSIOHook : public SSLIOHook
return false;
}
+ mbedTLS::Profile& GetProfile();
bool IsHandshakeDone() const { return (status == ISSL_HANDSHAKEN); }
};
-class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
+class mbedTLSIOHookProvider : public IOHookProvider
{
- reference<mbedTLS::Profile> profile;
+ mbedTLS::Profile profile;
public:
- mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile* prof)
- : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
- , profile(prof)
+ mbedTLSIOHookProvider(Module* mod, mbedTLS::Profile::Config& config)
+ : IOHookProvider(mod, "ssl/" + config.name, IOHookProvider::IOH_SSL)
+ , profile(config)
{
ServerInstance->Modules->AddService(*this);
}
@@ -810,15 +822,23 @@ class mbedTLSIOHookProvider : public refcountbase, public IOHookProvider
void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
{
- new mbedTLSIOHook(this, sock, true, profile);
+ new mbedTLSIOHook(this, sock, true);
}
void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
{
- new mbedTLSIOHook(this, sock, false, profile);
+ new mbedTLSIOHook(this, sock, false);
}
+
+ mbedTLS::Profile& GetProfile() { return profile; }
};
+mbedTLS::Profile& mbedTLSIOHook::GetProfile()
+{
+ IOHookProvider* hookprov = prov;
+ return static_cast<mbedTLSIOHookProvider*>(hookprov)->GetProfile();
+}
+
class ModuleSSLmbedTLS : public Module
{
typedef std::vector<reference<mbedTLSIOHookProvider> > ProfileList;
@@ -844,8 +864,8 @@ class ModuleSSLmbedTLS : public Module
try
{
- reference<mbedTLS::Profile> profile(mbedTLS::Profile::Create(defname, tag, ctr_drbg));
- newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
+ mbedTLS::Profile::Config profileconfig(defname, tag, ctr_drbg);
+ newprofiles.push_back(new mbedTLSIOHookProvider(this, profileconfig));
}
catch (CoreException& ex)
{
@@ -866,21 +886,28 @@ class ModuleSSLmbedTLS : public Module
continue;
}
- reference<mbedTLS::Profile> profile;
+ reference<mbedTLSIOHookProvider> prov;
try
{
- profile = mbedTLS::Profile::Create(name, tag, ctr_drbg);
+ mbedTLS::Profile::Config profileconfig(name, tag, ctr_drbg);
+ prov = new mbedTLSIOHookProvider(this, profileconfig);
}
catch (CoreException& ex)
{
throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
}
- newprofiles.push_back(new mbedTLSIOHookProvider(this, profile));
+ newprofiles.push_back(prov);
}
// New profiles are ok, begin using them
// Old profiles are deleted when their refcount drops to zero
+ for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+ {
+ mbedTLSIOHookProvider& prov = **i;
+ ServerInstance->Modules.DelService(prov);
+ }
+
profiles.swap(newprofiles);
}
diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp
index 9b7e608a2..ae5e213b7 100644
--- a/src/modules/extra/m_ssl_openssl.cpp
+++ b/src/modules/extra/m_ssl_openssl.cpp
@@ -240,7 +240,7 @@ namespace OpenSSL
}
};
- class Profile : public refcountbase
+ class Profile
{
/** Name of this profile
*/
@@ -459,7 +459,6 @@ class OpenSSLIOHook : public SSLIOHook
SSL* sess;
issl_status status;
bool data_to_write;
- reference<OpenSSL::Profile> profile;
// Returns 1 if handshake succeeded, 0 if it is still in progress, -1 if it failed
int Handshake(StreamSocket* user)
@@ -559,7 +558,7 @@ class OpenSSLIOHook : public SSLIOHook
if (certinfo->issuer.find_first_of("\r\n") != std::string::npos)
certinfo->issuer.clear();
- if (!X509_digest(cert, profile->GetDigest(), md, &n))
+ if (!X509_digest(cert, GetProfile().GetDigest(), md, &n))
{
certinfo->error = "Out of memory generating fingerprint";
}
@@ -580,7 +579,7 @@ class OpenSSLIOHook : public SSLIOHook
{
if ((where & SSL_CB_HANDSHAKE_START) && (status == ISSL_OPEN))
{
- if (profile->AllowRenegotiation())
+ if (GetProfile().AllowRenegotiation())
return;
// The other side is trying to renegotiate, kill the connection and change status
@@ -622,12 +621,11 @@ class OpenSSLIOHook : public SSLIOHook
friend void StaticSSLInfoCallback(const SSL* ssl, int where, int rc);
public:
- OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session, const reference<OpenSSL::Profile>& sslprofile)
+ OpenSSLIOHook(IOHookProvider* hookprov, StreamSocket* sock, SSL* session)
: SSLIOHook(hookprov)
, sess(session)
, status(ISSL_NONE)
, data_to_write(false)
- , profile(sslprofile)
{
// Create BIO instance and store a pointer to the socket in it which will be used by the read and write functions
#ifdef INSPIRCD_OPENSSL_OPAQUE_BIO
@@ -721,7 +719,7 @@ class OpenSSLIOHook : public SSLIOHook
while (!sendq.empty())
{
ERR_clear_error();
- FlattenSendQueue(sendq, profile->GetOutgoingRecordSize());
+ FlattenSendQueue(sendq, GetProfile().GetOutgoingRecordSize());
const StreamSocket::SendQueue::Element& buffer = sendq.front();
int ret = SSL_write(sess, buffer.data(), buffer.size());
@@ -790,6 +788,7 @@ class OpenSSLIOHook : public SSLIOHook
}
bool IsHandshakeDone() const { return (status == ISSL_OPEN); }
+ OpenSSL::Profile& GetProfile();
};
static void StaticSSLInfoCallback(const SSL* ssl, int where, int rc)
@@ -844,14 +843,14 @@ static int OpenSSL::BIOMethod::read(BIO* bio, char* buffer, int size)
return ret;
}
-class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider
+class OpenSSLIOHookProvider : public IOHookProvider
{
- reference<OpenSSL::Profile> profile;
+ OpenSSL::Profile profile;
public:
- OpenSSLIOHookProvider(Module* mod, reference<OpenSSL::Profile>& prof)
- : IOHookProvider(mod, "ssl/" + prof->GetName(), IOHookProvider::IOH_SSL)
- , profile(prof)
+ OpenSSLIOHookProvider(Module* mod, const std::string& profilename, ConfigTag* tag)
+ : IOHookProvider(mod, "ssl/" + profilename, IOHookProvider::IOH_SSL)
+ , profile(profilename, tag)
{
ServerInstance->Modules->AddService(*this);
}
@@ -863,15 +862,23 @@ class OpenSSLIOHookProvider : public refcountbase, public IOHookProvider
void OnAccept(StreamSocket* sock, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) CXX11_OVERRIDE
{
- new OpenSSLIOHook(this, sock, profile->CreateServerSession(), profile);
+ new OpenSSLIOHook(this, sock, profile.CreateServerSession());
}
void OnConnect(StreamSocket* sock) CXX11_OVERRIDE
{
- new OpenSSLIOHook(this, sock, profile->CreateClientSession(), profile);
+ new OpenSSLIOHook(this, sock, profile.CreateClientSession());
}
+
+ OpenSSL::Profile& GetProfile() { return profile; }
};
+OpenSSL::Profile& OpenSSLIOHook::GetProfile()
+{
+ IOHookProvider* hookprov = prov;
+ return static_cast<OpenSSLIOHookProvider*>(hookprov)->GetProfile();
+}
+
class ModuleSSLOpenSSL : public Module
{
typedef std::vector<reference<OpenSSLIOHookProvider> > ProfileList;
@@ -891,8 +898,7 @@ class ModuleSSLOpenSSL : public Module
try
{
- reference<OpenSSL::Profile> profile(new OpenSSL::Profile(defname, tag));
- newprofiles.push_back(new OpenSSLIOHookProvider(this, profile));
+ newprofiles.push_back(new OpenSSLIOHookProvider(this, defname, tag));
}
catch (OpenSSL::Exception& ex)
{
@@ -913,17 +919,23 @@ class ModuleSSLOpenSSL : public Module
continue;
}
- reference<OpenSSL::Profile> profile;
+ reference<OpenSSLIOHookProvider> prov;
try
{
- profile = new OpenSSL::Profile(name, tag);
+ prov = new OpenSSLIOHookProvider(this, name, tag);
}
catch (CoreException& ex)
{
throw ModuleException("Error while initializing SSL profile \"" + name + "\" at " + tag->getTagLocation() + " - " + ex.GetReason());
}
- newprofiles.push_back(new OpenSSLIOHookProvider(this, profile));
+ newprofiles.push_back(prov);
+ }
+
+ for (ProfileList::iterator i = profiles.begin(); i != profiles.end(); ++i)
+ {
+ OpenSSLIOHookProvider& prov = **i;
+ ServerInstance->Modules.DelService(prov);
}
profiles.swap(newprofiles);
diff --git a/src/modules/m_websocket.cpp b/src/modules/m_websocket.cpp
index a7457f788..12102d215 100644
--- a/src/modules/m_websocket.cpp
+++ b/src/modules/m_websocket.cpp
@@ -376,12 +376,12 @@ void WebSocketHookProvider::OnAccept(StreamSocket* sock, irc::sockets::sockaddrs
class ModuleWebSocket : public Module
{
dynamic_reference_nocheck<HashProvider> hash;
- WebSocketHookProvider hookprov;
+ reference<WebSocketHookProvider> hookprov;
public:
ModuleWebSocket()
: hash(this, "hash/sha1")
- , hookprov(this)
+ , hookprov(new WebSocketHookProvider(this))
{
sha1 = &hash;
}