From eaace5ed7cef3a02f905689a1b44a092ca99a6e1 Mon Sep 17 00:00:00 2001 From: danieldg Date: Fri, 6 Nov 2009 22:37:52 +0000 Subject: Remove Extensible parent from EventHandler This also fixes SSL certificate support when m_sslinfo is not loaded git-svn-id: http://svn.inspircd.org/repository/trunk/inspircd@12048 e03df62e-2008-0410-955e-edbf42e46eb7 --- include/socketengine.h | 2 +- src/modules/extra/m_ssl_gnutls.cpp | 43 +++++++++++++++------------------ src/modules/extra/m_ssl_openssl.cpp | 40 ++++++++++++++++++------------ src/modules/m_spanningtree/hmac.cpp | 2 +- src/modules/m_spanningtree/main.cpp | 8 ++++++ src/modules/m_spanningtree/netburst.cpp | 4 +-- src/modules/m_sslinfo.cpp | 14 +++++++---- src/modules/m_sslmodes.cpp | 6 ++--- src/modules/ssl.h | 31 ++++++++++++++++++++---- 9 files changed, 95 insertions(+), 55 deletions(-) diff --git a/include/socketengine.h b/include/socketengine.h index 01f66ef21..b411b394a 100644 --- a/include/socketengine.h +++ b/include/socketengine.h @@ -151,7 +151,7 @@ enum EventMask * must have a file descriptor. What this file descriptor * is actually attached to is completely up to you. */ -class CoreExport EventHandler : public Extensible +class CoreExport EventHandler : public classbase { private: /** Private state maintained by socket engine */ diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index bd22404b3..8ec787465 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -77,13 +77,10 @@ static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t user_wrap, const void* class issl_session { public: - issl_session() - { - sess = NULL; - } - gnutls_session_t sess; issl_status status; + reference cert; + issl_session() : sess(NULL) {} }; class CommandStartTLS : public SplitCommand @@ -332,11 +329,15 @@ class ModuleSSLGnuTLS : public Module void OnRequest(Request& request) { - Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); - if (sslinfo) - sslinfo->OnRequest(request); - } + if (strcmp("GET_SSL_CERT", request.id) == 0) + { + SocketCertificateRequest& req = static_cast(request); + int fd = req.sock->GetFd(); + issl_session* session = &sessions[fd]; + req.cert = session->cert; + } + } void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) { @@ -548,10 +549,11 @@ class ModuleSSLGnuTLS : public Module void OnUserConnect(LocalUser* user) { - if (user->GetIOHook() == this) + if (user->eh.GetIOHook() == this) { if (sessions[user->GetFd()].sess) { + SSLCertSubmission(user, this, ServerInstance->Modules->Find("m_sslinfo.so"), sessions[user->GetFd()].cert); std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess)); cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-"); cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess))); @@ -562,23 +564,19 @@ class ModuleSSLGnuTLS : public Module void CloseSession(issl_session* session) { - if(session->sess) + if (session->sess) { gnutls_bye(session->sess, GNUTLS_SHUT_WR); gnutls_deinit(session->sess); } - session->sess = NULL; + session->cert = NULL; session->status = ISSL_NONE; } - void VerifyCertificate(issl_session* session, Extensible* user) + void VerifyCertificate(issl_session* session, StreamSocket* user) { - if (!session->sess || !user) - return; - - Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); - if (!sslinfo) + if (!session->sess || !user || session->cert) return; unsigned int status; @@ -591,6 +589,7 @@ class ModuleSSLGnuTLS : public Module size_t digest_size = sizeof(digest); size_t name_size = sizeof(name); ssl_cert* certinfo = new ssl_cert; + session->cert = certinfo; /* This verification function uses the trusted CAs in the credentials * structure. So you must have installed one or more CA certificates. @@ -600,7 +599,7 @@ class ModuleSSLGnuTLS : public Module if (ret < 0) { certinfo->error = std::string(gnutls_strerror(ret)); - goto info_done; + return; } certinfo->invalid = (status & GNUTLS_CERT_INVALID); @@ -615,14 +614,14 @@ class ModuleSSLGnuTLS : public Module if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) { certinfo->error = "No X509 keys sent"; - goto info_done; + return; } ret = gnutls_x509_crt_init(&cert); if (ret < 0) { certinfo->error = gnutls_strerror(ret); - goto info_done; + return; } cert_list_size = 0; @@ -668,8 +667,6 @@ class ModuleSSLGnuTLS : public Module info_done_dealloc: gnutls_x509_crt_deinit(cert); -info_done: - SSLCertSubmission(user, this, sslinfo, certinfo); } void OnEvent(Event& ev) diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index 5eac93af6..03c460be2 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -53,6 +53,7 @@ class issl_session public: SSL* sess; issl_status status; + reference cert; int fd; bool outbound; @@ -125,7 +126,7 @@ class ModuleSSLOpenSSL : public Module // Needs the flag as it ignores a plain /rehash OnModuleRehash(NULL,"ssl"); - Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnHookIO }; + Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnHookIO, I_OnUserConnect }; ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); } @@ -244,6 +245,17 @@ class ModuleSSLOpenSSL : public Module delete[] sessions; } + void OnUserConnect(LocalUser* user) + { + if (user->eh.GetIOHook() == this) + { + if (sessions[user->GetFd()].sess) + { + SSLCertSubmission(user, this, ServerInstance->Modules->Find("m_sslinfo.so"), sessions[user->GetFd()].cert); + } + } + } + void OnCleanup(int target_type, void* item) { if (target_type == TYPE_USER) @@ -264,14 +276,17 @@ class ModuleSSLOpenSSL : public Module return Version("Provides SSL support for clients", VF_VENDOR); } - void OnRequest(Request& request) { - Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); - if (sslinfo) - sslinfo->OnRequest(request); - } + if (strcmp("GET_SSL_CERT", request.id) == 0) + { + SocketCertificateRequest& req = static_cast(request); + int fd = req.sock->GetFd(); + issl_session* session = &sessions[fd]; + req.cert = session->cert; + } + } void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) { @@ -472,7 +487,7 @@ class ModuleSSLOpenSSL : public Module return 0; } - bool Handshake(EventHandler* user, issl_session* session) + bool Handshake(StreamSocket* user, issl_session* session) { int ret; @@ -537,17 +552,14 @@ class ModuleSSLOpenSSL : public Module errno = EIO; } - void VerifyCertificate(issl_session* session, Extensible* user) + void VerifyCertificate(issl_session* session, StreamSocket* user) { - if (!session->sess || !user) - return; - - Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); - if (!sslinfo) + if (!session->sess || !user || session->cert) return; X509* cert; ssl_cert* certinfo = new ssl_cert; + session->cert = certinfo; unsigned int n; unsigned char md[EVP_MAX_MD_SIZE]; const EVP_MD *digest = EVP_md5(); @@ -557,7 +569,6 @@ class ModuleSSLOpenSSL : public Module if (!cert) { certinfo->error = "Could not get peer certificate: "+std::string(get_error()); - SSLCertSubmission(user, this, sslinfo, certinfo); return; } @@ -592,7 +603,6 @@ class ModuleSSLOpenSSL : public Module } X509_free(cert); - SSLCertSubmission(user, this, sslinfo, certinfo); } }; diff --git a/src/modules/m_spanningtree/hmac.cpp b/src/modules/m_spanningtree/hmac.cpp index bf98b16d1..b5a5fa228 100644 --- a/src/modules/m_spanningtree/hmac.cpp +++ b/src/modules/m_spanningtree/hmac.cpp @@ -128,7 +128,7 @@ bool TreeSocket::ComparePass(const Link& link, const std::string &theirs) std::string fp; if (GetIOHook()) { - SSLCertificateRequest req(this, Utils->Creator); + SocketCertificateRequest req(this, Utils->Creator, GetIOHook()); if (req.cert) { fp = req.cert->GetFingerprint(); diff --git a/src/modules/m_spanningtree/main.cpp b/src/modules/m_spanningtree/main.cpp index 8da34af53..b00107023 100644 --- a/src/modules/m_spanningtree/main.cpp +++ b/src/modules/m_spanningtree/main.cpp @@ -610,6 +610,14 @@ void ModuleSpanningTree::OnUserConnect(LocalUser* user) params.push_back(":"+std::string(user->fullname)); Utils->DoOneToMany(ServerInstance->Config->GetSID(), "UID", params); + for(Extensible::ExtensibleStore::const_iterator i = user->GetExtList().begin(); i != user->GetExtList().end(); i++) + { + ExtensionItem* item = i->first; + std::string value = item->serialize(FORMAT_NETWORK, user, i->second); + if (!value.empty()) + ProtoSendMetaData(this, user, item->key, value); + } + Utils->TreeRoot->SetUserCount(1); // increment by 1 } diff --git a/src/modules/m_spanningtree/netburst.cpp b/src/modules/m_spanningtree/netburst.cpp index 981c903a1..52ce5897d 100644 --- a/src/modules/m_spanningtree/netburst.cpp +++ b/src/modules/m_spanningtree/netburst.cpp @@ -220,7 +220,7 @@ void TreeSocket::SendChannelModes(TreeServer* Current) this->WriteLine(data); } - for(ExtensibleStore::const_iterator i = c->second->GetExtList().begin(); i != c->second->GetExtList().end(); i++) + for(Extensible::ExtensibleStore::const_iterator i = c->second->GetExtList().begin(); i != c->second->GetExtList().end(); i++) { ExtensionItem* item = i->first; std::string value = item->serialize(FORMAT_NETWORK, c->second, i->second); @@ -269,7 +269,7 @@ void TreeSocket::SendUsers(TreeServer* Current) } } - for(ExtensibleStore::const_iterator i = u->second->GetExtList().begin(); i != u->second->GetExtList().end(); i++) + for(Extensible::ExtensibleStore::const_iterator i = u->second->GetExtList().begin(); i != u->second->GetExtList().end(); i++) { ExtensionItem* item = i->first; std::string value = item->serialize(FORMAT_NETWORK, u->second, i->second); diff --git a/src/modules/m_sslinfo.cpp b/src/modules/m_sslinfo.cpp index 0ab749703..7457ce296 100644 --- a/src/modules/m_sslinfo.cpp +++ b/src/modules/m_sslinfo.cpp @@ -25,8 +25,10 @@ class SSLCertExt : public ExtensionItem { } void set(Extensible* item, ssl_cert* value) { + value->refcount_inc(); ssl_cert* old = static_cast(set_raw(item, value)); - delete old; + if (old && old->refcount_dec()) + delete old; } std::string serialize(SerializeFormat format, const Extensible* container, void* item) const @@ -61,7 +63,9 @@ class SSLCertExt : public ExtensionItem { void free(void* item) { - delete static_cast(item); + ssl_cert* old = static_cast(item); + if (old && old->refcount_dec()) + delete old; } }; @@ -228,10 +232,10 @@ class ModuleSSLInfo : public Module void OnRequest(Request& request) { - if (strcmp("GET_CERT", request.id) == 0) + if (strcmp("GET_USER_CERT", request.id) == 0) { - SSLCertificateRequest& req = static_cast(request); - req.cert = cmd.CertExt.get(req.item); + UserCertificateRequest& req = static_cast(request); + req.cert = cmd.CertExt.get(req.user); } else if (strcmp("SET_CERT", request.id) == 0) { diff --git a/src/modules/m_sslmodes.cpp b/src/modules/m_sslmodes.cpp index 5a3a5e712..0748c5bd5 100644 --- a/src/modules/m_sslmodes.cpp +++ b/src/modules/m_sslmodes.cpp @@ -34,7 +34,7 @@ class SSLMode : public ModeHandler const UserMembList* userlist = channel->GetUsers(); for(UserMembCIter i = userlist->begin(); i != userlist->end(); i++) { - SSLCertificateRequest req(i->first, creator); + UserCertificateRequest req(i->first, creator); req.Send(); if(!req.cert && !ServerInstance->ULine(i->first->server)) { @@ -83,7 +83,7 @@ class ModuleSSLModes : public Module { if(chan && chan->IsModeSet('z')) { - SSLCertificateRequest req(user, this); + UserCertificateRequest req(user, this); req.Send(); if (req.cert) { @@ -105,7 +105,7 @@ class ModuleSSLModes : public Module { if (mask[0] == 'z' && mask[1] == ':') { - SSLCertificateRequest req(user, this); + UserCertificateRequest req(user, this); req.Send(); if (req.cert && InspIRCd::Match(req.cert->GetFingerprint(), mask.substr(2))) return MOD_RES_DENY; diff --git a/src/modules/ssl.h b/src/modules/ssl.h index 68f1910ff..a01d91430 100644 --- a/src/modules/ssl.h +++ b/src/modules/ssl.h @@ -25,7 +25,7 @@ * in a unified manner. These classes are attached to ssl- * connected local users using SSLCertExt */ -class ssl_cert +class ssl_cert : public refcountbase { public: std::string dn; @@ -118,13 +118,34 @@ class ssl_cert } }; -struct SSLCertificateRequest : public Request +/** Get certificate from a socket (only useful with an SSL module) */ +struct SocketCertificateRequest : public Request { - Extensible* const item; + StreamSocket* const sock; + ssl_cert* cert; + + SocketCertificateRequest(StreamSocket* ss, Module* Me, Module* hook) + : Request(Me, hook, "GET_SSL_CERT"), sock(ss), cert(NULL) + { + Send(); + } + + std::string GetFingerprint() + { + if (cert) + return cert->GetFingerprint(); + return ""; + } +}; + +/** Get certificate from a user (requires m_sslinfo) */ +struct UserCertificateRequest : public Request +{ + User* const user; ssl_cert* cert; - SSLCertificateRequest(Extensible* e, Module* Me, Module* info = ServerInstance->Modules->Find("m_sslinfo.so")) - : Request(Me, info, "GET_CERT"), item(e), cert(NULL) + UserCertificateRequest(User* u, Module* Me, Module* info = ServerInstance->Modules->Find("m_sslinfo.so")) + : Request(Me, info, "GET_USER_CERT"), user(u), cert(NULL) { Send(); } -- cgit v1.2.3