From 79e23317fff58c2c7eaace6df781bccc08775f28 Mon Sep 17 00:00:00 2001 From: Simon Rettberg Date: Wed, 31 Mar 2021 12:31:13 +0200 Subject: [KERNEL] Deduplicate code, clean up, split into functions --- src/kernel/net.c | 732 ++++++++++++++++++++++++++----------------------------- src/kernel/net.h | 9 - 2 files changed, 339 insertions(+), 402 deletions(-) diff --git a/src/kernel/net.c b/src/kernel/net.c index 3c19f8d..b07e8dc 100644 --- a/src/kernel/net.c +++ b/src/kernel/net.c @@ -39,19 +39,28 @@ #endif #ifdef CONFIG_DEBUG_DRIVER -#define ASSERT(x) \ - do { \ - if (!(x)) { \ - printk(KERN_EMERG "assertion failed %s: %d: %s\n", __FILE__, __LINE__, #x); \ - BUG(); \ - } \ +#define ASSERT(x) \ + do { \ + if (!(x)) { \ + printk(KERN_EMERG "assertion failed %s: %d: %s\n", __FILE__, __LINE__, #x); \ + BUG(); \ + } \ } while (0) #else -#define ASSERT(x) \ - do { \ +#define ASSERT(x) \ + do { \ } while (0) #endif +#define init_msghdr(h) \ + do { \ + h.msg_name = NULL; \ + h.msg_namelen = 0; \ + h.msg_control = NULL; \ + h.msg_controllen = 0; \ + h.msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \ + } while (0) + // cmd_flags and cmd_type are merged into cmd_flags now #if REQ_FLAG_BITS > 24 #error "Fix CMD bitshift" @@ -63,18 +72,23 @@ #define DNBD3_DEV_READ REQ_OP_READ #define DNBD3_REQ_OP_SPECIAL REQ_OP_DRV_IN +#define dnbd3_dev_dbg_host(dev, host, fmt, ...) \ + dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__) +#define dnbd3_dev_err_host(dev, host, fmt, ...) \ + dev_err(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__) + #define dnbd3_dev_dbg_host_cur(dev, fmt, ...) \ - dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->cur_server.host, ##__VA_ARGS__) + dnbd3_dev_dbg_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__) #define dnbd3_dev_err_host_cur(dev, fmt, ...) \ - dev_err(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->cur_server.host, ##__VA_ARGS__) - -#define dnbd3_dev_dbg_host_alt(dev, fmt, ...) \ - dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->alt_servers[i].host, ##__VA_ARGS__) -#define dnbd3_dev_err_host_alt(dev, fmt, ...) \ - dev_err(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->alt_servers[i].host, ##__VA_ARGS__) + dnbd3_dev_err_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__) static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr); +static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, + struct sockaddr_storage *addr, uint16_t *remote_version); + +int dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock); + static void dnbd3_net_heartbeat(struct timer_list *arg) { dnbd3_device_t *dev = (dnbd3_device_t *)container_of(arg, dnbd3_device_t, hb_timer); @@ -122,23 +136,10 @@ static void dnbd3_net_heartbeat(struct timer_list *arg) static int dnbd3_net_discover(void *data) { dnbd3_device_t *dev = data; - struct sockaddr_in sin4; - struct sockaddr_in6 sin6; struct socket *sock, *best_sock = NULL; - - dnbd3_request_t dnbd3_request; - dnbd3_reply_t dnbd3_reply; dnbd3_alt_server_t *alt; struct sockaddr_storage host_compare, best_server; - struct msghdr msg; - struct kvec iov[2]; - - char *buf, *name; - serialized_buffer_t *payload; - uint64_t filesize; - uint16_t rid; uint16_t remote_version; - ktime_t start = 0, end = 0; unsigned long rtt, best_rtt = 0; unsigned long irqflags; @@ -146,25 +147,9 @@ static int dnbd3_net_discover(void *data) int turn = 0; int ready = 0, do_change = 0; char check_order[NUMBER_SERVERS]; - int mlen; struct request *last_request = (struct request *)123, *cur_request = (struct request *)456; - memset(&sin4, 0, sizeof(sin4)); - memset(&sin6, 0, sizeof(sin6)); - - init_msghdr(msg); - - BUILD_BUG_ON(sizeof(serialized_buffer_t) > DNBD3_BLOCK_SIZE); - - buf = kmalloc(DNBD3_BLOCK_SIZE, GFP_KERNEL); - if (!buf) - return -ENOMEM; - - payload = (serialized_buffer_t *)buf; // Reuse this buffer to save kernel mem - - dnbd3_request.magic = dnbd3_packet_magic; - for (i = 0; i < NUMBER_SERVERS; ++i) check_order[i] = i; @@ -194,9 +179,9 @@ static int dnbd3_net_discover(void *data) for (i = 0; i < isize; ++i) { j = ((ktime_to_s(start) >> i) ^ (ktime_to_us(start) >> j)) % NUMBER_SERVERS; if (j != i) { - mlen = check_order[i]; + int tmp = check_order[i]; check_order[i] = check_order[j]; - check_order[j] = mlen; + check_order[j] = tmp; } } } @@ -220,93 +205,13 @@ static int dnbd3_net_discover(void *data) if (sock == NULL) goto error; - // Request filesize - dnbd3_request.cmd = CMD_SELECT_IMAGE; - iov[0].iov_base = &dnbd3_request; - iov[0].iov_len = sizeof(dnbd3_request); - serializer_reset_write(payload); - serializer_put_uint16(payload, PROTOCOL_VERSION); // DNBD3 protocol version - serializer_put_string(payload, dev->imgname); // image name - serializer_put_uint16(payload, dev->rid); // revision id - serializer_put_uint8(payload, 0); // are we a server? (no!) - iov[1].iov_base = payload; - dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(payload); - fixup_request(dnbd3_request); - mlen = iov[1].iov_len + sizeof(dnbd3_request); - if (kernel_sendmsg(sock, &msg, iov, 2, mlen) != mlen) { - dnbd3_dev_err_host_alt(dev, "requesting image size failed\n"); + if (!dnbd3_execute_handshake(dev, sock, &host_compare, &remote_version)) goto error; - } - // receive net reply - iov[0].iov_base = &dnbd3_reply; - iov[0].iov_len = sizeof(dnbd3_reply); - if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != - sizeof(dnbd3_reply)) { - dnbd3_dev_err_host_alt(dev, "receiving image size packet (header) failed (discover)\n"); - goto error; - } - fixup_reply(dnbd3_reply); - if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_SELECT_IMAGE || - dnbd3_reply.size < 4) { - dnbd3_dev_err_host_alt(dev, - "content of image size packet (header) mismatched (discover)\n"); - goto error; - } - - // receive data - iov[0].iov_base = payload; - iov[0].iov_len = dnbd3_reply.size; - if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) { - dnbd3_dev_err_host_alt(dev, - "receiving image size packet (payload) failed (discover)\n"); - goto error; - } - serializer_reset_read(payload, dnbd3_reply.size); - - remote_version = serializer_get_uint16(payload); - if (remote_version < MIN_SUPPORTED_SERVER) { - dnbd3_dev_err_host_alt( - dev, "server version too old (client: %d, server: %d, min supported: %d)\n", - (int)PROTOCOL_VERSION, (int)remote_version, - (int)MIN_SUPPORTED_SERVER); - goto error; - } - - name = serializer_get_string(payload); - if (name == NULL) { - dnbd3_dev_err_host_alt(dev, "server did not supply an image name (discover)\n"); - goto error; - } - - if (strcmp(name, dev->imgname) != 0) { - dnbd3_dev_err_host_alt( - dev, - "image name does not match requested one (client: '%s', server: '%s') (discover)\n", - dev->imgname, name); - goto error; - } - - rid = serializer_get_uint16(payload); - if (rid != dev->rid) { - dnbd3_dev_err_host_alt( - dev, "server supplied wrong rid (client: '%d', server: '%d') (discover)\n", - (int)dev->rid, (int)rid); - goto error; - } - - filesize = serializer_get_uint64(payload); - if (filesize != dev->reported_size) { - dnbd3_dev_err_host_alt( - dev, - "reported image size of %llu does not match expected value %llu (discover)\n", - (unsigned long long)filesize, (unsigned long long)dev->reported_size); - goto error; - } // panic mode, take first responding server if (dev->panic) { - dnbd3_dev_dbg_host_alt(dev, "panic mode, changing server ...\n"); + dnbd3_dev_dbg_host(dev, &host_compare, "panic mode, changing to new server\n"); while (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0) schedule(); @@ -317,7 +222,6 @@ static int dnbd3_net_discover(void *data) sock_release(best_sock); dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect(); - kfree(buf); put_task_struct(dev->thread_discover); dev->thread_discover = NULL; dnbd3_net_disconnect(dev); @@ -331,48 +235,11 @@ static int dnbd3_net_discover(void *data) atomic_set(&dev->connection_lock, 0); } - // Request block - dnbd3_request.cmd = CMD_GET_BLOCK; - // Do *NOT* pick a random block as it has proven to cause severe - // cache thrashing on the server - dnbd3_request.offset = 0; - dnbd3_request.size = RTT_BLOCK_SIZE; - fixup_request(dnbd3_request); - iov[0].iov_base = &dnbd3_request; - iov[0].iov_len = sizeof(dnbd3_request); - // start rtt measurement start = ktime_get_real(); - if (kernel_sendmsg(sock, &msg, iov, 1, sizeof(dnbd3_request)) <= 0) { - dnbd3_dev_err_host_alt(dev, "requesting test block failed (discover)\n"); - goto error; - } - - // receive net reply - iov[0].iov_base = &dnbd3_reply; - iov[0].iov_len = sizeof(dnbd3_reply); - if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != - sizeof(dnbd3_reply)) { - dnbd3_dev_err_host_alt(dev, "receiving test block header packet failed (discover)\n"); - goto error; - } - fixup_reply(dnbd3_reply); - if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_GET_BLOCK || - dnbd3_reply.size != RTT_BLOCK_SIZE) { - dnbd3_dev_err_host_alt( - dev, "unexpected reply to block request: cmd=%d, size=%d (discover)\n", - (int)dnbd3_reply.cmd, (int)dnbd3_reply.size); - goto error; - } - - // receive data - iov[0].iov_base = buf; - iov[0].iov_len = RTT_BLOCK_SIZE; - if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != RTT_BLOCK_SIZE) { - dnbd3_dev_err_host_alt(dev, "receiving test block payload failed (discover)\n"); + if (!dnbd3_request_test_block(dev, &host_compare, sock)) goto error; - } end = ktime_get_real(); // end rtt measurement @@ -481,7 +348,6 @@ error: dev_info(dnbd3_device_to_dev(dev), "server %pISpc is faster (%lluµs vs. %lluµs)\n", &best_server, (unsigned long long)best_rtt, (unsigned long long)dev->cur_server.rtt); - kfree(buf); dev->better_sock = best_sock; // Take shortcut by continuing to use open connection put_task_struct(dev->thread_discover); dev->thread_discover = NULL; @@ -508,7 +374,6 @@ error: ready = 1; } - kfree(buf); if (kthread_should_stop()) dev_dbg(dnbd3_device_to_dev(dev), "kthread %s terminated normally\n", __func__); else @@ -880,6 +745,7 @@ static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage ktime_t start; int ret, connect_time_ms; struct socket *sock; + int retries = 4; if (sock_create_kern(&init_net, addr->ss_family, SOCK_STREAM, IPPROTO_TCP, &sock) < 0) { dev_err(dnbd3_device_to_dev(dev), "couldn't create socket\n"); @@ -914,40 +780,296 @@ static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage connect_time_ms = SOCKET_TIMEOUT_CLIENT_DATA * 1000; set_socket_timeouts(sock, connect_time_ms); start = ktime_get_real(); - ret = kernel_connect(sock, (struct sockaddr *)addr, sizeof(*addr), 0); - connect_time_ms = (int)ktime_ms_delta(ktime_get_real(), start); - if (connect_time_ms > 2 * SOCKET_TIMEOUT_CLIENT_DATA * 1000) { - /* Either I'm losing my mind or there was a specific build of kernel - * 5.x where SO_RCVTIMEO didn't affect the connect call above, so - * this function would hang for over a minute for unreachable hosts. - * Leave in this debug check for twice the configured timeout - */ - dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect call took %dms\n", - addr, connect_time_ms); - } - if (ret != 0) { - dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect failed (%d, blocked %dms)\n", - addr, ret, connect_time_ms); - goto error; + while (--retries > 0) { + ret = kernel_connect(sock, (struct sockaddr *)addr, sizeof(*addr), 0); + connect_time_ms = (int)ktime_ms_delta(ktime_get_real(), start); + if (connect_time_ms > 2 * SOCKET_TIMEOUT_CLIENT_DATA * 1000) { + /* Either I'm losing my mind or there was a specific build of kernel + * 5.x where SO_RCVTIMEO didn't affect the connect call above, so + * this function would hang for over a minute for unreachable hosts. + * Leave in this debug check for twice the configured timeout + */ + dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect call took %dms\n", + addr, connect_time_ms); + } + if (ret != 0) { + if (ret == -EINTR) + continue; + dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect failed (%d, blocked %dms)\n", + addr, ret, connect_time_ms); + goto error; + } + return sock; } - return sock; error: sock_release(sock); return NULL; } +/** + * Execute protocol handshake on a newly connected socket. + * If this is the initial connection to any server, ie. we're being called + * through the initial ioctl() to open a device, we'll store the rid, filesize + * etc. in the dev struct., otherwise, this is a potential switch to another + * server, so we validate the filesize, rid, name against what we expect. + * The server's protocol version is returned in 'remote_version' + */ +static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, + struct sockaddr_storage *addr, uint16_t *remote_version) +{ + const char *name; + uint64_t filesize; + int mlen; + uint16_t rid, initial_connect; + struct msghdr msg; + struct kvec iov[2]; + serialized_buffer_t *payload; + dnbd3_reply_t dnbd3_reply; + dnbd3_request_t dnbd3_request = { .magic = dnbd3_packet_magic }; + + payload = kmalloc(sizeof(*payload), GFP_KERNEL); + if (payload == NULL) + goto error; + + initial_connect = (dev->reported_size == 0); + init_msghdr(msg); + // Request filesize + dnbd3_request.cmd = CMD_SELECT_IMAGE; + iov[0].iov_base = &dnbd3_request; + iov[0].iov_len = sizeof(dnbd3_request); + serializer_reset_write(payload); + serializer_put_uint16(payload, PROTOCOL_VERSION); // DNBD3 protocol version + serializer_put_string(payload, dev->imgname); // image name + serializer_put_uint16(payload, dev->rid); // revision id + serializer_put_uint8(payload, 0); // are we a server? (no!) + iov[1].iov_base = payload; + dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(payload); + fixup_request(dnbd3_request); + mlen = iov[0].iov_len + iov[1].iov_len; + if (kernel_sendmsg(sock, &msg, iov, 2, mlen) != mlen) { + dnbd3_dev_err_host(dev, addr, "requesting image size failed\n"); + goto error; + } + + // receive net reply + iov[0].iov_base = &dnbd3_reply; + iov[0].iov_len = sizeof(dnbd3_reply); + if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply)) { + dnbd3_dev_err_host(dev, addr, "receiving image size packet (header) failed\n"); + goto error; + } + fixup_reply(dnbd3_reply); + if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_SELECT_IMAGE || dnbd3_reply.size < 4) { + dnbd3_dev_err_host(dev, addr, + "corrupted CMD_SELECT_IMAGE reply\n"); + goto error; + } + + // receive data + iov[0].iov_base = payload; + iov[0].iov_len = dnbd3_reply.size; + if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) { + dnbd3_dev_err_host(dev, addr, + "receiving payload of CMD_SELECT_IMAGE reply failed\n"); + goto error; + } + serializer_reset_read(payload, dnbd3_reply.size); + + *remote_version = serializer_get_uint16(payload); + name = serializer_get_string(payload); + rid = serializer_get_uint16(payload); + filesize = serializer_get_uint64(payload); + + if (*remote_version < MIN_SUPPORTED_SERVER) { + dnbd3_dev_err_host(dev, addr, + "server version too old (client: %d, server: %d, min supported: %d)\n", + (int)PROTOCOL_VERSION, (int)*remote_version, + (int)MIN_SUPPORTED_SERVER); + goto error; + } + + if (name == NULL) { + dnbd3_dev_err_host(dev, addr, "server did not supply an image name\n"); + goto error; + } + if (rid == 0) { + dnbd3_dev_err_host(dev, addr, "server did not supply a revision id\n"); + goto error; + } + + /* only check image name if this isn't the initial connect */ + if (initial_connect && dev->rid != 0 && strcmp(name, dev->imgname) != 0) { + dnbd3_dev_err_host(dev, addr, "server offers image '%s', requested '%s'\n", name, dev->imgname); + goto error; + } + + if (initial_connect) { + if (filesize < DNBD3_BLOCK_SIZE) { + dnbd3_dev_err_host(dev, addr, "reported size by server is < 4096\n"); + goto error; + } + if (strlen(dev->imgname) < strlen(name)) { + dev->imgname = krealloc(dev->imgname, strlen(name) + 1, GFP_KERNEL); + if (dev->imgname == NULL) { + dnbd3_dev_err_host(dev, addr, "reallocating buffer for new image name failed\n"); + goto error; + } + } + strcpy(dev->imgname, name); + dev->rid = rid; + // store image information + dev->reported_size = filesize; + set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */ + dnbd3_dev_dbg_host(dev, addr, "image size: %llu\n", dev->reported_size); + dev->update_available = 0; + } else { + /* switching connection, sanity checks */ + if (rid != dev->rid) { + dnbd3_dev_err_host(dev, addr, + "server supplied wrong rid (client: '%d', server: '%d')\n", + (int)dev->rid, (int)rid); + goto error; + } + + if (filesize != dev->reported_size) { + dnbd3_dev_err_host(dev, addr, + "reported image size of %llu does not match expected value %llu\n", + (unsigned long long)filesize, (unsigned long long)dev->reported_size); + goto error; + } + } + kfree(payload); + return 1; + +error: + kfree(payload); + return 0; +} + +int dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock) +{ + dnbd3_request_t dnbd3_request = { .magic = dnbd3_packet_magic }; + dnbd3_reply_t dnbd3_reply; + struct kvec iov; + struct msghdr msg; + char *buf = NULL; + char smallbuf[256]; + int remaining, buffer_size, ret, func_return; + + init_msghdr(msg); + + func_return = 0; + // Request block + dnbd3_request.cmd = CMD_GET_BLOCK; + // Do *NOT* pick a random block as it has proven to cause severe + // cache thrashing on the server + dnbd3_request.offset = 0; + dnbd3_request.size = RTT_BLOCK_SIZE; + fixup_request(dnbd3_request); + iov.iov_base = &dnbd3_request; + iov.iov_len = sizeof(dnbd3_request); + + if (kernel_sendmsg(sock, &msg, &iov, 1, sizeof(dnbd3_request)) <= 0) { + dnbd3_dev_err_host(dev, addr, "requesting test block failed\n"); + goto error; + } + + // receive net reply + iov.iov_base = &dnbd3_reply; + iov.iov_len = sizeof(dnbd3_reply); + if (kernel_recvmsg(sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags) + != sizeof(dnbd3_reply)) { + dnbd3_dev_err_host(dev, addr, "receiving test block header packet failed\n"); + goto error; + } + fixup_reply(dnbd3_reply); + if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_GET_BLOCK + || dnbd3_reply.size != RTT_BLOCK_SIZE) { + dnbd3_dev_err_host(dev, addr, + "unexpected reply to block request: cmd=%d, size=%d (discover)\n", + (int)dnbd3_reply.cmd, (int)dnbd3_reply.size); + goto error; + } + + // receive data + buf = kmalloc(DNBD3_BLOCK_SIZE, GFP_NOWAIT); + if (buf == NULL) { + /* fallback to stack if we're really memory constrained */ + buf = smallbuf; + buffer_size = sizeof(smallbuf); + } else { + buffer_size = DNBD3_BLOCK_SIZE; + } + remaining = RTT_BLOCK_SIZE; + /* TODO in either case we could build a large iovec that points to the same buffer over and over again */ + while (remaining > 0) { + iov.iov_base = buf; + iov.iov_len = buffer_size; + ret = kernel_recvmsg(sock, &msg, &iov, 1, MIN(remaining, buffer_size), msg.msg_flags); + if (ret <= 0) { + dnbd3_dev_err_host(dev, addr, "receiving test block payload failed (ret=%d)\n", ret); + goto error; + } + remaining -= ret; + } + func_return = 1; + // Fallthrough! +error: + if (buf != smallbuf) + kfree(buf); + return func_return; +} + +static int spawn_worker_thread(dnbd3_device_t *dev, struct task_struct **task, const char *name, + int (*threadfn)(void *data)) +{ + ASSERT(*task == NULL); + *task = kthread_create(threadfn, dev, "%s-%s", dev->disk->disk_name, name); + if (!IS_ERR(*task)) { + get_task_struct(*task); + wake_up_process(*task); + return 1; + } + dev_err(dnbd3_device_to_dev(dev), "failed to create %s thread (%ld)\n", + name, PTR_ERR(*task)); + /* reset possible non-NULL error value */ + *task = NULL; + return 0; +} + +static void stop_worker_thread(dnbd3_device_t *dev, struct task_struct **task, const char *name, int quiet) +{ + int ret; + + if (*task == NULL) + return; + if (!quiet) + dnbd3_dev_dbg_host_cur(dev, "stop %s thread\n", name); + ret = kthread_stop(*task); + put_task_struct(*task); + if (ret == -EINTR) { + /* thread has never been scheduled and run */ + if (!quiet) + dev_dbg(dnbd3_device_to_dev(dev), "%s thread has never run\n", name); + } else { + /* thread has run, check if it has terminated successfully */ + if (ret < 0 && !quiet) + dev_err(dnbd3_device_to_dev(dev), "%s thread was not terminated correctly\n", name); + } + *task = NULL; +} + int dnbd3_net_connect(dnbd3_device_t *dev) { - struct request *req1 = NULL; + struct request *req_alt_servers = NULL; unsigned long irqflags; ASSERT(atomic_read(&dev->connection_lock)); - // do some checks before connecting - req1 = kmalloc(sizeof(*req1), GFP_ATOMIC); - if (!req1) { - dnbd3_dev_err_host_cur(dev, "kmalloc failed\n"); - goto error; + if (dev->use_server_provided_alts) { + req_alt_servers = kmalloc(sizeof(*req_alt_servers), GFP_KERNEL); + if (req_alt_servers == NULL) + dnbd3_dev_err_host_cur(dev, "Cannot allocate memory to request list of alt servers\n"); } if (dev->cur_server.host.ss_family == 0 || dev->imgname == NULL) { @@ -964,165 +1086,50 @@ int dnbd3_net_connect(dnbd3_device_t *dev) ASSERT(dev->thread_receive == NULL); ASSERT(dev->thread_discover == NULL); - dnbd3_dev_dbg_host_cur(dev, "connecting ...\n"); - if (dev->better_sock != NULL) { // Switching server, connection is already established and size request was executed - dnbd3_dev_dbg_host_cur(dev, "on-the-fly server change ...\n"); + dnbd3_dev_dbg_host_cur(dev, "on-the-fly server change\n"); dev->sock = dev->better_sock; dev->better_sock = NULL; } else { // no established connection yet from discovery thread, start new one - uint64_t reported_size; - dnbd3_request_t dnbd3_request; - dnbd3_reply_t dnbd3_reply; - struct msghdr msg; - struct kvec iov[2]; - uint16_t rid, proto_version; - char *name; - int mlen; - - init_msghdr(msg); + uint16_t proto_version; + dnbd3_dev_dbg_host_cur(dev, "connecting\n"); dev->sock = dnbd3_connect(dev, &dev->cur_server.host); if (dev->sock == NULL) { dnbd3_dev_err_host_cur(dev, "%s: Failed\n", __func__); goto error; } - - // Request filesize - dnbd3_request.magic = dnbd3_packet_magic; - dnbd3_request.cmd = CMD_SELECT_IMAGE; - iov[0].iov_base = &dnbd3_request; - iov[0].iov_len = sizeof(dnbd3_request); - serializer_reset_write(&dev->payload_buffer); - serializer_put_uint16(&dev->payload_buffer, PROTOCOL_VERSION); - serializer_put_string(&dev->payload_buffer, dev->imgname); - serializer_put_uint16(&dev->payload_buffer, dev->rid); - serializer_put_uint8(&dev->payload_buffer, 0); // is_server = false - iov[1].iov_base = &dev->payload_buffer; - dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(&dev->payload_buffer); - fixup_request(dnbd3_request); - mlen = sizeof(dnbd3_request) + iov[1].iov_len; - if (kernel_sendmsg(dev->sock, &msg, iov, 2, mlen) != mlen) { - dnbd3_dev_err_host_cur(dev, "couldn't send CMD_SELECT_IMAGE\n"); - goto error; - } - // receive reply header - iov[0].iov_base = &dnbd3_reply; - iov[0].iov_len = sizeof(dnbd3_reply); - if (kernel_recvmsg(dev->sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != - sizeof(dnbd3_reply)) { - dnbd3_dev_err_host_cur(dev, "received corrupted reply header after CMD_SELECT_IMAGE\n"); + /* execute the "select image" handshake */ + if (!dnbd3_execute_handshake(dev, dev->sock, &dev->cur_server.host, &proto_version)) goto error; - } - // check reply header - fixup_reply(dnbd3_reply); - if (dnbd3_reply.cmd != CMD_SELECT_IMAGE || dnbd3_reply.size < 3 || dnbd3_reply.size > MAX_PAYLOAD || - dnbd3_reply.magic != dnbd3_packet_magic) { - dnbd3_dev_err_host_cur( - dev, "received invalid reply to CMD_SELECT_IMAGE, image doesn't exist on server\n"); - goto error; - } - // receive reply payload - iov[0].iov_base = &dev->payload_buffer; - iov[0].iov_len = dnbd3_reply.size; - if (kernel_recvmsg(dev->sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) { - dnbd3_dev_err_host_cur(dev, "cold not read CMD_SELECT_IMAGE payload on handshake\n"); - goto error; - } - // handle/check reply payload - serializer_reset_read(&dev->payload_buffer, dnbd3_reply.size); - proto_version = serializer_get_uint16(&dev->payload_buffer); + spin_lock_irqsave(&dev->blk_lock, irqflags); dev->cur_server.protocol_version = proto_version; spin_unlock_irqrestore(&dev->blk_lock, irqflags); - if (proto_version < MIN_SUPPORTED_SERVER) { - dnbd3_dev_err_host_cur(dev, "server version is lower than min supported version\n"); - goto error; - } - name = serializer_get_string(&dev->payload_buffer); - if (dev->rid != 0 && strcmp(name, dev->imgname) != 0) { - dnbd3_dev_err_host_cur(dev, "server offers image '%s', requested '%s'\n", name, dev->imgname); - goto error; - } - if (strlen(dev->imgname) < strlen(name)) { - dev->imgname = krealloc(dev->imgname, strlen(name) + 1, GFP_ATOMIC); - if (dev->imgname == NULL) { - dnbd3_dev_err_host_cur(dev, "reallocating buffer for new image name failed\n"); - goto error; - } - } - strcpy(dev->imgname, name); - rid = serializer_get_uint16(&dev->payload_buffer); - if (dev->rid != 0 && dev->rid != rid) { - dnbd3_dev_err_host_cur(dev, "server provides rid %d, requested was %d\n", (int)rid, - (int)dev->rid); - goto error; - } - dev->rid = rid; - reported_size = serializer_get_uint64(&dev->payload_buffer); - if (reported_size < 4096) { - dnbd3_dev_err_host_cur(dev, "reported size by server is < 4096\n"); - goto error; - } - if (dev->reported_size != 0 && dev->reported_size != reported_size) { - dnbd3_dev_err_host_cur(dev, "newly connected server reports size %llu, but expected is %llu\n", - reported_size, dev->reported_size); - goto error; - } else if (dev->reported_size == 0) { - // store image information - dev->reported_size = reported_size; - set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */ - dnbd3_dev_dbg_host_cur(dev, "image size: %llu\n", dev->reported_size); - dev->update_available = 0; - } } - // create required threads - dev->thread_send = kthread_create(dnbd3_net_send, dev, "%s-send", dev->disk->disk_name); - if (!IS_ERR(dev->thread_send)) { - get_task_struct(dev->thread_send); - wake_up_process(dev->thread_send); - } else { - dev_err(dnbd3_device_to_dev(dev), "failed to create send thread\n"); - /* reset error to cleanup thread */ - dev->thread_send = NULL; + /* create required threads */ + if (!spawn_worker_thread(dev, &dev->thread_send, "send", dnbd3_net_send)) goto error; - } - - dev->thread_receive = kthread_create(dnbd3_net_receive, dev, "%s-receive", dev->disk->disk_name); - if (!IS_ERR(dev->thread_receive)) { - get_task_struct(dev->thread_receive); - wake_up_process(dev->thread_receive); - } else { - dev_err(dnbd3_device_to_dev(dev), "failed to create receive thread\n"); - /* reset error to cleanup thread */ - dev->thread_receive = NULL; + if (!spawn_worker_thread(dev, &dev->thread_receive, "receive", dnbd3_net_receive)) goto error; - } - - dev->thread_discover = kthread_create(dnbd3_net_discover, dev, "%s-discover", dev->disk->disk_name); - if (!IS_ERR(dev->thread_discover)) { - get_task_struct(dev->thread_discover); - wake_up_process(dev->thread_discover); - } else { - dev_err(dnbd3_device_to_dev(dev), "failed to create discover thread\n"); - /* reset error to cleanup thread */ - dev->thread_discover = NULL; + if (!spawn_worker_thread(dev, &dev->thread_discover, "discover", dnbd3_net_discover)) goto error; - } + dnbd3_dev_dbg_host_cur(dev, "connection established\n"); dev->panic = 0; dev->panic_count = 0; - // Enqueue request to request_queue_send for a fresh list of alt servers - dnbd3_cmd_to_priv(req1, CMD_GET_SERVERS); - spin_lock_irqsave(&dev->blk_lock, irqflags); - list_add(&req1->queuelist, &dev->request_queue_send); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - - wake_up(&dev->process_queue_send); + if (req_alt_servers != NULL) { + // Enqueue request to request_queue_send for a fresh list of alt servers + dnbd3_cmd_to_priv(req_alt_servers, CMD_GET_SERVERS); + spin_lock_irqsave(&dev->blk_lock, irqflags); + list_add(&req_alt_servers->queuelist, &dev->request_queue_send); + spin_unlock_irqrestore(&dev->blk_lock, irqflags); + wake_up(&dev->process_queue_send); + } // add heartbeat timer // Do not goto error after creating the timer - we require that the timer exists @@ -1135,21 +1142,9 @@ int dnbd3_net_connect(dnbd3_device_t *dev) return 0; error: - if (dev->thread_send) { - kthread_stop(dev->thread_send); - put_task_struct(dev->thread_send); - dev->thread_send = NULL; - } - if (dev->thread_receive) { - kthread_stop(dev->thread_receive); - put_task_struct(dev->thread_receive); - dev->thread_receive = NULL; - } - if (dev->thread_discover) { - kthread_stop(dev->thread_discover); - put_task_struct(dev->thread_discover); - dev->thread_discover = NULL; - } + stop_worker_thread(dev, &dev->thread_send, "send", 1); + stop_worker_thread(dev, &dev->thread_receive, "receive", 1); + stop_worker_thread(dev, &dev->thread_discover, "discover", 1); if (dev->sock) { sock_release(dev->sock); dev->sock = NULL; @@ -1157,16 +1152,14 @@ error: spin_lock_irqsave(&dev->blk_lock, irqflags); dev->cur_server.host.ss_family = 0; spin_unlock_irqrestore(&dev->blk_lock, irqflags); - kfree(req1); + kfree(req_alt_servers); return -1; } int dnbd3_net_disconnect(dnbd3_device_t *dev) { - struct task_struct *thread = NULL; unsigned long irqflags; - int ret; dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n"); ASSERT(atomic_read(&dev->connection_lock)); @@ -1180,56 +1173,9 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) } // kill sending and receiving threads - if (dev->thread_send) { - dnbd3_dev_dbg_host_cur(dev, "stop send thread\n"); - thread = dev->thread_send; - ret = kthread_stop(thread); - put_task_struct(thread); - if (ret == -EINTR) { - /* thread has never been scheduled and run */ - dev_dbg(dnbd3_device_to_dev(dev), "send thread has never run\n"); - } else { - /* thread has run, check if it has terminated successfully */ - if (ret < 0) - dev_err(dnbd3_device_to_dev(dev), "send thread was not terminated correctly\n"); - } - dev->thread_send = NULL; - } - - if (dev->thread_receive) { - dnbd3_dev_dbg_host_cur(dev, "stop receive thread\n"); - thread = dev->thread_receive; - ret = kthread_stop(thread); - put_task_struct(thread); - if (ret == -EINTR) { - /* thread has never been scheduled and run */ - dev_dbg(dnbd3_device_to_dev(dev), "receive thread has never run\n"); - } else { - /* thread has run, check if it has terminated successfully */ - if (ret < 0) - dev_err(dnbd3_device_to_dev(dev), "receive thread was not terminated correctly\n"); - } - dev->thread_receive = NULL; - } - - if (dev->thread_discover) { - dnbd3_dev_dbg_host_cur(dev, "stop discover thread\n"); - thread = dev->thread_discover; - ret = kthread_stop(thread); - put_task_struct(thread); - if (ret == -EINTR) { - /* thread has never been scheduled and run */ - dev_dbg(dnbd3_device_to_dev(dev), "discover thread has never run\n"); - } else { - /* thread has run, check if it has terminated successfully */ - if (ret < 0) { - dev_err(dnbd3_device_to_dev(dev), "discover thread was not terminated correctly (%d)\n", - ret); - } - } - dev->thread_discover = NULL; - } - + stop_worker_thread(dev, &dev->thread_send, "send", 0); + stop_worker_thread(dev, &dev->thread_receive, "receive", 0); + stop_worker_thread(dev, &dev->thread_discover, "discover", 0); if (dev->sock) { sock_release(dev->sock); dev->sock = NULL; diff --git a/src/kernel/net.h b/src/kernel/net.h index d46505b..f91334e 100644 --- a/src/kernel/net.h +++ b/src/kernel/net.h @@ -24,15 +24,6 @@ #include "dnbd3_main.h" -#define init_msghdr(h) \ - do { \ - h.msg_name = NULL; \ - h.msg_namelen = 0; \ - h.msg_control = NULL; \ - h.msg_controllen = 0; \ - h.msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \ - } while (0) - int dnbd3_net_connect(dnbd3_device_t *lo); int dnbd3_net_disconnect(dnbd3_device_t *lo); -- cgit v1.2.3-55-g7522