diff options
Diffstat (limited to 'src/modules/extra')
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 227 | ||||
-rw-r--r-- | src/modules/extra/m_ssl_openssl.cpp | 182 | ||||
-rw-r--r-- | src/modules/extra/m_ziplink.cpp | 111 |
3 files changed, 206 insertions, 314 deletions
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index f458f5da1..27c466573 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -56,7 +56,6 @@ public: gnutls_session_t sess; issl_status status; - std::string outbuf; }; class CommandStartTLS : public Command @@ -83,7 +82,7 @@ class CommandStartTLS : public Command { user->WriteNumeric(670, "%s :STARTTLS successful, go ahead with TLS handshake", user->nick.c_str()); user->AddIOHook(creator); - creator->OnRawSocketAccept(user->GetFd(), NULL, NULL); + creator->OnStreamSocketAccept(user, NULL, NULL); } else user->WriteNumeric(691, "%s :STARTTLS failure", user->nick.c_str()); @@ -133,16 +132,14 @@ class ModuleSSLGnuTLS : public Module // Void return, guess we assume success gnutls_certificate_set_dh_params(x509_cred, dh_params); - Implementation eventlist[] = { I_On005Numeric, I_OnRawSocketConnect, I_OnRawSocketAccept, - I_OnRawSocketClose, I_OnRawSocketRead, I_OnRawSocketWrite, I_OnCleanup, - I_OnBufferFlushed, I_OnRequest, I_OnRehash, I_OnModuleRehash, I_OnPostConnect, + Implementation eventlist[] = { I_On005Numeric, I_OnRequest, I_OnRehash, I_OnModuleRehash, I_OnPostConnect, I_OnEvent, I_OnHookIO }; ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); ServerInstance->AddCommand(&starttls); } - virtual void OnRehash(User* user) + void OnRehash(User* user) { ConfigReader Conf(ServerInstance); @@ -168,7 +165,7 @@ class ModuleSSLGnuTLS : public Module sslports.erase(sslports.end() - 1); } - virtual void OnModuleRehash(User* user, const std::string ¶m) + void OnModuleRehash(User* user, const std::string ¶m) { if(param != "ssl") return; @@ -278,7 +275,7 @@ class ModuleSSLGnuTLS : public Module ServerInstance->Logs->Log("m_ssl_gnutls",DEFAULT, "m_ssl_gnutls.so: Failed to generate DH parameters (%d bits): %s", dh_bits, gnutls_strerror(ret)); } - virtual ~ModuleSSLGnuTLS() + ~ModuleSSLGnuTLS() { gnutls_x509_crt_deinit(x509_cert); gnutls_x509_privkey_deinit(x509_key); @@ -289,7 +286,7 @@ class ModuleSSLGnuTLS : public Module delete[] sessions; } - virtual void OnCleanup(int target_type, void* item) + void OnCleanup(int target_type, void* item) { if(target_type == TYPE_USER) { @@ -305,20 +302,20 @@ class ModuleSSLGnuTLS : public Module } } - virtual Version GetVersion() + Version GetVersion() { return Version("$Id$", VF_VENDOR, API_VERSION); } - virtual void On005Numeric(std::string &output) + void On005Numeric(std::string &output) { if (!sslports.empty()) output.append(" SSL=" + sslports); output.append(" STARTTLS"); } - virtual void OnHookIO(EventHandler* user, ListenSocketBase* lsb) + void OnHookIO(StreamSocket* user, ListenSocketBase* lsb) { if (!user->GetIOHook() && listenports.find(lsb) != listenports.end()) { @@ -327,7 +324,7 @@ class ModuleSSLGnuTLS : public Module } } - virtual const char* OnRequest(Request* request) + const char* OnRequest(Request* request) { ISHRequest* ISR = static_cast<ISHRequest*>(request); if (strcmp("IS_NAME", request->GetId()) == 0) @@ -336,20 +333,13 @@ class ModuleSSLGnuTLS : public Module } else if (strcmp("IS_HOOK", request->GetId()) == 0) { - const char* ret = "OK"; - try - { - ret = ISR->Sock->AddIOHook(this) ? "OK" : NULL; - } - catch (ModuleException &e) - { - return NULL; - } - return ret; + ISR->Sock->AddIOHook(this); + return "OK"; } else if (strcmp("IS_UNHOOK", request->GetId()) == 0) { - return ISR->Sock->DelIOHook() ? "OK" : NULL; + ISR->Sock->DelIOHook(); + return "OK"; } else if (strcmp("IS_HSDONE", request->GetId()) == 0) { @@ -383,12 +373,9 @@ class ModuleSSLGnuTLS : public Module } - virtual void OnRawSocketAccept(int fd, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) + void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) { - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return; - + int fd = user->GetFd(); issl_session* session = &sessions[fd]; /* For STARTTLS: Don't try and init a session on a socket that already has a session */ @@ -405,77 +392,67 @@ class ModuleSSLGnuTLS : public Module gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. - Handshake(session, fd); + Handshake(session, user); } - virtual void OnRawSocketConnect(int fd) + void OnStreamSocketConnect(StreamSocket* user) { - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return; - - issl_session* session = &sessions[fd]; + issl_session* session = &sessions[user->GetFd()]; gnutls_init(&session->sess, GNUTLS_CLIENT); gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate. gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); - gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(fd)); // Give gnutls the fd for the socket. + gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user->GetFd())); - Handshake(session, fd); + Handshake(session, user); } - virtual void OnRawSocketClose(int fd) + void OnStreamSocketClose(StreamSocket* user) { - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds())) - return; - - CloseSession(&sessions[fd]); + CloseSession(&sessions[user->GetFd()]); } - virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) + int OnStreamSocketRead(StreamSocket* user, std::string& recvq) { - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return 0; - - issl_session* session = &sessions[fd]; + issl_session* session = &sessions[user->GetFd()]; if (!session->sess) { - readresult = 0; CloseSession(session); - return 1; + user->SetError("No SSL session"); + return -1; } if (session->status == ISSL_HANDSHAKING_READ) { // The handshake isn't finished, try to finish it. - if(!Handshake(session, fd)) + if(!Handshake(session, user)) { - errno = session->status == ISSL_CLOSING ? EIO : EAGAIN; - // Couldn't resume handshake. + if (session->status != ISSL_CLOSING) + return 0; + user->SetError("Handshake Failed"); return -1; } } else if (session->status == ISSL_HANDSHAKING_WRITE) { - errno = EAGAIN; - MakePollWrite(fd); - return -1; + MakePollWrite(user); + return 0; } // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. if (session->status == ISSL_HANDSHAKEN) { - unsigned int len = 0; - while (len < count) + char* buffer = ServerInstance->GetReadBuffer(); + size_t bufsiz = ServerInstance->Config->NetBufferSize; + size_t len = 0; + while (len < bufsiz) { - int ret = gnutls_record_recv(session->sess, buffer + len, count - len); + int ret = gnutls_record_recv(session->sess, buffer + len, bufsiz - len); if (ret > 0) { len += ret; @@ -484,60 +461,49 @@ class ModuleSSLGnuTLS : public Module { break; } + else if (ret == 0) + { + user->SetError("SSL Connection closed"); + CloseSession(session); + return -1; + } else { - if (ret != 0) - ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, - "m_ssl_gnutls.so: Error while reading on fd %d: %s", - fd, gnutls_strerror(ret)); - - // if ret == 0, client closed connection. - readresult = 0; + user->SetError(gnutls_strerror(ret)); CloseSession(session); - return 1; + return -1; } } - readresult = len; if (len) { + recvq.append(buffer, len); return 1; } - else - { - errno = EAGAIN; - return -1; - } } else if (session->status == ISSL_CLOSING) - readresult = 0; + return -1; - return 1; + return 0; } - virtual int OnRawSocketWrite(int fd, const char* buffer, int count) + int OnStreamSocketWrite(StreamSocket* user, std::string& sendq) { - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return 0; - - issl_session* session = &sessions[fd]; - const char* sendbuffer = buffer; + issl_session* session = &sessions[user->GetFd()]; if (!session->sess) { CloseSession(session); - return 1; + user->SetError("No SSL session"); + return -1; } - session->outbuf.append(sendbuffer, count); - sendbuffer = session->outbuf.c_str(); - count = session->outbuf.size(); - if (session->status == ISSL_HANDSHAKING_WRITE || session->status == ISSL_HANDSHAKING_READ) { // The handshake isn't finished, try to finish it. - Handshake(session, fd); - errno = session->status == ISSL_CLOSING ? EIO : EAGAIN; + Handshake(session, user); + if (session->status != ISSL_CLOSING) + return 0; + user->SetError("Handshake Failed"); return -1; } @@ -545,42 +511,41 @@ class ModuleSSLGnuTLS : public Module if (session->status == ISSL_HANDSHAKEN) { - ret = gnutls_record_send(session->sess, sendbuffer, count); + ret = gnutls_record_send(session->sess, sendq.data(), sendq.length()); - if (ret == 0) + if (ret == (int)sendq.length()) { - CloseSession(session); + return 1; } - else if (ret < 0) + else if (ret > 0) { - if(ret != GNUTLS_E_AGAIN && ret != GNUTLS_E_INTERRUPTED) - { - ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, - "m_ssl_gnutls.so: Error while writing to fd %d: %s", - fd, gnutls_strerror(ret)); - CloseSession(session); - } - else - { - errno = EAGAIN; - } + sendq = sendq.substr(ret); + MakePollWrite(user); + return 0; } - else + else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) + { + MakePollWrite(user); + return 0; + } + else if (ret == 0) { - session->outbuf = session->outbuf.substr(ret); + CloseSession(session); + user->SetError("SSL Connection closed"); + return -1; + } + else // (ret < 0) + { + user->SetError(gnutls_strerror(ret)); + CloseSession(session); + return -1; } } - if (!session->outbuf.empty()) - MakePollWrite(fd); - - /* Who's smart idea was it to return 1 when we havent written anything? - * This fucks the buffer up in BufferedSocket :p - */ - return ret < 1 ? 0 : ret; + return 0; } - bool Handshake(issl_session* session, int fd) + bool Handshake(issl_session* session, EventHandler* user) { int ret = gnutls_handshake(session->sess); @@ -599,15 +564,11 @@ class ModuleSSLGnuTLS : public Module { // gnutls_handshake() wants to write() again. session->status = ISSL_HANDSHAKING_WRITE; - MakePollWrite(fd); + MakePollWrite(user); } } else { - // Handshake failed. - ServerInstance->Logs->Log("m_ssl_gnutls", DEFAULT, - "m_ssl_gnutls.so: Handshake failed on fd %d: %s", - fd, gnutls_strerror(ret)); CloseSession(session); session->status = ISSL_CLOSING; } @@ -619,18 +580,16 @@ class ModuleSSLGnuTLS : public Module // Change the seesion state session->status = ISSL_HANDSHAKEN; - EventHandler* user = ServerInstance->SE->GetRef(fd); - VerifyCertificate(session,user); // Finish writing, if any left - MakePollWrite(fd); + MakePollWrite(user); return true; } } - virtual void OnPostConnect(User* user) + void OnPostConnect(User* user) { // This occurs AFTER OnUserConnect so we can be sure the // protocol module has propagated the NICK message. @@ -646,22 +605,9 @@ class ModuleSSLGnuTLS : public Module } } - void MakePollWrite(int fd) + void MakePollWrite(EventHandler* eh) { - //OnRawSocketWrite(fd, NULL, 0); - EventHandler* eh = ServerInstance->SE->GetRef(fd); - if (eh) - ServerInstance->SE->WantWrite(eh); - } - - virtual void OnBufferFlushed(User* user) - { - if (user->GetIOHook() == this) - { - issl_session* session = &sessions[user->GetFd()]; - if (session && session->outbuf.size()) - OnRawSocketWrite(user->GetFd(), NULL, 0); - } + ServerInstance->SE->WantWrite(eh); } void CloseSession(issl_session* session) @@ -672,7 +618,6 @@ class ModuleSSLGnuTLS : public Module gnutls_deinit(session->sess); } - session->outbuf.clear(); session->sess = NULL; session->status = ISSL_NONE; } diff --git a/src/modules/extra/m_ssl_openssl.cpp b/src/modules/extra/m_ssl_openssl.cpp index a33cf6bc2..e77fa23ff 100644 --- a/src/modules/extra/m_ssl_openssl.cpp +++ b/src/modules/extra/m_ssl_openssl.cpp @@ -59,7 +59,7 @@ public: unsigned int inbufoffset; char* inbuf; // Buffer OpenSSL reads into. - std::string outbuf; // Buffer for outgoing data that OpenSSL will not take. + std::string outbuf; int fd; bool outbound; @@ -95,7 +95,6 @@ class ModuleSSLOpenSSL : public Module SSL_CTX* ctx; SSL_CTX* clictx; - char* dummy; char cipher[MAXBUF]; std::string keyfile; @@ -137,14 +136,13 @@ class ModuleSSLOpenSSL : public Module // Needs the flag as it ignores a plain /rehash OnModuleRehash(NULL,"ssl"); - Implementation eventlist[] = { I_OnRawSocketConnect, I_OnRawSocketAccept, - I_OnRawSocketClose, I_OnRawSocketRead, I_OnRawSocketWrite, I_OnCleanup, I_On005Numeric, - I_OnBufferFlushed, I_OnRequest, I_OnRehash, I_OnModuleRehash, I_OnPostConnect, + Implementation eventlist[] = { + I_On005Numeric, I_OnBufferFlushed, I_OnRequest, I_OnRehash, I_OnModuleRehash, I_OnPostConnect, I_OnHookIO }; ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); } - virtual void OnHookIO(EventHandler* user, ListenSocketBase* lsb) + void OnHookIO(StreamSocket* user, ListenSocketBase* lsb) { if (!user->GetIOHook() && listenports.find(lsb) != listenports.end()) { @@ -153,7 +151,7 @@ class ModuleSSLOpenSSL : public Module } } - virtual void OnRehash(User* user) + void OnRehash(User* user) { ConfigReader Conf(ServerInstance); @@ -179,7 +177,7 @@ class ModuleSSLOpenSSL : public Module sslports.erase(sslports.end() - 1); } - virtual void OnModuleRehash(User* user, const std::string ¶m) + void OnModuleRehash(User* user, const std::string ¶m) { if (param != "ssl") return; @@ -266,13 +264,13 @@ class ModuleSSLOpenSSL : public Module fclose(dhpfile); } - virtual void On005Numeric(std::string &output) + void On005Numeric(std::string &output) { if (!sslports.empty()) output.append(" SSL=" + sslports); } - virtual ~ModuleSSLOpenSSL() + ~ModuleSSLOpenSSL() { SSL_CTX_free(ctx); SSL_CTX_free(clictx); @@ -280,7 +278,7 @@ class ModuleSSLOpenSSL : public Module delete[] sessions; } - virtual void OnCleanup(int target_type, void* item) + void OnCleanup(int target_type, void* item) { if (target_type == TYPE_USER) { @@ -296,13 +294,13 @@ class ModuleSSLOpenSSL : public Module } } - virtual Version GetVersion() + Version GetVersion() { return Version("$Id$", VF_VENDOR, API_VERSION); } - virtual const char* OnRequest(Request* request) + const char* OnRequest(Request* request) { ISHRequest* ISR = (ISHRequest*)request; if (strcmp("IS_NAME", request->GetId()) == 0) @@ -311,21 +309,13 @@ class ModuleSSLOpenSSL : public Module } else if (strcmp("IS_HOOK", request->GetId()) == 0) { - const char* ret = "OK"; - try - { - ret = ISR->Sock->AddIOHook((Module*)this) ? "OK" : NULL; - } - catch (ModuleException &e) - { - return NULL; - } - - return ret; + ISR->Sock->AddIOHook(this); + return "OK"; } else if (strcmp("IS_UNHOOK", request->GetId()) == 0) { - return ISR->Sock->DelIOHook() ? "OK" : NULL; + ISR->Sock->DelIOHook(); + return "OK"; } else if (strcmp("IS_HSDONE", request->GetId()) == 0) { @@ -353,11 +343,9 @@ class ModuleSSLOpenSSL : public Module } - virtual void OnRawSocketAccept(int fd, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) + void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) { - /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return; + int fd = user->GetFd(); issl_session* session = &sessions[fd]; @@ -377,11 +365,12 @@ class ModuleSSLOpenSSL : public Module return; } - Handshake(session); + Handshake(user, session); } - virtual void OnRawSocketConnect(int fd) + void OnStreamSocketConnect(StreamSocket* user) { + int fd = user->GetFd(); /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() -1)) return; @@ -404,11 +393,12 @@ class ModuleSSLOpenSSL : public Module return; } - Handshake(session); + Handshake(user, session); } - virtual void OnRawSocketClose(int fd) + void OnStreamSocketClose(StreamSocket* user) { + int fd = user->GetFd(); /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) return; @@ -416,19 +406,19 @@ class ModuleSSLOpenSSL : public Module CloseSession(&sessions[fd]); } - virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) + int OnStreamSocketRead(StreamSocket* user, std::string& recvq) { + int fd = user->GetFd(); /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return 0; + return -1; issl_session* session = &sessions[fd]; if (!session->sess) { - readresult = 0; CloseSession(session); - return 1; + return -1; } if (session->status == ISSL_HANDSHAKING) @@ -436,17 +426,17 @@ class ModuleSSLOpenSSL : public Module if (session->rstat == ISSL_READ || session->wstat == ISSL_READ) { // The handshake isn't finished and it wants to read, try to finish it. - if (!Handshake(session)) + if (!Handshake(user, session)) { // Couldn't resume handshake. - errno = session->status == ISSL_NONE ? EIO : EAGAIN; - return -1; + if (session->status == ISSL_NONE) + return -1; + return 0; } } else { - errno = EAGAIN; - return -1; + return 0; } } @@ -456,51 +446,37 @@ class ModuleSSLOpenSSL : public Module { if (session->wstat == ISSL_READ) { - if(DoWrite(session) == 0) + if(DoWrite(user, session) == 0) return 0; } if (session->rstat == ISSL_READ) { - int ret = DoRead(session); + int ret = DoRead(user, session); if (ret > 0) { - if (count <= session->inbufoffset) - { - memcpy(buffer, session->inbuf, count); - // Move the stuff left in inbuf to the beginning of it - memmove(session->inbuf, session->inbuf + count, (session->inbufoffset - count)); - // Now we need to set session->inbufoffset to the amount of data still waiting to be handed to insp. - session->inbufoffset -= count; - // Insp uses readresult as the count of how much data there is in buffer, so: - readresult = count; - } - else - { - // There's not as much in the inbuf as there is space in the buffer, so just copy the whole thing. - memcpy(buffer, session->inbuf, session->inbufoffset); - - readresult = session->inbufoffset; - // Zero the offset, as there's nothing there.. - session->inbufoffset = 0; - } + recvq.append(session->inbuf, session->inbufoffset); + session->inbufoffset = 0; return 1; } - return ret; + else if (errno == EAGAIN || errno == EINTR) + return 0; + else + return -1; } } - return -1; + return 0; } - virtual int OnRawSocketWrite(int fd, const char* buffer, int count) + int OnStreamSocketWrite(StreamSocket* user, std::string& buffer) { + int fd = user->GetFd(); /* Are there any possibilities of an out of range fd? Hope not, but lets be paranoid */ if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return 0; + return -1; - errno = EAGAIN; issl_session* session = &sessions[fd]; if (!session->sess) @@ -509,40 +485,53 @@ class ModuleSSLOpenSSL : public Module return -1; } - session->outbuf.append(buffer, count); - MakePollWrite(session); - if (session->status == ISSL_HANDSHAKING) { // The handshake isn't finished, try to finish it. if (session->rstat == ISSL_WRITE || session->wstat == ISSL_WRITE) { - if (!Handshake(session)) + if (!Handshake(user, session)) { // Couldn't resume handshake. - errno = session->status == ISSL_NONE ? EIO : EAGAIN; - return -1; + if (session->status == ISSL_NONE) + return -1; + return 0; } } } + int rv = 0; + + // don't pull items into the output buffer until they are + // unlikely to block; this allows sendq exceeded to continue + // to work for SSL users. + // TODO better signaling for I/O requests so this isn't needed + if (session->outbuf.empty()) + { + session->outbuf = buffer; + rv = 1; + } + if (session->status == ISSL_OPEN) { if (session->rstat == ISSL_WRITE) { - DoRead(session); + DoRead(user, session); } if (session->wstat == ISSL_WRITE) { - return DoWrite(session); + DoWrite(user, session); } } - return 1; + if (rv == 0 || !session->outbuf.empty()) + ServerInstance->SE->WantWrite(user); + + return rv; } - int DoWrite(issl_session* session) + int DoWrite(StreamSocket* user, issl_session* session) { if (!session->outbuf.size()) return -1; @@ -561,6 +550,7 @@ class ModuleSSLOpenSSL : public Module if (err == SSL_ERROR_WANT_WRITE) { session->wstat = ISSL_WRITE; + ServerInstance->SE->WantWrite(user); return -1; } else if (err == SSL_ERROR_WANT_READ) @@ -581,7 +571,7 @@ class ModuleSSLOpenSSL : public Module } } - int DoRead(issl_session* session) + int DoRead(StreamSocket* user, issl_session* session) { // Is this right? Not sure if the unencrypted data is garaunteed to be the same length. // Read into the inbuffer, offset from the beginning by the amount of data we have that insp hasn't taken yet. @@ -606,7 +596,7 @@ class ModuleSSLOpenSSL : public Module else if (err == SSL_ERROR_WANT_WRITE) { session->rstat = ISSL_WRITE; - MakePollWrite(session); + ServerInstance->SE->WantWrite(user); return -1; } else @@ -627,7 +617,7 @@ class ModuleSSLOpenSSL : public Module } } - bool Handshake(issl_session* session) + bool Handshake(EventHandler* user, issl_session* session) { int ret; @@ -650,7 +640,7 @@ class ModuleSSLOpenSSL : public Module { session->wstat = ISSL_WRITE; session->status = ISSL_HANDSHAKING; - MakePollWrite(session); + ServerInstance->SE->WantWrite(user); return true; } else @@ -669,7 +659,7 @@ class ModuleSSLOpenSSL : public Module session->status = ISSL_OPEN; - MakePollWrite(session); + ServerInstance->SE->WantWrite(user); return true; } @@ -682,34 +672,14 @@ class ModuleSSLOpenSSL : public Module return true; } - virtual void OnPostConnect(User* user) - { - // This occurs AFTER OnUserConnect so we can be sure the - // protocol module has propagated the NICK message. - if ((user->GetIOHook() == this) && (IS_LOCAL(user))) - { - if (sessions[user->GetFd()].sess) - user->WriteServ("NOTICE %s :*** You are connected using SSL cipher \"%s\"", user->nick.c_str(), SSL_get_cipher(sessions[user->GetFd()].sess)); - } - } - - void MakePollWrite(issl_session* session) - { - //OnRawSocketWrite(session->fd, NULL, 0); - EventHandler* eh = ServerInstance->SE->GetRef(session->fd); - if (eh) - { - ServerInstance->SE->WantWrite(eh); - } - } - - virtual void OnBufferFlushed(User* user) + void OnBufferFlushed(User* user) { if (user->GetIOHook() == this) { + std::string dummy; issl_session* session = &sessions[user->GetFd()]; if (session && session->outbuf.size()) - OnRawSocketWrite(user->GetFd(), NULL, 0); + OnStreamSocketWrite(user, dummy); } } diff --git a/src/modules/extra/m_ziplink.cpp b/src/modules/extra/m_ziplink.cpp index c220460bd..7d090d80a 100644 --- a/src/modules/extra/m_ziplink.cpp +++ b/src/modules/extra/m_ziplink.cpp @@ -67,29 +67,29 @@ class ModuleZLib : public Module total_out_compressed = total_in_compressed = 0; total_out_uncompressed = total_in_uncompressed = 0; - Implementation eventlist[] = { I_OnRawSocketConnect, I_OnRawSocketAccept, I_OnRawSocketClose, I_OnRawSocketRead, I_OnRawSocketWrite, I_OnStats, I_OnRequest }; - ServerInstance->Modules->Attach(eventlist, this, 7); + Implementation eventlist[] = { I_OnStats, I_OnRequest }; + ServerInstance->Modules->Attach(eventlist, this, 2); // Allocate a buffer which is used for reading and writing data net_buffer_size = ServerInstance->Config->NetBufferSize; net_buffer = new char[net_buffer_size]; } - virtual ~ModuleZLib() + ~ModuleZLib() { ServerInstance->Modules->UnpublishInterface("BufferedSocketHook", this); delete[] sessions; delete[] net_buffer; } - virtual Version GetVersion() + Version GetVersion() { return Version("$Id$", VF_VENDOR, API_VERSION); } /* Handle BufferedSocketHook API requests */ - virtual const char* OnRequest(Request* request) + const char* OnRequest(Request* request) { ISHRequest* ISR = (ISHRequest*)request; if (strcmp("IS_NAME", request->GetId()) == 0) @@ -99,22 +99,13 @@ class ModuleZLib : public Module } else if (strcmp("IS_HOOK", request->GetId()) == 0) { - /* Attach to an inspsocket */ - const char* ret = "OK"; - try - { - ret = ISR->Sock->AddIOHook((Module*)this) ? "OK" : NULL; - } - catch (ModuleException& e) - { - return NULL; - } - return ret; + ISR->Sock->AddIOHook(this); + return "OK"; } else if (strcmp("IS_UNHOOK", request->GetId()) == 0) { - /* Detach from an inspsocket */ - return ISR->Sock->DelIOHook() ? "OK" : NULL; + ISR->Sock->DelIOHook(); + return "OK"; } else if (strcmp("IS_HSDONE", request->GetId()) == 0) { @@ -134,7 +125,7 @@ class ModuleZLib : public Module } /* Handle stats z (misc stats) */ - virtual ModResult OnStats(char symbol, User* user, string_list &results) + ModResult OnStats(char symbol, User* user, string_list &results) { if (symbol == 'z') { @@ -174,10 +165,14 @@ class ModuleZLib : public Module return MOD_RES_PASSTHRU; } - virtual void OnRawSocketConnect(int fd) + void OnStreamSocketConnect(StreamSocket* user) { - if ((fd < 0) || (fd > ServerInstance->SE->GetMaxFds() - 1)) - return; + OnStreamSocketAccept(user, 0, 0); + } + + void OnRawSocketAccept(StreamSocket* user, irc::sockets::sockaddrs*, irc::sockets::sockaddrs*) + { + int fd = user->GetFd(); izip_session* session = &sessions[fd]; @@ -211,39 +206,33 @@ class ModuleZLib : public Module session->status = IZIP_OPEN; } - virtual void OnRawSocketAccept(int fd, irc::sockets::sockaddrs*, irc::sockets::sockaddrs*) - { - /* Nothing special needs doing here compared to connect() */ - OnRawSocketConnect(fd); - } - - virtual void OnRawSocketClose(int fd) + void OnStreamSocketClose(StreamSocket* user) { + int fd = user->GetFd(); CloseSession(&sessions[fd]); } - virtual int OnRawSocketRead(int fd, char* buffer, unsigned int count, int &readresult) + int OnStreamSocketRead(StreamSocket* user, std::string& recvq) { + int fd = user->GetFd(); /* Find the sockets session */ izip_session* session = &sessions[fd]; if (session->status == IZIP_CLOSED) - return 0; + return -1; - if (session->inbuf.length()) - { - /* Our input buffer is filling up. This is *BAD*. - * We can't return more data than fits into buffer - * (count bytes), so we will generate another read - * event on purpose by *NOT* reading from 'fd' at all - * for now. - */ - readresult = 0; - } - else + if (session->inbuf.empty()) { /* Read read_buffer_size bytes at a time to the buffer (usually 2.5k) */ - readresult = read(fd, net_buffer, net_buffer_size); + int readresult = read(fd, net_buffer, net_buffer_size); + + if (readresult < 0) + { + if (errno == EINTR || errno == EAGAIN) + return 0; + } + if (readresult <= 0) + return -1; total_in_compressed += readresult; @@ -252,10 +241,8 @@ class ModuleZLib : public Module } size_t in_len = session->inbuf.length(); - - /* Do we have anything to do? */ - if (in_len <= 0) - return 0; + char* buffer = ServerInstance->GetReadBuffer(); + int count = ServerInstance->Config->NetBufferSize; /* Prepare decompression */ session->d_stream.next_in = (Bytef *)session->inbuf.c_str(); @@ -302,8 +289,7 @@ class ModuleZLib : public Module } if (ret != Z_OK) { - readresult = 0; - return 0; + return -1; } /* Update the inbut buffer */ @@ -315,24 +301,18 @@ class ModuleZLib : public Module total_in_uncompressed += uncompressed_length; /* Null-terminate the buffer -- this doesnt harm binary data */ - buffer[uncompressed_length] = 0; - - /* Set the read size to the correct total size */ - readresult = uncompressed_length; - + recvq.append(buffer, uncompressed_length); return 1; } - virtual int OnRawSocketWrite(int fd, const char* buffer, int count) + int OnStreamSocketWrite(StreamSocket* user, std::string& sendq) { + int fd = user->GetFd(); izip_session* session = &sessions[fd]; - if (!count) /* Nothing to do! */ - return 0; - if(session->status != IZIP_OPEN) /* Seriously, wtf? */ - return 0; + return -1; int ret; @@ -343,8 +323,8 @@ class ModuleZLib : public Module do { /* Prepare compression */ - session->c_stream.next_in = (Bytef*)buffer + offset; - session->c_stream.avail_in = count - offset; + session->c_stream.next_in = (Bytef*)sendq.data() + offset; + session->c_stream.avail_in = sendq.length() - offset; session->c_stream.next_out = (Bytef*)net_buffer; session->c_stream.avail_out = net_buffer_size; @@ -378,7 +358,7 @@ class ModuleZLib : public Module /* Space before - space after stuff was added to this */ unsigned int compressed = net_buffer_size - session->c_stream.avail_out; - unsigned int uncompressed = count - session->c_stream.avail_in; + unsigned int uncompressed = sendq.length() - session->c_stream.avail_in; /* Make it skip the data which was compressed already */ offset += uncompressed; @@ -404,14 +384,11 @@ class ModuleZLib : public Module else { session->outbuf.clear(); - return 0; + return -1; } } - /* ALL LIES the lot of it, we havent really written - * this amount, but the layer above doesnt need to know. - */ - return count; + return 1; } void Error(izip_session* session, const std::string &text) |