diff options
Diffstat (limited to 'src/modules/extra/m_ssl_gnutls.cpp')
-rw-r--r-- | src/modules/extra/m_ssl_gnutls.cpp | 89 |
1 files changed, 54 insertions, 35 deletions
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp index e72666062..c65b8528a 100644 --- a/src/modules/extra/m_ssl_gnutls.cpp +++ b/src/modules/extra/m_ssl_gnutls.cpp @@ -44,6 +44,34 @@ static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_c return 0; } +static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t user_wrap, void* buffer, size_t size) +{ + StreamSocket* user = reinterpret_cast<StreamSocket*>(user_wrap); + if (user->GetEventMask() & FD_READ_WILL_BLOCK) + { + errno = EAGAIN; + return -1; + } + int rv = recv(user->GetFd(), buffer, size, 0); + if (rv < (int)size) + ServerInstance->SE->ChangeEventMask(user, FD_READ_WILL_BLOCK); + return rv; +} + +static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t user_wrap, const void* buffer, size_t size) +{ + StreamSocket* user = reinterpret_cast<StreamSocket*>(user_wrap); + if (user->GetEventMask() & FD_WRITE_WILL_BLOCK) + { + errno = EAGAIN; + return -1; + } + int rv = send(user->GetFd(), buffer, size, 0); + if (rv < (int)size) + ServerInstance->SE->ChangeEventMask(user, FD_WRITE_WILL_BLOCK); + return rv; +} + /** Represents an SSL user's extra data */ class issl_session : public classbase @@ -388,7 +416,9 @@ class ModuleSSLGnuTLS : public Module gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); - gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(fd)); // Give gnutls the fd for the socket. + gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user)); + gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); + gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any. @@ -404,7 +434,9 @@ class ModuleSSLGnuTLS : public Module gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate. gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_dh_set_prime_bits(session->sess, dh_bits); - gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user->GetFd())); + gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user)); + gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper); + gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper); Handshake(session, user); } @@ -425,7 +457,7 @@ class ModuleSSLGnuTLS : public Module return -1; } - if (session->status == ISSL_HANDSHAKING_READ) + if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE) { // The handshake isn't finished, try to finish it. @@ -437,11 +469,6 @@ class ModuleSSLGnuTLS : public Module return -1; } } - else if (session->status == ISSL_HANDSHAKING_WRITE) - { - ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_POLL_WRITE); - return 0; - } // If we resumed the handshake then session->status will be ISSL_HANDSHAKEN. @@ -449,35 +476,27 @@ class ModuleSSLGnuTLS : public Module { char* buffer = ServerInstance->GetReadBuffer(); size_t bufsiz = ServerInstance->Config->NetBufferSize; - size_t len = 0; - while (len < bufsiz) + int ret = gnutls_record_recv(session->sess, buffer, bufsiz); + if (ret > 0) { - int ret = gnutls_record_recv(session->sess, buffer + len, bufsiz - len); - if (ret > 0) - { - len += ret; - } - else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) - { - break; - } - else if (ret == 0) - { - user->SetError("SSL Connection closed"); - CloseSession(session); - return -1; - } - else - { - user->SetError(gnutls_strerror(ret)); - CloseSession(session); - return -1; - } + recvq.append(buffer, ret); + return 1; } - if (len) + else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) { - recvq.append(buffer, len); - return 1; + return 0; + } + else if (ret == 0) + { + user->SetError("SSL Connection closed"); + CloseSession(session); + return -1; + } + else + { + user->SetError(gnutls_strerror(ret)); + CloseSession(session); + return -1; } } else if (session->status == ISSL_CLOSING) @@ -546,7 +565,7 @@ class ModuleSSLGnuTLS : public Module return 0; } - bool Handshake(issl_session* session, EventHandler* user) + bool Handshake(issl_session* session, StreamSocket* user) { int ret = gnutls_handshake(session->sess); |