summaryrefslogtreecommitdiff
path: root/src/stls/stls_run.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/stls/stls_run.c')
-rw-r--r--src/stls/stls_run.c360
1 files changed, 186 insertions, 174 deletions
diff --git a/src/stls/stls_run.c b/src/stls/stls_run.c
index 7385c4e..2456e22 100644
--- a/src/stls/stls_run.c
+++ b/src/stls/stls_run.c
@@ -1,6 +1,7 @@
/* ISC license. */
#include <sys/uio.h>
+#include <stdint.h>
#include <errno.h>
#include <unistd.h>
@@ -14,248 +15,259 @@
#include <s6-networking/stls.h>
-typedef struct tlsbuf_s tlsbuf_t, *tlsbuf_t_ref ;
-struct tlsbuf_s
+
+typedef struct stls_buffer_s stls_buffer, *stls_buffer_ref ;
+struct stls_buffer_s
{
buffer b ;
- unsigned char blockedonother : 1 ;
char buf[STLS_BUFSIZE] ;
+ uint8_t flags ; /* 0x1: flush/fill wants opposite IO; 0x2: close_notify initiated */
} ;
-static inline int buffer_tls_flush (struct tls *ctx, tlsbuf_t *b)
+
+ /*
+ We need access to the state field of struct tls, which is private.
+ So we fake enough stuff that we get the correct field offset.
+ */
+
+#define TLS_EOF_NO_CLOSE_NOTIFY 1
+
+struct fake_tls_error_s
{
- struct iovec v[2] ;
- ssize_t r, w ;
- buffer_rpeek(&b[0].b, v) ;
- r = tls_write(ctx, v[0].iov_base, v[0].iov_len) ;
- switch (r)
- {
- case -1 : return -1 ;
- case TLS_WANT_POLLIN :
- if (b[1].blockedonother) strerr_dief1x(101, "TLS deadlock") ;
- b[0].blockedonother = 1 ;
- case TLS_WANT_POLLOUT : return 0 ;
- default : break ;
- }
- w = r ;
- if ((size_t)w == v[0].iov_len && v[1].iov_len)
+ char *msg ;
+ int num ;
+ int tls ;
+} ;
+
+struct fake_tls_s
+{
+ void *config ;
+ void *keypair ;
+ struct fake_tls_error_s error ;
+ uint32_t flags ;
+ uint32_t state ;
+} ;
+
+ /* All because there's no accessor for this in the official libtls API: */
+
+static inline int tls_eof_got_close_notify (struct tls *ctx)
+{
+ return !(((struct fake_tls_s *)ctx)->state & TLS_EOF_NO_CLOSE_NOTIFY) ;
+}
+
+ /* We want tls_read/write to behave l */
+
+static int tls_allwrite (struct tls *ctx, char const *s, size_t len, size_t *w)
+{
+ while (*w < len)
{
- r = tls_write(ctx, v[1].iov_base, v[1].iov_len) ;
+ ssize_t r = tls_write(ctx, s + *w, len - *w) ;
switch (r)
{
- case TLS_WANT_POLLIN :
- if (b[1].blockedonother) strerr_dief1x(101, "TLS deadlock") ;
- b[0].blockedonother = 1 ;
- case -1 :
- case TLS_WANT_POLLOUT :
- buffer_rseek(&b[0].b, w) ;
- return 0 ;
+ case -1 : strerr_diefu2x(98, "tls_write: ", tls_error(ctx)) ;
+ case TLS_WANT_POLLIN : return 1 ;
+ case TLS_WANT_POLLOUT : return 0 ;
default : break ;
}
- w += r ;
+ *w += r ;
}
- buffer_rseek(&b[0].b, w) ;
- return 1 ;
+ return 0 ;
}
-static inline int buffer_tls_fill (struct tls *ctx, tlsbuf_t *b)
+static void tls_flush (struct tls *ctx, stls_buffer *b)
{
struct iovec v[2] ;
- ssize_t r, w ;
- int ok = 1 ;
- buffer_wpeek(&b[1].b, v) ;
- r = tls_read(ctx, v[0].iov_base, v[0].iov_len) ;
- switch (r)
- {
- case 0 : return -2 ;
- case -1 : return -1 ;
- case TLS_WANT_POLLOUT :
- if (b[0].blockedonother) strerr_dief1x(101, "TLS deadlock") ;
- b[1].blockedonother = 1 ;
- case TLS_WANT_POLLIN : return 0 ;
- default : break ;
- }
- w = r ;
- if ((size_t)w == v[0].iov_len && v[1].iov_len)
+ size_t w = 0 ;
+ int r ;
+ buffer_rpeek(&b[0].b, v) ;
+ r = tls_allwrite(ctx, v[0].iov_base, v[0].iov_len, &w) ;
+ buffer_rseek(&b[0].b, w) ;
+ if (w < v[0].iov_len || !v[1].iov_len) goto out ;
+ w = 0 ;
+ r = tls_allwrite(ctx, v[1].iov_base, v[1].iov_len, &w) ;
+ buffer_rseek(&b[0].b, w) ;
+ out:
+ if (r) b[1].flags |= 1 ; else b[1].flags &= ~1 ;
+}
+
+static int tls_allread (struct tls *ctx, char *s, size_t len, size_t *w)
+{
+ while (*w < len)
{
- r = tls_read(ctx, v[1].iov_base, v[1].iov_len) ;
+ ssize_t r = tls_read(ctx, s + *w, len - *w) ;
switch (r)
{
- case TLS_WANT_POLLOUT :
- if (b[0].blockedonother) strerr_dief1x(101, "TLS deadlock") ;
- b[1].blockedonother = 1 ;
- case -1 :
- case TLS_WANT_POLLIN :
- buffer_wseek(&b[1].b, w) ;
- return 0 ;
- case 0 : ok = -1 ; errno = EPIPE ;
+ case -1 : strerr_diefu2x(98, "tls_read: ", tls_error(ctx)) ;
+ case 0 : return -1 ;
+ case TLS_WANT_POLLIN : return 0 ;
+ case TLS_WANT_POLLOUT : return 1 ;
default : break ;
}
- w += r ;
+ *w += r ;
}
- buffer_wseek(&b[1].b, w) ;
- return ok ;
+ return 0 ;
}
-static void send_closenotify (struct tls *ctx, int const *fds)
+static int tls_fill (struct tls *ctx, stls_buffer *b)
{
- iopause_fd x = { .fd = fds[3], .events = IOPAUSE_WRITE } ;
- while (tls_close(ctx) == TLS_WANT_POLLOUT)
- iopause_g(&x, 1, 0) ;
+ struct iovec v[2] ;
+ size_t w = 0 ;
+ int r ;
+ buffer_wpeek(&b[1].b, v) ;
+ r = tls_allread(ctx, v[0].iov_base, v[0].iov_len, &w) ;
+ buffer_wseek(&b[1].b, w) ;
+ if (w < v[0].iov_len || !v[1].iov_len) goto out ;
+ w = 0 ;
+ r = tls_allread(ctx, v[1].iov_base, v[1].iov_len, &w) ;
+ buffer_wseek(&b[1].b, w) ;
+ out:
+ if (r == -1) return 1 ;
+ if (r) b[0].flags |= 1 ; else b[0].flags &= ~1 ;
+ return 0 ;
}
-static void closeit (struct tls *ctx, int *fds, int closenotify)
+static int tls_close_nb (struct tls *ctx, stls_buffer *b)
{
- if (!closenotify) fd_shutdown(fds[3], 1) ;
- else if (fds[2] >= 0) send_closenotify(ctx, fds) ;
- fd_close(fds[3]) ; fds[3] = -1 ;
+ switch (tls_close(ctx))
+ {
+ case 0 : b[0].flags &= ~2 ; b[1].flags &= ~2 ; b[1].flags |= 4 ; return 1 ;
+ case TLS_WANT_POLLIN : b[0].flags &= ~2 ; b[1].flags |= 2 ; break ;
+ case TLS_WANT_POLLOUT : b[0].flags |= 2 ; b[1].flags &= ~2 ; break ;
+ default : strerr_diefu2x(98, "tls_close: ", tls_error(ctx)) ;
+ }
+ return 0 ;
}
-void stls_run (struct tls *ctx, int *fds, uint32_t options, unsigned int verbosity)
+ /* The engine. */
+
+void stls_run (struct tls *ctx, int const *fds, uint32_t options, unsigned int verbosity)
{
- tlsbuf_t b[2] = { { .blockedonother = 0 }, { .blockedonother = 0 } } ;
- iopause_fd x[4] ;
- unsigned int xindex[4] ;
-
- if (ndelay_on(fds[0]) < 0
- || ndelay_on(fds[1]) < 0
- || ndelay_on(fds[2]) < 0
- || ndelay_on(fds[3]) < 0)
+ stls_buffer b[2] =
+ {
+ { .b = BUFFER_INIT(&buffer_read, fds[0], b[0].buf, STLS_BUFSIZE), .flags = 0 },
+ { .b = BUFFER_INIT(&buffer_write, fds[1], b[1].buf, STLS_BUFSIZE), .flags = 0 },
+ } ;
+ iopause_fd x[4] = { { .fd = fds[0] }, { .fd = fds[1] }, { .fd = fds[2] }, { .fd = fds[3] } } ;
+
+ if (ndelay_on(x[0].fd) == -1
+ || ndelay_on(x[1].fd) == -1
+ || ndelay_on(x[2].fd) == -1
+ || ndelay_on(x[3].fd) == -1)
strerr_diefu1sys(111, "set fds non-blocking") ;
- buffer_init(&b[0].b, &buffer_read, fds[0], b[0].buf, STLS_BUFSIZE) ;
- buffer_init(&b[1].b, &buffer_write, fds[1], b[1].buf, STLS_BUFSIZE) ;
-
- for (;;)
+ while (x[0].fd >= 0 || x[1].fd >= 0 || x[3].fd >= 0)
{
- unsigned int j = 0 ;
- int r ;
-
-
- /* poll() preparation */
-
- if (fds[0] >= 0 && buffer_isreadable(&b[0].b))
- {
- x[j].fd = fds[0] ;
- x[j].events = IOPAUSE_READ ;
- xindex[0] = j++ ;
- }
- else xindex[0] = 4 ;
-
- if (fds[1] >= 0 && buffer_iswritable(&b[1].b))
- {
- x[j].fd = fds[1] ;
- x[j].events = IOPAUSE_WRITE ;
- xindex[1] = j++ ;
- }
- else xindex[1] = 4 ;
-
- if (fds[2] >= 0 && !b[1].blockedonother && buffer_isreadable(&b[1].b))
- {
- x[j].fd = fds[2] ;
- x[j].events = IOPAUSE_READ ;
- xindex[2] = j++ ;
- }
- else xindex[2] = 4 ;
-
- if (fds[3] >= 0 && !b[0].blockedonother && buffer_iswritable(&b[0].b))
- {
- x[j].fd = fds[3] ;
- x[j].events = IOPAUSE_WRITE ;
- xindex[3] = j++ ;
- }
- else xindex[3] = 4 ;
-
- if (xindex[0] == 4 && xindex[1] == 4 && xindex[3] == 4) break ;
+ x[0].events = x[0].fd >= 0 && buffer_isreadable(&b[0].b) ? IOPAUSE_READ : 0 ;
+ x[1].events = x[1].fd >= 0 && buffer_iswritable(&b[1].b) ? IOPAUSE_WRITE : 0 ;
+ x[2].events = x[2].fd >= 0 && (buffer_isreadable(&b[1].b) || (b[1].flags & 1 && buffer_iswritable(&b[0].b))) ? IOPAUSE_READ : 0 ;
+ x[3].events = x[3].fd >= 0 && (buffer_iswritable(&b[0].b) || (b[0].flags & 1 && buffer_isreadable(&b[1].b))) ? IOPAUSE_WRITE : 0 ;
- /* poll() */
-
- r = iopause_g(x, j, 0) ;
- if (r < 0) strerr_diefu1sys(111, "iopause") ;
- else if (!r) break ;
-
- while (j--)
- if (x[j].revents & IOPAUSE_EXCEPT)
- x[j].revents |= IOPAUSE_READ | IOPAUSE_WRITE ;
-
+ if (iopause_g(x, 4, 0) == -1) strerr_diefu1sys(111, "iopause") ;
/* Flush to local */
- if (xindex[1] < 4 && x[xindex[1]].revents & IOPAUSE_WRITE)
+ if (x[1].revents)
{
- r = buffer_flush(&b[1].b) ;
- if (!r && !error_isagain(errno))
+ if (!buffer_flush(&b[1].b))
{
- strerr_warnwu1sys("write to application") ;
- if (fds[2] >= 0)
- {
- if (options & 1) fd_shutdown(fds[2], 0) ;
- fd_close(fds[2]) ; fds[2] = -1 ;
- xindex[2] = 4 ;
- }
- r = 1 ;
+ if (!error_isagain(errno)) strerr_diefu1sys(111, "write to local") ;
}
- if (r && fds[2] < 0)
+ else if (x[2].fd == -1)
{
- fd_close(fds[1]) ; fds[1] = -1 ;
+ fd_close(x[1].fd) ;
+ x[1].fd = -1 ;
}
}
- /* Flush to remote */
+ /* Flush to remote: do everything that had TLS_WANT_POLLOUT */
- if (xindex[3] < 4 && x[xindex[3]].revents & IOPAUSE_WRITE)
+ if (x[3].revents)
{
- r = buffer_tls_flush(ctx, b) ;
- if (r < 0)
+ if (buffer_len(&b[0].b)) tls_flush(ctx, b) ; /* normal write */
+ if ((b[0].flags & 1 && tls_fill(ctx, b)) /* peer sent close_notify and it just completed */
+ || (b[0].flags & 2 && tls_close_nb(ctx, b))) /* we send close_notify and it instantly succeeds */
{
- strerr_warnwu2x("write to peer: ", tls_error(ctx)) ;
- fd_close(fds[0]) ; fds[0] = -1 ;
- xindex[0] = 4 ;
+ if (buffer_isempty(&b[1].b)) break ;
+ fd_close(x[3].fd) ; x[3].fd = -1 ;
+ fd_close(x[2].fd) ; x[2].fd = -1 ;
+ if (x[0].fd >= 0) { fd_close(x[0].fd) ; x[0].fd = -1 ; }
+ continue ;
+ }
+ if (x[0].fd == -1 && buffer_isempty(&b[0].b))
+ {
+ if (!(options & 1) || tls_close_nb(ctx, b))
+ {
+ fd_shutdown(x[3].fd, 1) ;
+ fd_close(x[3].fd) ;
+ x[3].fd = -1 ;
+ }
}
- if (r && fds[0] < 0)
- closeit(ctx, fds, options & 1) ;
}
/* Fill from local */
- if (xindex[0] < 4 && x[xindex[0]].revents & IOPAUSE_READ)
+ if (x[0].revents)
{
- r = sanitize_read(buffer_fill(&b[0].b)) ;
- if (r < 0)
+ ssize_t r = buffer_fill(&b[0].b) ;
+ if (r == -1 && !error_isagain(errno))
+ strerr_diefu1sys(111, "read from local") ;
+ else if (!r)
{
- if (errno != EPIPE) strerr_warnwu1sys("read from application") ;
- fd_close(fds[0]) ; fds[0] = -1 ;
+ fd_close(x[0].fd) ;
+ x[0].fd = -1 ;
if (buffer_isempty(&b[0].b))
- closeit(ctx, fds, options & 1) ;
+ {
+ if (!(options & 1) || tls_close_nb(ctx, b))
+ {
+ fd_shutdown(x[3].fd, 1) ;
+ fd_close(x[3].fd) ;
+ x[3].fd = -1 ;
+ }
+ }
}
}
- /* Fill from remote */
+ /* Fill from remote: do everything that had TLS_WANT_POLLIN */
- if (xindex[2] < 4 && x[xindex[2]].revents & IOPAUSE_READ)
+ if (x[2].revents)
{
- r = buffer_tls_fill(ctx, b) ;
- if (r < 0)
- {
- if (r == -1) strerr_warnwu2x("read from peer: ", tls_error(ctx)) ;
- if (!(options & 1)) fd_shutdown(fds[2], 0) ;
- /*
- XXX: We need a way to detect when we've received a close_notify,
- because then we need to trigger a write and then shut the engine
- down. This is orthogonal to options&1, it only means that the
- peer sent a close_notify.
- As for now, libtls doesn't offer an API to detect that, so we
- do nothing special - we just wait until our app sends EOF.
- */
- fd_close(fds[2]) ; fds[2] = -1 ;
+ if (buffer_isreadable(&b[1].b) && tls_fill(ctx, b))
+ { /* connection closed */
+ fd_shutdown(x[2].fd, 0) ;
+ fd_close(x[2].fd) ;
+ x[2].fd = -1 ;
if (buffer_isempty(&b[1].b))
{
- fd_close(fds[1]) ; fds[1] = -1 ;
+ if (tls_eof_got_close_notify(ctx)) break ;
+ fd_close(x[1].fd) ;
+ x[1].fd = -1 ;
+ }
+ if (options & 2)
+ {
+ if (!tls_eof_got_close_notify(ctx))
+ strerr_dief1x(98, "remote closed connection without a close_notify") ;
+ else if (x[3].fd >= 0)
+ {
+ fd_shutdown(x[3].fd, 1) ;
+ fd_close(x[3].fd) ;
+ x[3].fd = -1 ;
+ }
+ }
+ }
+ else
+ { /* normal case */
+ if (b[1].flags & 1) tls_flush(ctx, b) ;
+ if (b[1].flags & 2 && tls_close_nb(ctx, b))
+ {
+ if (buffer_isempty(&b[1].b)) break ;
+ if (x[3].fd >= 0) { fd_close(x[3].fd) ; x[3].fd = -1 ; }
+ if (x[0].fd >= 0) { fd_close(x[0].fd) ; x[0].fd = -1 ; }
+ fd_close(x[2].fd) ; x[2].fd = -1 ;
}
}
}