Use a tailq of write buffers instead of a single one per connection.

This allows us to queue up multiple messages for writing like the
sudoers client supports.  Currently, each connection has its own
free list.  In the future we may want a single free list with low
and high water marks.
This commit is contained in:
Todd C. Miller
2021-04-06 14:30:16 -06:00
parent e3ff4e663c
commit 7bb5eef9d9
3 changed files with 76 additions and 28 deletions

View File

@@ -27,11 +27,13 @@
#define MESSAGE_SIZE_MAX (2 * 1024 * 1024) #define MESSAGE_SIZE_MAX (2 * 1024 * 1024)
struct connection_buffer { struct connection_buffer {
TAILQ_ENTRY(connection_buffer) entries;
uint8_t *data; uint8_t *data;
unsigned int size; unsigned int size;
unsigned int len; unsigned int len;
unsigned int off; unsigned int off;
}; };
TAILQ_HEAD(connection_buffer_list, connection_buffer);
/* logsrv_util.c */ /* logsrv_util.c */
struct iolog_file; struct iolog_file;

View File

@@ -108,6 +108,7 @@ connection_closure_free(struct connection_closure *closure)
if (closure != NULL) { if (closure != NULL) {
bool shutting_down = closure->state == SHUTDOWN; bool shutting_down = closure->state == SHUTDOWN;
struct sudo_event_base *evbase = closure->evbase; struct sudo_event_base *evbase = closure->evbase;
struct connection_buffer *buf;
TAILQ_REMOVE(&connections, closure, entries); TAILQ_REMOVE(&connections, closure, entries);
#if defined(HAVE_OPENSSL) #if defined(HAVE_OPENSSL)
@@ -126,7 +127,16 @@ connection_closure_free(struct connection_closure *closure)
#endif #endif
eventlog_free(closure->evlog); eventlog_free(closure->evlog);
free(closure->read_buf.data); free(closure->read_buf.data);
free(closure->write_buf.data); while ((buf = TAILQ_FIRST(&closure->write_bufs)) != NULL) {
TAILQ_REMOVE(&closure->write_bufs, buf, entries);
free(buf->data);
free(buf);
}
while ((buf = TAILQ_FIRST(&closure->free_bufs)) != NULL) {
TAILQ_REMOVE(&closure->free_bufs, buf, entries);
free(buf->data);
free(buf);
}
free(closure); free(closure);
if (shutting_down && TAILQ_EMPTY(&connections)) if (shutting_down && TAILQ_EMPTY(&connections))
@@ -136,18 +146,34 @@ connection_closure_free(struct connection_closure *closure)
debug_return; debug_return;
} }
static bool static struct connection_buffer *
fmt_server_message(struct connection_buffer *buf, ServerMessage *msg) get_free_buf(struct connection_closure *closure)
{ {
struct connection_buffer *buf;
debug_decl(get_free_buf, SUDO_DEBUG_UTIL);
buf = TAILQ_FIRST(&closure->free_bufs);
if (buf != NULL)
TAILQ_REMOVE(&closure->free_bufs, buf, entries);
else
buf = calloc(1, sizeof(*buf));
debug_return_ptr(buf);
}
static bool
fmt_server_message(struct connection_closure *closure, ServerMessage *msg)
{
struct connection_buffer *buf;
uint32_t msg_len; uint32_t msg_len;
bool ret = false; bool ret = false;
size_t len; size_t len;
debug_decl(fmt_server_message, SUDO_DEBUG_UTIL); debug_decl(fmt_server_message, SUDO_DEBUG_UTIL);
if (buf->len != 0) { if ((buf = get_free_buf(closure)) == NULL) {
sudo_debug_printf(SUDO_DEBUG_ERROR|SUDO_DEBUG_LINENO, sudo_debug_printf(SUDO_DEBUG_ERROR|SUDO_DEBUG_LINENO,
"pending write, unable to format ServerMessage"); "unable to allocate connection_buffer");
debug_return_bool(false); goto done;
} }
len = server_message__get_packed_size(msg); len = server_message__get_packed_size(msg);
@@ -178,14 +204,21 @@ fmt_server_message(struct connection_buffer *buf, ServerMessage *msg)
memcpy(buf->data, &msg_len, sizeof(msg_len)); memcpy(buf->data, &msg_len, sizeof(msg_len));
server_message__pack(msg, buf->data + sizeof(msg_len)); server_message__pack(msg, buf->data + sizeof(msg_len));
buf->len = len; buf->len = len;
TAILQ_INSERT_TAIL(&closure->write_bufs, buf, entries);
buf = NULL;
ret = true; ret = true;
done: done:
if (buf != NULL) {
free(buf->data);
free(buf);
}
debug_return_bool(ret); debug_return_bool(ret);
} }
static bool static bool
fmt_hello_message(struct connection_buffer *buf) fmt_hello_message(struct connection_closure *closure)
{ {
ServerMessage msg = SERVER_MESSAGE__INIT; ServerMessage msg = SERVER_MESSAGE__INIT;
ServerHello hello = SERVER_HELLO__INIT; ServerHello hello = SERVER_HELLO__INIT;
@@ -196,11 +229,11 @@ fmt_hello_message(struct connection_buffer *buf)
msg.u.hello = &hello; msg.u.hello = &hello;
msg.type_case = SERVER_MESSAGE__TYPE_HELLO; msg.type_case = SERVER_MESSAGE__TYPE_HELLO;
debug_return_bool(fmt_server_message(buf, &msg)); debug_return_bool(fmt_server_message(closure, &msg));
} }
static bool static bool
fmt_log_id_message(const char *id, struct connection_buffer *buf) fmt_log_id_message(const char *id, struct connection_closure *closure)
{ {
ServerMessage msg = SERVER_MESSAGE__INIT; ServerMessage msg = SERVER_MESSAGE__INIT;
debug_decl(fmt_log_id_message, SUDO_DEBUG_UTIL); debug_decl(fmt_log_id_message, SUDO_DEBUG_UTIL);
@@ -208,11 +241,11 @@ fmt_log_id_message(const char *id, struct connection_buffer *buf)
msg.u.log_id = (char *)id; msg.u.log_id = (char *)id;
msg.type_case = SERVER_MESSAGE__TYPE_LOG_ID; msg.type_case = SERVER_MESSAGE__TYPE_LOG_ID;
debug_return_bool(fmt_server_message(buf, &msg)); debug_return_bool(fmt_server_message(closure, &msg));
} }
static bool static bool
fmt_error_message(const char *errstr, struct connection_buffer *buf) fmt_error_message(const char *errstr, struct connection_closure *closure)
{ {
ServerMessage msg = SERVER_MESSAGE__INIT; ServerMessage msg = SERVER_MESSAGE__INIT;
debug_decl(fmt_error_message, SUDO_DEBUG_UTIL); debug_decl(fmt_error_message, SUDO_DEBUG_UTIL);
@@ -220,7 +253,7 @@ fmt_error_message(const char *errstr, struct connection_buffer *buf)
msg.u.error = (char *)errstr; msg.u.error = (char *)errstr;
msg.type_case = SERVER_MESSAGE__TYPE_ERROR; msg.type_case = SERVER_MESSAGE__TYPE_ERROR;
debug_return_bool(fmt_server_message(buf, &msg)); debug_return_bool(fmt_server_message(closure, &msg));
} }
struct logsrvd_info_closure { struct logsrvd_info_closure {
@@ -328,7 +361,7 @@ handle_accept(AcceptMessage *msg, struct connection_closure *closure)
if (msg->expect_iobufs) { if (msg->expect_iobufs) {
/* Send log ID to client for restarting connections. */ /* Send log ID to client for restarting connections. */
if (!fmt_log_id_message(closure->evlog->iolog_path, &closure->write_buf)) if (!fmt_log_id_message(closure->evlog->iolog_path, closure))
debug_return_bool(false); debug_return_bool(false);
if (sudo_ev_add(closure->evbase, closure->write_ev, if (sudo_ev_add(closure->evbase, closure->write_ev,
logsrvd_conf_get_sock_timeout(), false) == -1) { logsrvd_conf_get_sock_timeout(), false) == -1) {
@@ -459,7 +492,7 @@ handle_restart(RestartMessage *msg, struct connection_closure *closure)
if (!iolog_restart(msg, closure)) { if (!iolog_restart(msg, closure)) {
sudo_debug_printf(SUDO_DEBUG_WARN, "%s: unable to restart I/O log", __func__); sudo_debug_printf(SUDO_DEBUG_WARN, "%s: unable to restart I/O log", __func__);
/* XXX - structured error message so client can send from beginning */ /* XXX - structured error message so client can send from beginning */
if (!fmt_error_message(closure->errstr, &closure->write_buf)) if (!fmt_error_message(closure->errstr, closure))
debug_return_bool(false); debug_return_bool(false);
sudo_ev_del(closure->evbase, closure->read_ev); sudo_ev_del(closure->evbase, closure->read_ev);
if (sudo_ev_add(closure->evbase, closure->write_ev, if (sudo_ev_add(closure->evbase, closure->write_ev,
@@ -784,7 +817,7 @@ static void
server_msg_cb(int fd, int what, void *v) server_msg_cb(int fd, int what, void *v)
{ {
struct connection_closure *closure = v; struct connection_closure *closure = v;
struct connection_buffer *buf = &closure->write_buf; struct connection_buffer *buf;
ssize_t nwritten; ssize_t nwritten;
debug_decl(server_msg_cb, SUDO_DEBUG_UTIL); debug_decl(server_msg_cb, SUDO_DEBUG_UTIL);
@@ -802,12 +835,18 @@ server_msg_cb(int fd, int what, void *v)
if (what == SUDO_EV_TIMEOUT) { if (what == SUDO_EV_TIMEOUT) {
sudo_debug_printf(SUDO_DEBUG_ERROR|SUDO_DEBUG_LINENO, sudo_debug_printf(SUDO_DEBUG_ERROR|SUDO_DEBUG_LINENO,
"Writing to client timed out"); "timed out writing to client (%s)", closure->ipaddr);
goto finished; goto finished;
} }
sudo_debug_printf(SUDO_DEBUG_INFO, "%s: sending %u bytes to client", if ((buf = TAILQ_FIRST(&closure->write_bufs)) == NULL) {
__func__, buf->len - buf->off); sudo_debug_printf(SUDO_DEBUG_ERROR|SUDO_DEBUG_LINENO,
"missing write buffer");
goto finished;
}
sudo_debug_printf(SUDO_DEBUG_INFO, "%s: sending %u bytes to client (%s)",
__func__, buf->len - buf->off, closure->ipaddr);
#if defined(HAVE_OPENSSL) #if defined(HAVE_OPENSSL)
if (closure->ssl != NULL) { if (closure->ssl != NULL) {
@@ -854,15 +893,20 @@ server_msg_cb(int fd, int what, void *v)
buf->off += nwritten; buf->off += nwritten;
if (buf->off == buf->len) { if (buf->off == buf->len) {
/* sent entire message */ /* sent entire message, move buf to free list */
sudo_debug_printf(SUDO_DEBUG_INFO, sudo_debug_printf(SUDO_DEBUG_INFO,
"%s: finished sending %u bytes to client", __func__, buf->len); "%s: finished sending %u bytes to client", __func__, buf->len);
buf->off = 0; buf->off = 0;
buf->len = 0; buf->len = 0;
sudo_ev_del(closure->evbase, closure->write_ev); TAILQ_REMOVE(&closure->write_bufs, buf, entries);
if (closure->state == FINISHED || closure->state == SHUTDOWN || TAILQ_INSERT_TAIL(&closure->free_bufs, buf, entries);
closure->state == ERROR) if (TAILQ_EMPTY(&closure->write_bufs)) {
goto finished; /* Write queue empty, check state. */
sudo_ev_del(closure->evbase, closure->write_ev);
if (closure->state == FINISHED || closure->state == SHUTDOWN ||
closure->state == ERROR)
goto finished;
}
} }
debug_return; debug_return;
@@ -1007,7 +1051,7 @@ client_msg_cb(int fd, int what, void *v)
send_error: send_error:
if (closure->errstr == NULL) if (closure->errstr == NULL)
goto finished; goto finished;
if (fmt_error_message(closure->errstr, &closure->write_buf)) { if (fmt_error_message(closure->errstr, closure)) {
sudo_ev_del(closure->evbase, closure->read_ev); sudo_ev_del(closure->evbase, closure->read_ev);
if (sudo_ev_add(closure->evbase, closure->write_ev, if (sudo_ev_add(closure->evbase, closure->write_ev,
logsrvd_conf_get_sock_timeout(), false) == -1) { logsrvd_conf_get_sock_timeout(), false) == -1) {
@@ -1042,8 +1086,7 @@ server_commit_cb(int unused, int what, void *v)
__func__, (long long)closure->elapsed_time.tv_sec, __func__, (long long)closure->elapsed_time.tv_sec,
closure->elapsed_time.tv_nsec); closure->elapsed_time.tv_nsec);
/* XXX - assumes no other server message pending, use a queue instead? */ if (!fmt_server_message(closure, &msg)) {
if (!fmt_server_message(&closure->write_buf, &msg)) {
sudo_debug_printf(SUDO_DEBUG_ERROR|SUDO_DEBUG_LINENO, sudo_debug_printf(SUDO_DEBUG_ERROR|SUDO_DEBUG_LINENO,
"unable to format ServerMessage (commit point)"); "unable to format ServerMessage (commit point)");
goto bad; goto bad;
@@ -1074,7 +1117,7 @@ start_protocol(struct connection_closure *closure)
const struct timespec *timeout = logsrvd_conf_get_sock_timeout(); const struct timespec *timeout = logsrvd_conf_get_sock_timeout();
debug_decl(start_protocol, SUDO_DEBUG_UTIL); debug_decl(start_protocol, SUDO_DEBUG_UTIL);
if (!fmt_hello_message(&closure->write_buf)) if (!fmt_hello_message(closure))
debug_return_bool(false); debug_return_bool(false);
if (sudo_ev_add(closure->evbase, closure->write_ev, timeout, false) == -1) if (sudo_ev_add(closure->evbase, closure->write_ev, timeout, false) == -1)
@@ -1506,6 +1549,8 @@ connection_closure_alloc(int sock, bool tls, struct sudo_event_base *base)
closure->iolog_dir_fd = -1; closure->iolog_dir_fd = -1;
closure->sock = sock; closure->sock = sock;
closure->evbase = base; closure->evbase = base;
TAILQ_INIT(&closure->write_bufs);
TAILQ_INIT(&closure->free_bufs);
TAILQ_INSERT_TAIL(&connections, closure, entries); TAILQ_INSERT_TAIL(&connections, closure, entries);

View File

@@ -61,7 +61,8 @@ struct connection_closure {
struct eventlog *evlog; struct eventlog *evlog;
struct timespec elapsed_time; struct timespec elapsed_time;
struct connection_buffer read_buf; struct connection_buffer read_buf;
struct connection_buffer write_buf; struct connection_buffer_list write_bufs;
struct connection_buffer_list free_bufs;
struct sudo_event_base *evbase; struct sudo_event_base *evbase;
struct sudo_event *commit_ev; struct sudo_event *commit_ev;
struct sudo_event *read_ev; struct sudo_event *read_ev;