diff options
Diffstat (limited to 'src/stls/stls_run.c')
-rw-r--r-- | src/stls/stls_run.c | 360 |
1 files changed, 186 insertions, 174 deletions
diff --git a/src/stls/stls_run.c b/src/stls/stls_run.c index 7385c4e..2456e22 100644 --- a/src/stls/stls_run.c +++ b/src/stls/stls_run.c @@ -1,6 +1,7 @@ /* ISC license. */ #include <sys/uio.h> +#include <stdint.h> #include <errno.h> #include <unistd.h> @@ -14,248 +15,259 @@ #include <s6-networking/stls.h> -typedef struct tlsbuf_s tlsbuf_t, *tlsbuf_t_ref ; -struct tlsbuf_s + +typedef struct stls_buffer_s stls_buffer, *stls_buffer_ref ; +struct stls_buffer_s { buffer b ; - unsigned char blockedonother : 1 ; char buf[STLS_BUFSIZE] ; + uint8_t flags ; /* 0x1: flush/fill wants opposite IO; 0x2: close_notify initiated */ } ; -static inline int buffer_tls_flush (struct tls *ctx, tlsbuf_t *b) + + /* + We need access to the state field of struct tls, which is private. + So we fake enough stuff that we get the correct field offset. + */ + +#define TLS_EOF_NO_CLOSE_NOTIFY 1 + +struct fake_tls_error_s { - struct iovec v[2] ; - ssize_t r, w ; - buffer_rpeek(&b[0].b, v) ; - r = tls_write(ctx, v[0].iov_base, v[0].iov_len) ; - switch (r) - { - case -1 : return -1 ; - case TLS_WANT_POLLIN : - if (b[1].blockedonother) strerr_dief1x(101, "TLS deadlock") ; - b[0].blockedonother = 1 ; - case TLS_WANT_POLLOUT : return 0 ; - default : break ; - } - w = r ; - if ((size_t)w == v[0].iov_len && v[1].iov_len) + char *msg ; + int num ; + int tls ; +} ; + +struct fake_tls_s +{ + void *config ; + void *keypair ; + struct fake_tls_error_s error ; + uint32_t flags ; + uint32_t state ; +} ; + + /* All because there's no accessor for this in the official libtls API: */ + +static inline int tls_eof_got_close_notify (struct tls *ctx) +{ + return !(((struct fake_tls_s *)ctx)->state & TLS_EOF_NO_CLOSE_NOTIFY) ; +} + + /* We want tls_read/write to behave l */ + +static int tls_allwrite (struct tls *ctx, char const *s, size_t len, size_t *w) +{ + while (*w < len) { - r = tls_write(ctx, v[1].iov_base, v[1].iov_len) ; + ssize_t r = tls_write(ctx, s + *w, len - *w) ; switch (r) { - case TLS_WANT_POLLIN : - if (b[1].blockedonother) strerr_dief1x(101, "TLS deadlock") ; - b[0].blockedonother = 1 ; - case -1 : - case TLS_WANT_POLLOUT : - buffer_rseek(&b[0].b, w) ; - return 0 ; + case -1 : strerr_diefu2x(98, "tls_write: ", tls_error(ctx)) ; + case TLS_WANT_POLLIN : return 1 ; + case TLS_WANT_POLLOUT : return 0 ; default : break ; } - w += r ; + *w += r ; } - buffer_rseek(&b[0].b, w) ; - return 1 ; + return 0 ; } -static inline int buffer_tls_fill (struct tls *ctx, tlsbuf_t *b) +static void tls_flush (struct tls *ctx, stls_buffer *b) { struct iovec v[2] ; - ssize_t r, w ; - int ok = 1 ; - buffer_wpeek(&b[1].b, v) ; - r = tls_read(ctx, v[0].iov_base, v[0].iov_len) ; - switch (r) - { - case 0 : return -2 ; - case -1 : return -1 ; - case TLS_WANT_POLLOUT : - if (b[0].blockedonother) strerr_dief1x(101, "TLS deadlock") ; - b[1].blockedonother = 1 ; - case TLS_WANT_POLLIN : return 0 ; - default : break ; - } - w = r ; - if ((size_t)w == v[0].iov_len && v[1].iov_len) + size_t w = 0 ; + int r ; + buffer_rpeek(&b[0].b, v) ; + r = tls_allwrite(ctx, v[0].iov_base, v[0].iov_len, &w) ; + buffer_rseek(&b[0].b, w) ; + if (w < v[0].iov_len || !v[1].iov_len) goto out ; + w = 0 ; + r = tls_allwrite(ctx, v[1].iov_base, v[1].iov_len, &w) ; + buffer_rseek(&b[0].b, w) ; + out: + if (r) b[1].flags |= 1 ; else b[1].flags &= ~1 ; +} + +static int tls_allread (struct tls *ctx, char *s, size_t len, size_t *w) +{ + while (*w < len) { - r = tls_read(ctx, v[1].iov_base, v[1].iov_len) ; + ssize_t r = tls_read(ctx, s + *w, len - *w) ; switch (r) { - case TLS_WANT_POLLOUT : - if (b[0].blockedonother) strerr_dief1x(101, "TLS deadlock") ; - b[1].blockedonother = 1 ; - case -1 : - case TLS_WANT_POLLIN : - buffer_wseek(&b[1].b, w) ; - return 0 ; - case 0 : ok = -1 ; errno = EPIPE ; + case -1 : strerr_diefu2x(98, "tls_read: ", tls_error(ctx)) ; + case 0 : return -1 ; + case TLS_WANT_POLLIN : return 0 ; + case TLS_WANT_POLLOUT : return 1 ; default : break ; } - w += r ; + *w += r ; } - buffer_wseek(&b[1].b, w) ; - return ok ; + return 0 ; } -static void send_closenotify (struct tls *ctx, int const *fds) +static int tls_fill (struct tls *ctx, stls_buffer *b) { - iopause_fd x = { .fd = fds[3], .events = IOPAUSE_WRITE } ; - while (tls_close(ctx) == TLS_WANT_POLLOUT) - iopause_g(&x, 1, 0) ; + struct iovec v[2] ; + size_t w = 0 ; + int r ; + buffer_wpeek(&b[1].b, v) ; + r = tls_allread(ctx, v[0].iov_base, v[0].iov_len, &w) ; + buffer_wseek(&b[1].b, w) ; + if (w < v[0].iov_len || !v[1].iov_len) goto out ; + w = 0 ; + r = tls_allread(ctx, v[1].iov_base, v[1].iov_len, &w) ; + buffer_wseek(&b[1].b, w) ; + out: + if (r == -1) return 1 ; + if (r) b[0].flags |= 1 ; else b[0].flags &= ~1 ; + return 0 ; } -static void closeit (struct tls *ctx, int *fds, int closenotify) +static int tls_close_nb (struct tls *ctx, stls_buffer *b) { - if (!closenotify) fd_shutdown(fds[3], 1) ; - else if (fds[2] >= 0) send_closenotify(ctx, fds) ; - fd_close(fds[3]) ; fds[3] = -1 ; + switch (tls_close(ctx)) + { + case 0 : b[0].flags &= ~2 ; b[1].flags &= ~2 ; b[1].flags |= 4 ; return 1 ; + case TLS_WANT_POLLIN : b[0].flags &= ~2 ; b[1].flags |= 2 ; break ; + case TLS_WANT_POLLOUT : b[0].flags |= 2 ; b[1].flags &= ~2 ; break ; + default : strerr_diefu2x(98, "tls_close: ", tls_error(ctx)) ; + } + return 0 ; } -void stls_run (struct tls *ctx, int *fds, uint32_t options, unsigned int verbosity) + /* The engine. */ + +void stls_run (struct tls *ctx, int const *fds, uint32_t options, unsigned int verbosity) { - tlsbuf_t b[2] = { { .blockedonother = 0 }, { .blockedonother = 0 } } ; - iopause_fd x[4] ; - unsigned int xindex[4] ; - - if (ndelay_on(fds[0]) < 0 - || ndelay_on(fds[1]) < 0 - || ndelay_on(fds[2]) < 0 - || ndelay_on(fds[3]) < 0) + stls_buffer b[2] = + { + { .b = BUFFER_INIT(&buffer_read, fds[0], b[0].buf, STLS_BUFSIZE), .flags = 0 }, + { .b = BUFFER_INIT(&buffer_write, fds[1], b[1].buf, STLS_BUFSIZE), .flags = 0 }, + } ; + iopause_fd x[4] = { { .fd = fds[0] }, { .fd = fds[1] }, { .fd = fds[2] }, { .fd = fds[3] } } ; + + if (ndelay_on(x[0].fd) == -1 + || ndelay_on(x[1].fd) == -1 + || ndelay_on(x[2].fd) == -1 + || ndelay_on(x[3].fd) == -1) strerr_diefu1sys(111, "set fds non-blocking") ; - buffer_init(&b[0].b, &buffer_read, fds[0], b[0].buf, STLS_BUFSIZE) ; - buffer_init(&b[1].b, &buffer_write, fds[1], b[1].buf, STLS_BUFSIZE) ; - - for (;;) + while (x[0].fd >= 0 || x[1].fd >= 0 || x[3].fd >= 0) { - unsigned int j = 0 ; - int r ; - - - /* poll() preparation */ - - if (fds[0] >= 0 && buffer_isreadable(&b[0].b)) - { - x[j].fd = fds[0] ; - x[j].events = IOPAUSE_READ ; - xindex[0] = j++ ; - } - else xindex[0] = 4 ; - - if (fds[1] >= 0 && buffer_iswritable(&b[1].b)) - { - x[j].fd = fds[1] ; - x[j].events = IOPAUSE_WRITE ; - xindex[1] = j++ ; - } - else xindex[1] = 4 ; - - if (fds[2] >= 0 && !b[1].blockedonother && buffer_isreadable(&b[1].b)) - { - x[j].fd = fds[2] ; - x[j].events = IOPAUSE_READ ; - xindex[2] = j++ ; - } - else xindex[2] = 4 ; - - if (fds[3] >= 0 && !b[0].blockedonother && buffer_iswritable(&b[0].b)) - { - x[j].fd = fds[3] ; - x[j].events = IOPAUSE_WRITE ; - xindex[3] = j++ ; - } - else xindex[3] = 4 ; - - if (xindex[0] == 4 && xindex[1] == 4 && xindex[3] == 4) break ; + x[0].events = x[0].fd >= 0 && buffer_isreadable(&b[0].b) ? IOPAUSE_READ : 0 ; + x[1].events = x[1].fd >= 0 && buffer_iswritable(&b[1].b) ? IOPAUSE_WRITE : 0 ; + x[2].events = x[2].fd >= 0 && (buffer_isreadable(&b[1].b) || (b[1].flags & 1 && buffer_iswritable(&b[0].b))) ? IOPAUSE_READ : 0 ; + x[3].events = x[3].fd >= 0 && (buffer_iswritable(&b[0].b) || (b[0].flags & 1 && buffer_isreadable(&b[1].b))) ? IOPAUSE_WRITE : 0 ; - /* poll() */ - - r = iopause_g(x, j, 0) ; - if (r < 0) strerr_diefu1sys(111, "iopause") ; - else if (!r) break ; - - while (j--) - if (x[j].revents & IOPAUSE_EXCEPT) - x[j].revents |= IOPAUSE_READ | IOPAUSE_WRITE ; - + if (iopause_g(x, 4, 0) == -1) strerr_diefu1sys(111, "iopause") ; /* Flush to local */ - if (xindex[1] < 4 && x[xindex[1]].revents & IOPAUSE_WRITE) + if (x[1].revents) { - r = buffer_flush(&b[1].b) ; - if (!r && !error_isagain(errno)) + if (!buffer_flush(&b[1].b)) { - strerr_warnwu1sys("write to application") ; - if (fds[2] >= 0) - { - if (options & 1) fd_shutdown(fds[2], 0) ; - fd_close(fds[2]) ; fds[2] = -1 ; - xindex[2] = 4 ; - } - r = 1 ; + if (!error_isagain(errno)) strerr_diefu1sys(111, "write to local") ; } - if (r && fds[2] < 0) + else if (x[2].fd == -1) { - fd_close(fds[1]) ; fds[1] = -1 ; + fd_close(x[1].fd) ; + x[1].fd = -1 ; } } - /* Flush to remote */ + /* Flush to remote: do everything that had TLS_WANT_POLLOUT */ - if (xindex[3] < 4 && x[xindex[3]].revents & IOPAUSE_WRITE) + if (x[3].revents) { - r = buffer_tls_flush(ctx, b) ; - if (r < 0) + if (buffer_len(&b[0].b)) tls_flush(ctx, b) ; /* normal write */ + if ((b[0].flags & 1 && tls_fill(ctx, b)) /* peer sent close_notify and it just completed */ + || (b[0].flags & 2 && tls_close_nb(ctx, b))) /* we send close_notify and it instantly succeeds */ { - strerr_warnwu2x("write to peer: ", tls_error(ctx)) ; - fd_close(fds[0]) ; fds[0] = -1 ; - xindex[0] = 4 ; + if (buffer_isempty(&b[1].b)) break ; + fd_close(x[3].fd) ; x[3].fd = -1 ; + fd_close(x[2].fd) ; x[2].fd = -1 ; + if (x[0].fd >= 0) { fd_close(x[0].fd) ; x[0].fd = -1 ; } + continue ; + } + if (x[0].fd == -1 && buffer_isempty(&b[0].b)) + { + if (!(options & 1) || tls_close_nb(ctx, b)) + { + fd_shutdown(x[3].fd, 1) ; + fd_close(x[3].fd) ; + x[3].fd = -1 ; + } } - if (r && fds[0] < 0) - closeit(ctx, fds, options & 1) ; } /* Fill from local */ - if (xindex[0] < 4 && x[xindex[0]].revents & IOPAUSE_READ) + if (x[0].revents) { - r = sanitize_read(buffer_fill(&b[0].b)) ; - if (r < 0) + ssize_t r = buffer_fill(&b[0].b) ; + if (r == -1 && !error_isagain(errno)) + strerr_diefu1sys(111, "read from local") ; + else if (!r) { - if (errno != EPIPE) strerr_warnwu1sys("read from application") ; - fd_close(fds[0]) ; fds[0] = -1 ; + fd_close(x[0].fd) ; + x[0].fd = -1 ; if (buffer_isempty(&b[0].b)) - closeit(ctx, fds, options & 1) ; + { + if (!(options & 1) || tls_close_nb(ctx, b)) + { + fd_shutdown(x[3].fd, 1) ; + fd_close(x[3].fd) ; + x[3].fd = -1 ; + } + } } } - /* Fill from remote */ + /* Fill from remote: do everything that had TLS_WANT_POLLIN */ - if (xindex[2] < 4 && x[xindex[2]].revents & IOPAUSE_READ) + if (x[2].revents) { - r = buffer_tls_fill(ctx, b) ; - if (r < 0) - { - if (r == -1) strerr_warnwu2x("read from peer: ", tls_error(ctx)) ; - if (!(options & 1)) fd_shutdown(fds[2], 0) ; - /* - XXX: We need a way to detect when we've received a close_notify, - because then we need to trigger a write and then shut the engine - down. This is orthogonal to options&1, it only means that the - peer sent a close_notify. - As for now, libtls doesn't offer an API to detect that, so we - do nothing special - we just wait until our app sends EOF. - */ - fd_close(fds[2]) ; fds[2] = -1 ; + if (buffer_isreadable(&b[1].b) && tls_fill(ctx, b)) + { /* connection closed */ + fd_shutdown(x[2].fd, 0) ; + fd_close(x[2].fd) ; + x[2].fd = -1 ; if (buffer_isempty(&b[1].b)) { - fd_close(fds[1]) ; fds[1] = -1 ; + if (tls_eof_got_close_notify(ctx)) break ; + fd_close(x[1].fd) ; + x[1].fd = -1 ; + } + if (options & 2) + { + if (!tls_eof_got_close_notify(ctx)) + strerr_dief1x(98, "remote closed connection without a close_notify") ; + else if (x[3].fd >= 0) + { + fd_shutdown(x[3].fd, 1) ; + fd_close(x[3].fd) ; + x[3].fd = -1 ; + } + } + } + else + { /* normal case */ + if (b[1].flags & 1) tls_flush(ctx, b) ; + if (b[1].flags & 2 && tls_close_nb(ctx, b)) + { + if (buffer_isempty(&b[1].b)) break ; + if (x[3].fd >= 0) { fd_close(x[3].fd) ; x[3].fd = -1 ; } + if (x[0].fd >= 0) { fd_close(x[0].fd) ; x[0].fd = -1 ; } + fd_close(x[2].fd) ; x[2].fd = -1 ; } } } |