summaryrefslogtreecommitdiff
path: root/src/modules/extra/m_ssl_gnutls.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/modules/extra/m_ssl_gnutls.cpp')
-rw-r--r--src/modules/extra/m_ssl_gnutls.cpp89
1 files changed, 54 insertions, 35 deletions
diff --git a/src/modules/extra/m_ssl_gnutls.cpp b/src/modules/extra/m_ssl_gnutls.cpp
index e72666062..c65b8528a 100644
--- a/src/modules/extra/m_ssl_gnutls.cpp
+++ b/src/modules/extra/m_ssl_gnutls.cpp
@@ -44,6 +44,34 @@ static int cert_callback (gnutls_session_t session, const gnutls_datum_t * req_c
return 0;
}
+static ssize_t gnutls_pull_wrapper(gnutls_transport_ptr_t user_wrap, void* buffer, size_t size)
+{
+ StreamSocket* user = reinterpret_cast<StreamSocket*>(user_wrap);
+ if (user->GetEventMask() & FD_READ_WILL_BLOCK)
+ {
+ errno = EAGAIN;
+ return -1;
+ }
+ int rv = recv(user->GetFd(), buffer, size, 0);
+ if (rv < (int)size)
+ ServerInstance->SE->ChangeEventMask(user, FD_READ_WILL_BLOCK);
+ return rv;
+}
+
+static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t user_wrap, const void* buffer, size_t size)
+{
+ StreamSocket* user = reinterpret_cast<StreamSocket*>(user_wrap);
+ if (user->GetEventMask() & FD_WRITE_WILL_BLOCK)
+ {
+ errno = EAGAIN;
+ return -1;
+ }
+ int rv = send(user->GetFd(), buffer, size, 0);
+ if (rv < (int)size)
+ ServerInstance->SE->ChangeEventMask(user, FD_WRITE_WILL_BLOCK);
+ return rv;
+}
+
/** Represents an SSL user's extra data
*/
class issl_session : public classbase
@@ -388,7 +416,9 @@ class ModuleSSLGnuTLS : public Module
gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
gnutls_dh_set_prime_bits(session->sess, dh_bits);
- gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(fd)); // Give gnutls the fd for the socket.
+ gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user));
+ gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper);
+ gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper);
gnutls_certificate_server_set_request(session->sess, GNUTLS_CERT_REQUEST); // Request client certificate if any.
@@ -404,7 +434,9 @@ class ModuleSSLGnuTLS : public Module
gnutls_set_default_priority(session->sess); // Avoid calling all the priority functions, defaults are adequate.
gnutls_credentials_set(session->sess, GNUTLS_CRD_CERTIFICATE, x509_cred);
gnutls_dh_set_prime_bits(session->sess, dh_bits);
- gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user->GetFd()));
+ gnutls_transport_set_ptr(session->sess, reinterpret_cast<gnutls_transport_ptr_t>(user));
+ gnutls_transport_set_push_function(session->sess, gnutls_push_wrapper);
+ gnutls_transport_set_pull_function(session->sess, gnutls_pull_wrapper);
Handshake(session, user);
}
@@ -425,7 +457,7 @@ class ModuleSSLGnuTLS : public Module
return -1;
}
- if (session->status == ISSL_HANDSHAKING_READ)
+ if (session->status == ISSL_HANDSHAKING_READ || session->status == ISSL_HANDSHAKING_WRITE)
{
// The handshake isn't finished, try to finish it.
@@ -437,11 +469,6 @@ class ModuleSSLGnuTLS : public Module
return -1;
}
}
- else if (session->status == ISSL_HANDSHAKING_WRITE)
- {
- ServerInstance->SE->ChangeEventMask(user, FD_WANT_NO_READ | FD_WANT_POLL_WRITE);
- return 0;
- }
// If we resumed the handshake then session->status will be ISSL_HANDSHAKEN.
@@ -449,35 +476,27 @@ class ModuleSSLGnuTLS : public Module
{
char* buffer = ServerInstance->GetReadBuffer();
size_t bufsiz = ServerInstance->Config->NetBufferSize;
- size_t len = 0;
- while (len < bufsiz)
+ int ret = gnutls_record_recv(session->sess, buffer, bufsiz);
+ if (ret > 0)
{
- int ret = gnutls_record_recv(session->sess, buffer + len, bufsiz - len);
- if (ret > 0)
- {
- len += ret;
- }
- else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
- {
- break;
- }
- else if (ret == 0)
- {
- user->SetError("SSL Connection closed");
- CloseSession(session);
- return -1;
- }
- else
- {
- user->SetError(gnutls_strerror(ret));
- CloseSession(session);
- return -1;
- }
+ recvq.append(buffer, ret);
+ return 1;
}
- if (len)
+ else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED)
{
- recvq.append(buffer, len);
- return 1;
+ return 0;
+ }
+ else if (ret == 0)
+ {
+ user->SetError("SSL Connection closed");
+ CloseSession(session);
+ return -1;
+ }
+ else
+ {
+ user->SetError(gnutls_strerror(ret));
+ CloseSession(session);
+ return -1;
}
}
else if (session->status == ISSL_CLOSING)
@@ -546,7 +565,7 @@ class ModuleSSLGnuTLS : public Module
return 0;
}
- bool Handshake(issl_session* session, EventHandler* user)
+ bool Handshake(issl_session* session, StreamSocket* user)
{
int ret = gnutls_handshake(session->sess);