diff options
Diffstat (limited to 'src/kernel/net.c')
-rw-r--r-- | src/kernel/net.c | 1384 |
1 files changed, 646 insertions, 738 deletions
diff --git a/src/kernel/net.c b/src/kernel/net.c index 5919832..f5806de 100644 --- a/src/kernel/net.c +++ b/src/kernel/net.c @@ -22,7 +22,6 @@ #include <dnbd3/config/client.h> #include "net.h" #include "blk.h" -#include "utils.h" #include "dnbd3_main.h" #include <dnbd3/shared/serialize.h> @@ -30,7 +29,6 @@ #include <linux/time.h> #include <linux/ktime.h> #include <linux/tcp.h> -#include <linux/sched/task.h> #ifndef MIN #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -40,7 +38,7 @@ #define ktime_to_s(kt) ktime_divns(kt, NSEC_PER_SEC) #endif -#ifdef CONFIG_DEBUG_DRIVER +#ifdef DEBUG #define ASSERT(x) \ do { \ if (!(x)) { \ @@ -54,15 +52,6 @@ } while (0) #endif -#define init_msghdr(h) \ - do { \ - h.msg_name = NULL; \ - h.msg_namelen = 0; \ - h.msg_control = NULL; \ - h.msg_controllen = 0; \ - h.msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \ - } while (0) - #define dnbd3_dev_dbg_host(dev, host, fmt, ...) \ dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__) #define dnbd3_dev_err_host(dev, host, fmt, ...) \ @@ -73,219 +62,267 @@ #define dnbd3_dev_err_host_cur(dev, fmt, ...) \ dnbd3_dev_err_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__) -static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr); +static bool dnbd3_drain_socket(dnbd3_device_t *dev, struct socket *sock, int bytes); +static int dnbd3_recv_bytes(struct socket *sock, void *buffer, size_t count); +static int dnbd3_recv_reply(struct socket *sock, dnbd3_reply_t *reply_hdr); +static bool dnbd3_send_request(struct socket *sock, u16 cmd, u64 handle, u64 offset, u32 size); + +static int dnbd3_set_primary_connection(dnbd3_device_t *dev, struct socket *sock, + struct sockaddr_storage *addr, u16 protocol_version); + +static int dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr, + struct socket **sock_out); + +static bool dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, + struct sockaddr_storage *addr, uint16_t *remote_version, bool copy_image_info); + +static bool dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, + struct socket *sock); + +static bool dnbd3_send_empty_request(dnbd3_device_t *dev, u16 cmd); + +static void dnbd3_start_discover(dnbd3_device_t *dev, bool panic); + +static void dnbd3_discover(dnbd3_device_t *dev); + +static void dnbd3_internal_discover(dnbd3_device_t *dev); -static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, - struct sockaddr_storage *addr, uint16_t *remote_version); +static void set_socket_timeout(struct socket *sock, bool set_send, int timeout_ms); -static int dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock); +// Use as write-only dump, don't care about race conditions etc. +static u8 __garbage_mem[PAGE_SIZE]; -static void dnbd3_net_heartbeat(struct timer_list *arg) +/** + * Delayed work triggering sending of keepalive packet. + */ +static void dnbd3_keepalive_workfn(struct work_struct *work) { - dnbd3_device_t *dev = (dnbd3_device_t *)container_of(arg, dnbd3_device_t, hb_timer); - - // Because different events need different intervals, the timer is called once a second. - // Other intervals can be derived using dev->heartbeat_count. -#define timeout_seconds(x) (dev->heartbeat_count % (x) == 0) - - if (!dev->panic) { - if (timeout_seconds(TIMER_INTERVAL_KEEPALIVE_PACKET)) { - struct request *req = kmalloc(sizeof(struct request), GFP_ATOMIC); - // send keepalive - if (req) { - unsigned long irqflags; - - dnbd3_cmd_to_priv(req, CMD_KEEPALIVE); - spin_lock_irqsave(&dev->blk_lock, irqflags); - list_add_tail(&req->queuelist, &dev->request_queue_send); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - wake_up(&dev->process_queue_send); - } else { - dev_err(dnbd3_device_to_dev(dev), "couldn't create keepalive request\n"); - } - } - if ((dev->heartbeat_count > STARTUP_MODE_DURATION && timeout_seconds(TIMER_INTERVAL_PROBE_NORMAL)) || - (dev->heartbeat_count <= STARTUP_MODE_DURATION && timeout_seconds(TIMER_INTERVAL_PROBE_STARTUP))) { - // Normal discovery - dev->discover = 1; - wake_up(&dev->process_queue_discover); - } - } else if (timeout_seconds(TIMER_INTERVAL_PROBE_PANIC)) { - // Panic discovery - dev->discover = 1; - wake_up(&dev->process_queue_discover); + unsigned long irqflags; + dnbd3_device_t *dev = container_of(work, dnbd3_device_t, keepalive_work.work); + + dnbd3_send_empty_request(dev, CMD_KEEPALIVE); + spin_lock_irqsave(&dev->blk_lock, irqflags); + if (device_active(dev)) { + mod_delayed_work(system_freezable_power_efficient_wq, + &dev->keepalive_work, KEEPALIVE_INTERVAL * HZ); } + spin_unlock_irqrestore(&dev->blk_lock, irqflags); +} + +/** + * Delayed work triggering discovery (alt server check) + */ +static void dnbd3_discover_workfn(struct work_struct *work) +{ + dnbd3_device_t *dev = container_of(work, dnbd3_device_t, discover_work.work); - dev->hb_timer.expires = jiffies + HZ; + dnbd3_discover(dev); +} - ++dev->heartbeat_count; - add_timer(&dev->hb_timer); +/** + * For manually triggering an immediate discovery + */ +static void dnbd3_start_discover(dnbd3_device_t *dev, bool panic) +{ + unsigned long irqflags; -#undef timeout_seconds + if (!device_active(dev)) + return; + if (panic && dnbd3_flag_get(dev->connection_lock)) { + spin_lock_irqsave(&dev->blk_lock, irqflags); + if (!dev->panic) { + // Panic freshly turned on + dev->panic = true; + dev->discover_interval = TIMER_INTERVAL_PROBE_PANIC; + } + spin_unlock_irqrestore(&dev->blk_lock, irqflags); + dnbd3_flag_reset(dev->connection_lock); + } + spin_lock_irqsave(&dev->blk_lock, irqflags); + mod_delayed_work(system_freezable_power_efficient_wq, + &dev->discover_work, 1); + spin_unlock_irqrestore(&dev->blk_lock, irqflags); } -static int dnbd3_net_discover(void *data) +/** + * Wrapper for the actual discover function below. Check run conditions + * here and re-schedule delayed task here. + */ +static void dnbd3_discover(dnbd3_device_t *dev) +{ + unsigned long irqflags; + + if (!device_active(dev) || dnbd3_flag_taken(dev->connection_lock)) + return; // device not active anymore, or just about to switch + if (!dnbd3_flag_get(dev->discover_running)) + return; // Already busy + spin_lock_irqsave(&dev->blk_lock, irqflags); + cancel_delayed_work(&dev->discover_work); + spin_unlock_irqrestore(&dev->blk_lock, irqflags); + dnbd3_internal_discover(dev); + dev->discover_count++; + // Re-queueing logic + spin_lock_irqsave(&dev->blk_lock, irqflags); + if (device_active(dev)) { + mod_delayed_work(system_freezable_power_efficient_wq, + &dev->discover_work, dev->discover_interval * HZ); + if (dev->discover_interval < TIMER_INTERVAL_PROBE_MAX + && dev->discover_count > DISCOVER_STARTUP_PHASE_COUNT) { + dev->discover_interval += 2; + } + } + spin_unlock_irqrestore(&dev->blk_lock, irqflags); + dnbd3_flag_reset(dev->discover_running); +} + +/** + * Discovery. Probe all (or some) known alt servers, + * and initiate connection switch if appropriate + */ +static void dnbd3_internal_discover(dnbd3_device_t *dev) { - dnbd3_device_t *dev = data; struct socket *sock, *best_sock = NULL; dnbd3_alt_server_t *alt; struct sockaddr_storage host_compare, best_server; uint16_t remote_version; - ktime_t start = ktime_set(0, 0), end = ktime_set(0, 0); + ktime_t start, end; unsigned long rtt = 0, best_rtt = 0; - unsigned long irqflags; - int i, j, isize, fails, rtt_threshold; - int turn = 0; - int ready = 0, do_change = 0; - char check_order[NUMBER_SERVERS]; - - struct request *last_request = (struct request *)123, *cur_request = (struct request *)456; + int i, j, k, isize, fails, rtt_threshold; + int do_change = 0; + u8 check_order[NUMBER_SERVERS]; + const bool ready = dev->discover_count > DISCOVER_STARTUP_PHASE_COUNT; + const u32 turn = dev->discover_count % DISCOVER_HISTORY_SIZE; + // Shuffle alt_servers for (i = 0; i < NUMBER_SERVERS; ++i) check_order[i] = i; - while (!kthread_should_stop()) { - wait_event_interruptible(dev->process_queue_discover, - kthread_should_stop() || dev->discover || dev->thread_discover == NULL); + for (i = 0; i < NUMBER_SERVERS; ++i) { + j = prandom_u32() % NUMBER_SERVERS; + if (j != i) { + int tmp = check_order[i]; - if (kthread_should_stop() || dev->imgname == NULL || dev->thread_discover == NULL) - break; + check_order[i] = check_order[j]; + check_order[j] = tmp; + } + } - if (!dev->discover) - continue; - dev->discover = 0; + best_server.ss_family = 0; + best_rtt = RTT_UNREACHABLE; - if (dev->reported_size < 4096) - continue; + if (!ready || dev->panic) + isize = NUMBER_SERVERS; + else + isize = 3; - best_server.ss_family = 0; - best_rtt = 0xFFFFFFFul; + for (j = 0; j < NUMBER_SERVERS; ++j) { + if (!device_active(dev)) + break; + i = check_order[j]; + mutex_lock(&dev->alt_servers_lock); + host_compare = dev->alt_servers[i].host; + fails = dev->alt_servers[i].failures; + mutex_unlock(&dev->alt_servers_lock); + if (host_compare.ss_family == 0) + continue; // Empty slot + // Reduced probability for hosts that have been unreachable + if (!dev->panic && fails > 50 && (prandom_u32() % 4) != 0) + continue; // If not in panic mode, skip server if it failed too many times + if (isize-- <= 0 && !is_same_server(&dev->cur_server.host, &host_compare)) + continue; // Only test isize servers plus current server + + // Initialize socket and connect + sock = NULL; + if (dnbd3_connect(dev, &host_compare, &sock) != 0) + goto error; - if (dev->heartbeat_count < STARTUP_MODE_DURATION || dev->panic) - isize = NUMBER_SERVERS; - else - isize = 3; + remote_version = 0; + if (!dnbd3_execute_handshake(dev, sock, &host_compare, &remote_version, false)) + goto error; - if (NUMBER_SERVERS > isize) { - for (i = 0; i < isize; ++i) { - j = ((ktime_to_s(start) >> i) ^ (ktime_to_us(start) >> j)) % NUMBER_SERVERS; - if (j != i) { - int tmp = check_order[i]; - check_order[i] = check_order[j]; - check_order[j] = tmp; - } + // panic mode, take first responding server + if (dev->panic) { + dnbd3_dev_dbg_host(dev, &host_compare, "panic mode, changing to new server\n"); + if (!dnbd3_flag_get(dev->connection_lock)) { + dnbd3_dev_dbg_host(dev, &host_compare, "...raced, ignoring\n"); + } else { + // Check global flag, a connect might have been in progress + if (best_sock != NULL) + sock_release(best_sock); + set_socket_timeout(sock, false, SOCKET_TIMEOUT_RECV * 1000 + 1000); + if (dnbd3_set_primary_connection(dev, sock, &host_compare, remote_version) != 0) + sock_release(sock); + dnbd3_flag_reset(dev->connection_lock); + return; } } - for (j = 0; j < NUMBER_SERVERS; ++j) { - i = check_order[j]; - mutex_lock(&dev->alt_servers_lock); - host_compare = dev->alt_servers[i].host; - fails = dev->alt_servers[i].failures; - mutex_unlock(&dev->alt_servers_lock); - if (host_compare.ss_family == 0) - continue; // Empty slot - if (!dev->panic && fails > 50 - && (ktime_to_us(start) & 7) != 0) - continue; // If not in panic mode, skip server if it failed too many times - if (isize-- <= 0 && !is_same_server(&dev->cur_server.host, &host_compare)) - continue; // Only test isize servers plus current server - - // Initialize socket and connect - sock = dnbd3_connect(dev, &host_compare); - if (sock == NULL) - goto error; - - if (!dnbd3_execute_handshake(dev, sock, &host_compare, &remote_version)) - goto error; - - - // panic mode, take first responding server - if (dev->panic) { - dnbd3_dev_dbg_host(dev, &host_compare, "panic mode, changing to new server\n"); - while (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0) - schedule(); - - if (dev->panic) { - // Re-check, a connect might have been in progress - dev->panic = 0; - if (best_sock != NULL) - sock_release(best_sock); - - dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect(); - put_task_struct(dev->thread_discover); - dev->thread_discover = NULL; - dnbd3_net_disconnect(dev); - spin_lock_irqsave(&dev->blk_lock, irqflags); - dev->cur_server.host = host_compare; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - dnbd3_net_connect(dev); - atomic_set(&dev->connection_lock, 0); - return 0; - } - atomic_set(&dev->connection_lock, 0); - } - - // start rtt measurement - start = ktime_get_real(); - - if (!dnbd3_request_test_block(dev, &host_compare, sock)) - goto error; - - end = ktime_get_real(); // end rtt measurement + // actual rtt measurement is just the first block requests and reply + start = ktime_get_real(); + if (!dnbd3_request_test_block(dev, &host_compare, sock)) + goto error; + end = ktime_get_real(); - mutex_lock(&dev->alt_servers_lock); - if (is_same_server(&dev->alt_servers[i].host, &host_compare)) { - dev->alt_servers[i].protocol_version = remote_version; - dev->alt_servers[i].rtts[turn] = (unsigned long)ktime_us_delta(end, start); - - rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1] - + dev->alt_servers[i].rtts[2] + dev->alt_servers[i].rtts[3]) - / 4; - dev->alt_servers[i].failures = 0; - if (dev->alt_servers[i].best_count > 1) - dev->alt_servers[i].best_count -= 2; + mutex_lock(&dev->alt_servers_lock); + if (is_same_server(&dev->alt_servers[i].host, &host_compare)) { + dev->alt_servers[i].protocol_version = remote_version; + dev->alt_servers[i].rtts[turn] = + (unsigned long)ktime_us_delta(end, start); + + rtt = 0; + for (k = 0; k < DISCOVER_HISTORY_SIZE; ++k) { + rtt += dev->alt_servers[i].rtts[k]; } - mutex_unlock(&dev->alt_servers_lock); + rtt /= DISCOVER_HISTORY_SIZE; + dev->alt_servers[i].failures = 0; + if (dev->alt_servers[i].best_count > 1) + dev->alt_servers[i].best_count -= 2; + } + mutex_unlock(&dev->alt_servers_lock); - if (best_rtt > rtt) { - // This one is better, keep socket open in case we switch - best_rtt = rtt; - best_server = host_compare; - if (best_sock != NULL) - sock_release(best_sock); - best_sock = sock; - sock = NULL; - } else { - // Not better, discard connection - sock_release(sock); - sock = NULL; - } + if (best_rtt > rtt) { + // This one is better, keep socket open in case we switch + best_rtt = rtt; + best_server = host_compare; + if (best_sock != NULL) + sock_release(best_sock); + best_sock = sock; + sock = NULL; + } else { + // Not better, discard connection + sock_release(sock); + sock = NULL; + } - // update cur servers rtt - if (is_same_server(&dev->cur_server.host, &host_compare)) - dev->cur_server.rtt = rtt; + // update cur servers rtt + if (is_same_server(&dev->cur_server.host, &host_compare)) + dev->cur_server.rtt = rtt; - continue; + continue; error: - if (sock != NULL) { - sock_release(sock); - sock = NULL; - } - mutex_lock(&dev->alt_servers_lock); - if (is_same_server(&dev->alt_servers[i].host, &host_compare)) { - ++dev->alt_servers[i].failures; - dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE; - if (dev->alt_servers[i].best_count > 2) - dev->alt_servers[i].best_count -= 3; - } - mutex_unlock(&dev->alt_servers_lock); - if (is_same_server(&dev->cur_server.host, &host_compare)) - dev->cur_server.rtt = RTT_UNREACHABLE; - } // for loop over alt_servers + if (sock != NULL) { + sock_release(sock); + sock = NULL; + } + mutex_lock(&dev->alt_servers_lock); + if (is_same_server(&dev->alt_servers[i].host, &host_compare)) { + if (remote_version) + dev->alt_servers[i].protocol_version = remote_version; + ++dev->alt_servers[i].failures; + dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE; + if (dev->alt_servers[i].best_count > 2) + dev->alt_servers[i].best_count -= 3; + } + mutex_unlock(&dev->alt_servers_lock); + if (is_same_server(&dev->cur_server.host, &host_compare)) + dev->cur_server.rtt = RTT_UNREACHABLE; + } // END - for loop over alt_servers + if (best_server.ss_family == 0) { + // No alt server could be reached + ASSERT(!best_sock); if (dev->panic) { if (dev->panic_count < 255) dev->panic_count++; @@ -293,295 +330,166 @@ error: if (PROBE_COUNT_TIMEOUT > 0 && dev->panic_count == PROBE_COUNT_TIMEOUT + 1) dnbd3_blk_fail_all_requests(dev); } + return; + } - if (best_server.ss_family == 0 || kthread_should_stop() || dev->thread_discover == NULL) { - // No alt server could be reached at all or thread should stop - if (best_sock != NULL) { - // Should never happen actually - sock_release(best_sock); - best_sock = NULL; - } - continue; - } - - // If best server was repeatedly measured best, lower the switching threshold more - mutex_lock(&dev->alt_servers_lock); - alt = get_existing_alt_from_addr(&best_server, dev); - if (alt != NULL) { - if (alt->best_count < 148) - alt->best_count += 3; - rtt_threshold = 1500 - (alt->best_count * 10); - } else { - rtt_threshold = 1500; - } - mutex_unlock(&dev->alt_servers_lock); - - do_change = ready && !is_same_server(&best_server, &dev->cur_server.host) - && (ktime_to_us(start) & 3) != 0 - && RTT_THRESHOLD_FACTOR(dev->cur_server.rtt) > best_rtt + rtt_threshold; - - if (ready && !do_change && best_sock != NULL) { - spin_lock_irqsave(&dev->blk_lock, irqflags); - if (!list_empty(&dev->request_queue_send)) { - cur_request = list_entry(dev->request_queue_send.next, struct request, queuelist); - do_change = (cur_request == last_request); - if (do_change) - dev_warn(dnbd3_device_to_dev(dev), "hung request, triggering change\n"); - } else { - cur_request = (struct request *)123; - } - last_request = cur_request; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - } - - // take server with lowest rtt - // if a (dis)connect is already in progress, we do nothing, this is not panic mode - if (do_change && atomic_cmpxchg(&dev->connection_lock, 0, 1) == 0) { - dev_info(dnbd3_device_to_dev(dev), "server %pISpc is faster (%lluµs vs. %lluµs)\n", - &best_server, - (unsigned long long)best_rtt, (unsigned long long)dev->cur_server.rtt); - dev->better_sock = best_sock; // Take shortcut by continuing to use open connection - put_task_struct(dev->thread_discover); - dev->thread_discover = NULL; - dnbd3_net_disconnect(dev); - spin_lock_irqsave(&dev->blk_lock, irqflags); - dev->cur_server.host = best_server; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - dev->cur_server.rtt = best_rtt; - dnbd3_net_connect(dev); - atomic_set(&dev->connection_lock, 0); - return 0; - } - - // Clean up connection that was held open for quicker server switch - if (best_sock != NULL) { + // If best server was repeatedly measured best, lower the switching threshold more + mutex_lock(&dev->alt_servers_lock); + alt = get_existing_alt_from_addr(&best_server, dev); + if (alt != NULL) { + if (alt->best_count < 178) + alt->best_count += 3; + rtt_threshold = 1800 - (alt->best_count * 10); + remote_version = alt->protocol_version; + } else { + rtt_threshold = 1800; + remote_version = 0; + } + mutex_unlock(&dev->alt_servers_lock); + + do_change = ready && !is_same_server(&best_server, &dev->cur_server.host) + && RTT_THRESHOLD_FACTOR(dev->cur_server.rtt) > best_rtt + rtt_threshold; + + // take server with lowest rtt + // if a (dis)connect is already in progress, we do nothing, this is not panic mode + if (do_change && device_active(dev) && dnbd3_flag_get(dev->connection_lock)) { + dev_info(dnbd3_device_to_dev(dev), "server %pISpc is faster (%lluµs vs. %lluµs)\n", + &best_server, + (unsigned long long)best_rtt, (unsigned long long)dev->cur_server.rtt); + set_socket_timeout(sock, false, // recv + MAX(best_rtt / 1000, SOCKET_TIMEOUT_RECV * 1000) + 500); + set_socket_timeout(sock, true, // send + MAX(best_rtt / 1000, SOCKET_TIMEOUT_SEND * 1000) + 500); + if (dnbd3_set_primary_connection(dev, best_sock, &best_server, remote_version) != 0) sock_release(best_sock); - best_sock = NULL; - } - - // Increase rtt array index pointer, low probability that it doesn't advance - if (!ready || (ktime_to_us(start) & 15) != 0) - turn = (turn + 1) % 4; - if (turn == 2) // Set ready when we only have 2 of 4 measurements for quicker load balancing - ready = 1; + dnbd3_flag_reset(dev->connection_lock); + return; } - if (kthread_should_stop()) - dev_dbg(dnbd3_device_to_dev(dev), "kthread %s terminated normally\n", __func__); - else - dev_dbg(dnbd3_device_to_dev(dev), "kthread %s exited unexpectedly\n", __func__); - - return 0; + // Clean up connection that was held open for quicker server switch + if (best_sock != NULL) + sock_release(best_sock); } -static int dnbd3_net_send(void *data) +/** + * Worker for sending pending requests. This will be triggered whenever + * we get a new request from the block layer. The worker will then + * work through all the requests in the send queue, request them from + * the server, and return again. + */ +static void dnbd3_send_workfn(struct work_struct *work) { - dnbd3_device_t *dev = data; - struct request *blk_request, *tmp_request; - - dnbd3_request_t dnbd3_request; - struct msghdr msg; - struct kvec iov; - + dnbd3_device_t *dev = container_of(work, dnbd3_device_t, send_work); + struct request *blk_request; + struct dnbd3_cmd *cmd; unsigned long irqflags; - int ret = 0; - - init_msghdr(msg); - - dnbd3_request.magic = dnbd3_packet_magic; - - set_user_nice(current, -20); - - // move already sent requests to request_queue_send again - spin_lock_irqsave(&dev->blk_lock, irqflags); - if (!list_empty(&dev->request_queue_receive)) { - dev_dbg(dnbd3_device_to_dev(dev), "request queue was not empty"); - list_for_each_entry_safe(blk_request, tmp_request, &dev->request_queue_receive, queuelist) { - list_del_init(&blk_request->queuelist); - list_add(&blk_request->queuelist, &dev->request_queue_send); - } - } - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - - while (!kthread_should_stop()) { - wait_event_interruptible(dev->process_queue_send, - kthread_should_stop() || !list_empty(&dev->request_queue_send)); - - if (kthread_should_stop()) - break; - - // extract block request - /* lock since we aquire a blk request from the request_queue_send */ - spin_lock_irqsave(&dev->blk_lock, irqflags); - if (list_empty(&dev->request_queue_send)) { - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - continue; - } - blk_request = list_entry(dev->request_queue_send.next, struct request, queuelist); - // what to do? - switch (dnbd3_req_op(blk_request)) { - case DNBD3_DEV_READ: - dnbd3_request.cmd = CMD_GET_BLOCK; - dnbd3_request.offset = blk_rq_pos(blk_request) << 9; // *512 - dnbd3_request.size = blk_rq_bytes(blk_request); // bytes left to complete entire request - // enqueue request to request_queue_receive - list_del_init(&blk_request->queuelist); - list_add_tail(&blk_request->queuelist, &dev->request_queue_receive); - break; - case DNBD3_REQ_OP_SPECIAL: - dnbd3_request.cmd = dnbd3_priv_to_cmd(blk_request); - dnbd3_request.size = 0; - list_del_init(&blk_request->queuelist); + mutex_lock(&dev->send_mutex); + while (dev->sock && device_active(dev)) { + // extract next block request + spin_lock_irqsave(&dev->send_queue_lock, irqflags); + if (list_empty(&dev->send_queue)) { + spin_unlock_irqrestore(&dev->send_queue_lock, irqflags); break; - - default: - if (!atomic_read(&dev->connection_lock)) - dev_err(dnbd3_device_to_dev(dev), "unknown command (send %u %u)\n", - (int)blk_request->cmd_flags, (int)dnbd3_req_op(blk_request)); - list_del_init(&blk_request->queuelist); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - continue; } - // send net request - dnbd3_request.handle = (uint64_t)(uintptr_t)blk_request; // Double cast to prevent warning on 32bit - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - fixup_request(dnbd3_request); - iov.iov_base = &dnbd3_request; - iov.iov_len = sizeof(dnbd3_request); - if (kernel_sendmsg(dev->sock, &msg, &iov, 1, sizeof(dnbd3_request)) != sizeof(dnbd3_request)) { - if (!atomic_read(&dev->connection_lock)) + blk_request = list_entry(dev->send_queue.next, struct request, queuelist); + list_del_init(&blk_request->queuelist); + spin_unlock_irqrestore(&dev->send_queue_lock, irqflags); + // append to receive queue + spin_lock_irqsave(&dev->recv_queue_lock, irqflags); + list_add_tail(&blk_request->queuelist, &dev->recv_queue); + spin_unlock_irqrestore(&dev->recv_queue_lock, irqflags); + + cmd = blk_mq_rq_to_pdu(blk_request); + if (!dnbd3_send_request(dev->sock, CMD_GET_BLOCK, cmd->handle, + blk_rq_pos(blk_request) << 9 /* sectors */, blk_rq_bytes(blk_request))) { + if (!dnbd3_flag_taken(dev->connection_lock)) { dnbd3_dev_err_host_cur(dev, "connection to server lost (send)\n"); - ret = -ESHUTDOWN; - goto cleanup; + dnbd3_start_discover(dev, true); + } + break; } } - - dev_dbg(dnbd3_device_to_dev(dev), "kthread %s terminated normally\n", __func__); - return 0; - -cleanup: - if (!atomic_read(&dev->connection_lock) && !kthread_should_stop()) { - dev_dbg(dnbd3_device_to_dev(dev), "send thread: Triggering panic mode...\n"); - if (dev->sock) - kernel_sock_shutdown(dev->sock, SHUT_RDWR); - dev->panic = 1; - dev->discover = 1; - wake_up(&dev->process_queue_discover); - } - - if (kthread_should_stop() || ret == 0 || atomic_read(&dev->connection_lock)) - dev_dbg(dnbd3_device_to_dev(dev), "kthread %s terminated normally (cleanup)\n", __func__); - else - dev_err(dnbd3_device_to_dev(dev), "kthread %s terminated abnormally (%d)\n", __func__, ret); - - return 0; + mutex_unlock(&dev->send_mutex); } -static int dnbd3_net_receive(void *data) +/** + * The receive workfn stays active for as long as the connection to a server + * lasts, i.e. it only gets restarted when we switch to a new server. + */ +static void dnbd3_recv_workfn(struct work_struct *work) { - dnbd3_device_t *dev = data; - struct request *blk_request, *tmp_request, *received_request; - - dnbd3_reply_t dnbd3_reply; - struct msghdr msg; - struct kvec iov; + dnbd3_device_t *dev = container_of(work, dnbd3_device_t, recv_work); + struct request *blk_request; + struct request *rq_iter; + struct dnbd3_cmd *cmd; + dnbd3_reply_t reply_hdr; struct req_iterator iter; struct bio_vec bvec_inst; struct bio_vec *bvec = &bvec_inst; + struct msghdr msg = { .msg_flags = MSG_NOSIGNAL | MSG_WAITALL }; + struct kvec iov; void *kaddr; unsigned long irqflags; uint16_t rid; - unsigned long recv_timeout = jiffies; - - int count, remaining, ret = 0; - - init_msghdr(msg); - set_user_nice(current, -20); + int remaining; + int ret; - while (!kthread_should_stop()) { + mutex_lock(&dev->recv_mutex); + while (dev->sock) { // receive net reply - iov.iov_base = &dnbd3_reply; - iov.iov_len = sizeof(dnbd3_reply); - ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags); - - /* end thread after socket timeout or reception of data */ - if (kthread_should_stop()) - break; - - /* check return value of kernel_recvmsg() */ + ret = dnbd3_recv_reply(dev->sock, &reply_hdr); if (ret == 0) { /* have not received any data, but remote peer is shutdown properly */ dnbd3_dev_dbg_host_cur(dev, "remote peer has performed an orderly shutdown\n"); - goto cleanup; + goto out_unlock; } else if (ret < 0) { if (ret == -EAGAIN) { - if (jiffies < recv_timeout) - recv_timeout = jiffies; // Handle overflow - if ((jiffies - recv_timeout) / HZ > SOCKET_KEEPALIVE_TIMEOUT) { - if (!atomic_read(&dev->connection_lock)) - dnbd3_dev_err_host_cur(dev, "receive timeout reached (%d of %d secs)\n", - (int)((jiffies - recv_timeout) / HZ), - (int)SOCKET_KEEPALIVE_TIMEOUT); - ret = -ETIMEDOUT; - goto cleanup; - } - continue; + if (!dnbd3_flag_taken(dev->connection_lock)) + dnbd3_dev_err_host_cur(dev, "receive timeout reached\n"); } else { - /* for all errors other than -EAGAIN, print message and abort thread */ - if (!atomic_read(&dev->connection_lock)) - dnbd3_dev_err_host_cur(dev, "connection to server lost (receive)\n"); - goto cleanup; + /* for all errors other than -EAGAIN, print errno */ + if (!dnbd3_flag_taken(dev->connection_lock)) + dnbd3_dev_err_host_cur(dev, "connection to server lost (receive, errno=%d)\n", ret); } + goto out_unlock; } /* check if arrived data is valid */ - if (ret != sizeof(dnbd3_reply)) { - if (!atomic_read(&dev->connection_lock)) - dnbd3_dev_err_host_cur(dev, "recv partial msg header (%d bytes)\n", ret); - ret = -EINVAL; - goto cleanup; + if (ret != sizeof(reply_hdr)) { + if (!dnbd3_flag_taken(dev->connection_lock)) + dnbd3_dev_err_host_cur(dev, "recv partial msg header (%d/%d bytes)\n", + ret, (int)sizeof(reply_hdr)); + goto out_unlock; } - fixup_reply(dnbd3_reply); // check error - if (dnbd3_reply.magic != dnbd3_packet_magic) { + if (reply_hdr.magic != dnbd3_packet_magic) { dnbd3_dev_err_host_cur(dev, "wrong packet magic (receive)\n"); - ret = -EINVAL; - goto cleanup; - } - if (dnbd3_reply.cmd == 0) { - dnbd3_dev_err_host_cur(dev, "command was 0 (Receive)\n"); - ret = -EINVAL; - goto cleanup; + goto out_unlock; } - // Update timeout - recv_timeout = jiffies; - // what to do? - switch (dnbd3_reply.cmd) { + switch (reply_hdr.cmd) { case CMD_GET_BLOCK: // search for replied request in queue blk_request = NULL; - spin_lock_irqsave(&dev->blk_lock, irqflags); - list_for_each_entry_safe(received_request, tmp_request, &dev->request_queue_receive, - queuelist) { - if ((uint64_t)(uintptr_t)received_request == dnbd3_reply.handle) { - // Double cast to prevent warning on 32bit - blk_request = received_request; + spin_lock_irqsave(&dev->recv_queue_lock, irqflags); + list_for_each_entry(rq_iter, &dev->recv_queue, queuelist) { + cmd = blk_mq_rq_to_pdu(rq_iter); + if (cmd->handle == reply_hdr.handle) { + blk_request = rq_iter; list_del_init(&blk_request->queuelist); break; } } - spin_unlock_irqrestore(&dev->blk_lock, irqflags); + spin_unlock_irqrestore(&dev->recv_queue_lock, irqflags); if (blk_request == NULL) { - dnbd3_dev_err_host_cur(dev, "received block data for unrequested handle (%llu: %llu)\n", - (unsigned long long)dnbd3_reply.handle, - (unsigned long long)dnbd3_reply.size); - ret = -EINVAL; - goto cleanup; + dnbd3_dev_err_host_cur(dev, "received block data for unrequested handle (%llx: len=%llu)\n", + reply_hdr.handle, + (u64)reply_hdr.size); + goto out_unlock; } // receive data and answer to block layer #if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 14, 0) @@ -599,45 +507,36 @@ static int dnbd3_net_receive(void *data) /* have not received any data, but remote peer is shutdown properly */ dnbd3_dev_dbg_host_cur( dev, "remote peer has performed an orderly shutdown\n"); - ret = 0; } else if (ret < 0) { - if (!atomic_read(&dev->connection_lock)) + if (!dnbd3_flag_taken(dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "disconnect: receiving from net to block layer\n"); } else { - if (!atomic_read(&dev->connection_lock)) + if (!dnbd3_flag_taken(dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "receiving from net to block layer (%d bytes)\n", ret); - ret = -EINVAL; } // Requeue request - spin_lock_irqsave(&dev->blk_lock, irqflags); - list_add(&blk_request->queuelist, &dev->request_queue_send); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - goto cleanup; + spin_lock_irqsave(&dev->send_queue_lock, irqflags); + list_add(&blk_request->queuelist, &dev->send_queue); + spin_unlock_irqrestore(&dev->send_queue_lock, irqflags); + goto out_unlock; } } -#ifdef DNBD3_BLK_MQ blk_mq_end_request(blk_request, BLK_STS_OK); -#else - blk_end_request_all(blk_request, 0); -#endif - continue; + break; case CMD_GET_SERVERS: - remaining = dnbd3_reply.size; + remaining = reply_hdr.size; if (dev->use_server_provided_alts) { dnbd3_server_entry_t new_server; while (remaining >= sizeof(dnbd3_server_entry_t)) { - iov.iov_base = &new_server; - iov.iov_len = sizeof(dnbd3_server_entry_t); - if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, - msg.msg_flags) != sizeof(dnbd3_server_entry_t)) { - if (!atomic_read(&dev->connection_lock)) + if (dnbd3_recv_bytes(dev->sock, &new_server, sizeof(new_server)) + != sizeof(new_server)) { + if (!dnbd3_flag_taken(dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n"); - ret = -EINVAL; - goto cleanup; + goto out_unlock; } // TODO: Log if (new_server.failures == 0) { // ADD @@ -645,36 +544,20 @@ static int dnbd3_net_receive(void *data) } else { // REM dnbd3_rem_server(dev, &new_server.host); } - remaining -= sizeof(dnbd3_server_entry_t); - } - } - // Drain any payload still on the wire - while (remaining > 0) { - count = MIN(sizeof(dnbd3_reply), - remaining); // Abuse the reply struct as the receive buffer - iov.iov_base = &dnbd3_reply; - iov.iov_len = count; - ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); - if (ret <= 0) { - if (!atomic_read(&dev->connection_lock)) - dnbd3_dev_err_host_cur( - dev, "recv additional payload from CMD_GET_SERVERS\n"); - ret = -EINVAL; - goto cleanup; + remaining -= sizeof(new_server); } - remaining -= ret; } - continue; + if (!dnbd3_drain_socket(dev, dev->sock, remaining)) + goto out_unlock; + break; case CMD_LATEST_RID: - if (dnbd3_reply.size != 2) { - dev_err(dnbd3_device_to_dev(dev), "CMD_LATEST_RID.size != 2\n"); + if (reply_hdr.size < 2) { + dev_err(dnbd3_device_to_dev(dev), "CMD_LATEST_RID.size < 2\n"); continue; } - iov.iov_base = &rid; - iov.iov_len = sizeof(rid); - if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags) <= 0) { - if (!atomic_read(&dev->connection_lock)) + if (dnbd3_recv_bytes(dev->sock, &rid, 2) != 2) { + if (!dnbd3_flag_taken(dev->connection_lock)) dev_err(dnbd3_device_to_dev(dev), "could not receive CMD_LATEST_RID payload\n"); } else { rid = net_order_16(rid); @@ -682,70 +565,52 @@ static int dnbd3_net_receive(void *data) dev->imgname, (int)rid, (int)dev->rid); dev->update_available = (rid > dev->rid ? 1 : 0); } + if (reply_hdr.size > 2) + dnbd3_drain_socket(dev, dev->sock, reply_hdr.size - 2); continue; case CMD_KEEPALIVE: - if (dnbd3_reply.size != 0) - dev_err(dnbd3_device_to_dev(dev), "keep alive packet with payload\n"); + if (reply_hdr.size != 0) { + dev_dbg(dnbd3_device_to_dev(dev), "keep alive packet with payload\n"); + dnbd3_drain_socket(dev, dev->sock, reply_hdr.size); + } continue; default: - dev_err(dnbd3_device_to_dev(dev), "unknown command (receive)\n"); - continue; + dev_err(dnbd3_device_to_dev(dev), "unknown command: %d (receive), aborting connection\n", (int)reply_hdr.cmd); + goto out_unlock; } } - - dev_dbg(dnbd3_device_to_dev(dev), "kthread thread_receive terminated normally\n"); - return 0; - -cleanup: - if (!atomic_read(&dev->connection_lock) && !kthread_should_stop()) { - dev_dbg(dnbd3_device_to_dev(dev), "recv thread: Triggering panic mode...\n"); - if (dev->sock) - kernel_sock_shutdown(dev->sock, SHUT_RDWR); - dev->panic = 1; - dev->discover = 1; - wake_up(&dev->process_queue_discover); - } - - if (kthread_should_stop() || ret == 0 || atomic_read(&dev->connection_lock)) - dev_dbg(dnbd3_device_to_dev(dev), "kthread %s terminated normally (cleanup)\n", __func__); - else - dev_err(dnbd3_device_to_dev(dev), "kthread %s terminated abnormally (%d)\n", __func__, ret); - - return 0; +out_unlock: + // This will check if we actually still need a new connection + dnbd3_start_discover(dev, true); + mutex_unlock(&dev->recv_mutex); } -static void set_socket_timeouts(struct socket *sock, int timeout_ms) +/** + * Set send or receive timeout of given socket + */ +static void set_socket_timeout(struct socket *sock, bool set_send, int timeout_ms) { #if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 1, 0) + int opt = set_send ? SO_SNDTIMEO_NEW : SO_RCVTIMEO_NEW; struct __kernel_sock_timeval timeout; #else + int opt = set_send ? SO_SNDTIMEO : SO_RCVTIMEO; struct timeval timeout; #endif #if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 9, 0) - sockptr_t timeout_ptr; - - timeout_ptr = KERNEL_SOCKPTR(&timeout); + sockptr_t timeout_ptr = KERNEL_SOCKPTR(&timeout); #else - char *timeout_ptr; - - timeout_ptr = (char *)&timeout; + char *timeout_ptr = (char *)&timeout; #endif timeout.tv_sec = timeout_ms / 1000; timeout.tv_usec = (timeout_ms % 1000) * 1000; - -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 1, 0) - sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout)); - sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout)); -#else - sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout)); - sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout)); -#endif + sock_setsockopt(sock, SOL_SOCKET, opt, timeout_ptr, sizeof(timeout)); } -static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr) +static int dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket **sock_out) { ktime_t start; int ret, connect_time_ms; @@ -763,7 +628,7 @@ static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage #endif if (ret < 0) { dev_err(dnbd3_device_to_dev(dev), "couldn't create socket: %d\n", ret); - return NULL; + return ret; } /* Only one retry, TCP no delay */ @@ -790,36 +655,40 @@ static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage connect_time_ms = dev->cur_server.rtt * 2 / 1000; } /* but obey a minimal configurable value, and maximum sanity check */ - if (connect_time_ms < SOCKET_TIMEOUT_CLIENT_DATA * 1000) - connect_time_ms = SOCKET_TIMEOUT_CLIENT_DATA * 1000; + if (connect_time_ms < SOCKET_TIMEOUT_SEND * 1000) + connect_time_ms = SOCKET_TIMEOUT_SEND * 1000; else if (connect_time_ms > 60000) connect_time_ms = 60000; - set_socket_timeouts(sock, connect_time_ms); + set_socket_timeout(sock, false, connect_time_ms); // recv + set_socket_timeout(sock, true, connect_time_ms); // send start = ktime_get_real(); while (--retries > 0) { ret = kernel_connect(sock, (struct sockaddr *)addr, addrlen, 0); connect_time_ms = (int)ktime_ms_delta(ktime_get_real(), start); - if (connect_time_ms > 2 * SOCKET_TIMEOUT_CLIENT_DATA * 1000) { + if (connect_time_ms > 2 * SOCKET_TIMEOUT_SEND * 1000) { /* Either I'm losing my mind or there was a specific build of kernel * 5.x where SO_RCVTIMEO didn't affect the connect call above, so * this function would hang for over a minute for unreachable hosts. * Leave in this debug check for twice the configured timeout */ - dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect call took %dms\n", - addr, connect_time_ms); + dnbd3_dev_dbg_host(dev, addr, "connect: call took %dms\n", + connect_time_ms); } if (ret != 0) { if (ret == -EINTR) - continue; - dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect failed (%d, blocked %dms)\n", - addr, ret, connect_time_ms); + dnbd3_dev_dbg_host(dev, addr, "connect: interrupted system call (blocked %dms)\n", + connect_time_ms); + else + dnbd3_dev_dbg_host(dev, addr, "connect: failed (%d, blocked %dms)\n", + ret, connect_time_ms); goto error; } - return sock; + *sock_out = sock; + return 0; } error: sock_release(sock); - return NULL; + return ret < 0 ? ret : -EIO; } #define dnbd3_err_dbg_host(...) do { \ @@ -837,37 +706,39 @@ error: * server, so we validate the filesize, rid, name against what we expect. * The server's protocol version is returned in 'remote_version' */ -static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, - struct sockaddr_storage *addr, uint16_t *remote_version) +static bool dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, + struct sockaddr_storage *addr, uint16_t *remote_version, bool copy_data) { + unsigned long irqflags; const char *name; uint64_t filesize; int mlen; - uint16_t rid, initial_connect; - struct msghdr msg; + uint16_t rid; + struct msghdr msg = { .msg_flags = MSG_NOSIGNAL | MSG_WAITALL }; struct kvec iov[2]; serialized_buffer_t *payload; - dnbd3_reply_t dnbd3_reply; - dnbd3_request_t dnbd3_request = { .magic = dnbd3_packet_magic }; + dnbd3_reply_t reply_hdr; + dnbd3_request_t request_hdr = { .magic = dnbd3_packet_magic }; payload = kmalloc(sizeof(*payload), GFP_KERNEL); if (payload == NULL) goto error; - initial_connect = (dev->reported_size == 0); - init_msghdr(msg); + if (copy_data && device_active(dev)) { + dev_warn(dnbd3_device_to_dev(dev), "Called handshake function with copy_data enabled when reported_size is not zero\n"); + } // Request filesize - dnbd3_request.cmd = CMD_SELECT_IMAGE; - iov[0].iov_base = &dnbd3_request; - iov[0].iov_len = sizeof(dnbd3_request); + request_hdr.cmd = CMD_SELECT_IMAGE; + iov[0].iov_base = &request_hdr; + iov[0].iov_len = sizeof(request_hdr); serializer_reset_write(payload); serializer_put_uint16(payload, PROTOCOL_VERSION); // DNBD3 protocol version serializer_put_string(payload, dev->imgname); // image name serializer_put_uint16(payload, dev->rid); // revision id serializer_put_uint8(payload, 0); // are we a server? (no!) iov[1].iov_base = payload; - dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(payload); - fixup_request(dnbd3_request); + request_hdr.size = iov[1].iov_len = serializer_get_written_length(payload); + fixup_request(request_hdr); mlen = iov[0].iov_len + iov[1].iov_len; if (kernel_sendmsg(sock, &msg, iov, 2, mlen) != mlen) { dnbd3_err_dbg_host(dev, addr, "requesting image size failed\n"); @@ -875,28 +746,28 @@ static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, } // receive net reply - iov[0].iov_base = &dnbd3_reply; - iov[0].iov_len = sizeof(dnbd3_reply); - if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply)) { + if (dnbd3_recv_reply(sock, &reply_hdr) != sizeof(reply_hdr)) { dnbd3_err_dbg_host(dev, addr, "receiving image size packet (header) failed\n"); goto error; } - fixup_reply(dnbd3_reply); - if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_SELECT_IMAGE || dnbd3_reply.size < 4) { + if (reply_hdr.magic != dnbd3_packet_magic + || reply_hdr.cmd != CMD_SELECT_IMAGE || reply_hdr.size < 4 + || reply_hdr.size > sizeof(*payload)) { dnbd3_err_dbg_host(dev, addr, - "corrupted CMD_SELECT_IMAGE reply\n"); + "corrupt CMD_SELECT_IMAGE reply\n"); goto error; } // receive data iov[0].iov_base = payload; - iov[0].iov_len = dnbd3_reply.size; - if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) { + iov[0].iov_len = reply_hdr.size; + if (kernel_recvmsg(sock, &msg, iov, 1, reply_hdr.size, msg.msg_flags) + != reply_hdr.size) { dnbd3_err_dbg_host(dev, addr, "receiving payload of CMD_SELECT_IMAGE reply failed\n"); goto error; } - serializer_reset_read(payload, dnbd3_reply.size); + serializer_reset_read(payload, reply_hdr.size); *remote_version = serializer_get_uint16(payload); name = serializer_get_string(payload); @@ -910,7 +781,6 @@ static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, (int)MIN_SUPPORTED_SERVER); goto error; } - if (name == NULL) { dnbd3_err_dbg_host(dev, addr, "server did not supply an image name\n"); goto error; @@ -920,20 +790,16 @@ static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, goto error; } - /* only check image name if this isn't the initial connect */ - if (initial_connect && dev->rid != 0 && strcmp(name, dev->imgname) != 0) { - dnbd3_err_dbg_host(dev, addr, "server offers image '%s', requested '%s'\n", name, dev->imgname); - goto error; - } - - if (initial_connect) { + if (copy_data) { if (filesize < DNBD3_BLOCK_SIZE) { dnbd3_err_dbg_host(dev, addr, "reported size by server is < 4096\n"); goto error; } + spin_lock_irqsave(&dev->blk_lock, irqflags); if (strlen(dev->imgname) < strlen(name)) { dev->imgname = krealloc(dev->imgname, strlen(name) + 1, GFP_KERNEL); if (dev->imgname == NULL) { + spin_unlock_irqrestore(&dev->blk_lock, irqflags); dnbd3_err_dbg_host(dev, addr, "reallocating buffer for new image name failed\n"); goto error; } @@ -942,9 +808,10 @@ static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, dev->rid = rid; // store image information dev->reported_size = filesize; + dev->update_available = 0; + spin_unlock_irqrestore(&dev->blk_lock, irqflags); set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */ dnbd3_dev_dbg_host(dev, addr, "image size: %llu\n", dev->reported_size); - dev->update_available = 0; } else { /* switching connection, sanity checks */ if (rid != dev->rid) { @@ -954,6 +821,11 @@ static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, goto error; } + if (strcmp(name, dev->imgname) != 0) { + dnbd3_err_dbg_host(dev, addr, "server offers image '%s', requested '%s'\n", name, dev->imgname); + goto error; + } + if (filesize != dev->reported_size) { dnbd3_err_dbg_host(dev, addr, "reported image size of %llu does not match expected value %llu\n", @@ -962,251 +834,287 @@ static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock, } } kfree(payload); - return 1; + return true; error: kfree(payload); - return 0; + return false; +} + +static bool dnbd3_send_request(struct socket *sock, u16 cmd, u64 handle, u64 offset, u32 size) +{ + struct msghdr msg = { .msg_flags = MSG_NOSIGNAL }; + dnbd3_request_t request_hdr = { + .magic = dnbd3_packet_magic, + .cmd = cmd, + .size = size, + .offset = offset, + .handle = handle, + }; + struct kvec iov = { .iov_base = &request_hdr, .iov_len = sizeof(request_hdr) }; + + fixup_request(request_hdr); + return kernel_sendmsg(sock, &msg, &iov, 1, sizeof(request_hdr)) == sizeof(request_hdr); +} + +/** + * Send a request with given cmd type and empty payload. + */ +static bool dnbd3_send_empty_request(dnbd3_device_t *dev, u16 cmd) +{ + int ret; + + mutex_lock(&dev->send_mutex); + ret = dev->sock + && dnbd3_send_request(dev->sock, cmd, 0, 0, 0); + mutex_unlock(&dev->send_mutex); + return ret; +} + +static int dnbd3_recv_bytes(struct socket *sock, void *buffer, size_t count) +{ + struct msghdr msg = { .msg_flags = MSG_NOSIGNAL | MSG_WAITALL }; + struct kvec iov = { .iov_base = buffer, .iov_len = count }; + + return kernel_recvmsg(sock, &msg, &iov, 1, count, msg.msg_flags); +} + +static int dnbd3_recv_reply(struct socket *sock, dnbd3_reply_t *reply_hdr) +{ + int ret = dnbd3_recv_bytes(sock, reply_hdr, sizeof(*reply_hdr)); + + fixup_reply(*reply_hdr); + return ret; } -static int dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock) +static bool dnbd3_drain_socket(dnbd3_device_t *dev, struct socket *sock, int bytes) { - dnbd3_request_t dnbd3_request = { .magic = dnbd3_packet_magic }; - dnbd3_reply_t dnbd3_reply; + int ret; struct kvec iov; - struct msghdr msg; - char *buf = NULL; - char smallbuf[256]; - int remaining, buffer_size, ret, func_return; + struct msghdr msg = { .msg_flags = MSG_NOSIGNAL }; + + while (bytes > 0) { + iov.iov_base = __garbage_mem; + iov.iov_len = sizeof(__garbage_mem); + ret = kernel_recvmsg(sock, &msg, &iov, 1, MIN(bytes, iov.iov_len), msg.msg_flags); + if (ret <= 0) { + dnbd3_dev_err_host_cur(dev, "draining payload failed (ret=%d)\n", ret); + return false; + } + bytes -= ret; + } + return true; +} - init_msghdr(msg); +static bool dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock) +{ + dnbd3_reply_t reply_hdr; - func_return = 0; // Request block - dnbd3_request.cmd = CMD_GET_BLOCK; - // Do *NOT* pick a random block as it has proven to cause severe - // cache thrashing on the server - dnbd3_request.offset = 0; - dnbd3_request.size = RTT_BLOCK_SIZE; - fixup_request(dnbd3_request); - iov.iov_base = &dnbd3_request; - iov.iov_len = sizeof(dnbd3_request); - - if (kernel_sendmsg(sock, &msg, &iov, 1, sizeof(dnbd3_request)) <= 0) { + if (!dnbd3_send_request(sock, CMD_GET_BLOCK, 0, 0, RTT_BLOCK_SIZE)) { dnbd3_err_dbg_host(dev, addr, "requesting test block failed\n"); - goto error; + return false; } // receive net reply - iov.iov_base = &dnbd3_reply; - iov.iov_len = sizeof(dnbd3_reply); - if (kernel_recvmsg(sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags) - != sizeof(dnbd3_reply)) { + if (dnbd3_recv_reply(sock, &reply_hdr) != sizeof(reply_hdr)) { dnbd3_err_dbg_host(dev, addr, "receiving test block header packet failed\n"); - goto error; + return false; } - fixup_reply(dnbd3_reply); - if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_GET_BLOCK - || dnbd3_reply.size != RTT_BLOCK_SIZE) { + if (reply_hdr.magic != dnbd3_packet_magic || reply_hdr.cmd != CMD_GET_BLOCK + || reply_hdr.size != RTT_BLOCK_SIZE || reply_hdr.handle != 0) { dnbd3_err_dbg_host(dev, addr, - "unexpected reply to block request: cmd=%d, size=%d (discover)\n", - (int)dnbd3_reply.cmd, (int)dnbd3_reply.size); - goto error; + "unexpected reply to block request: cmd=%d, size=%d, handle=%llu (discover)\n", + (int)reply_hdr.cmd, (int)reply_hdr.size, reply_hdr.handle); + return false; } // receive data - buf = kmalloc(DNBD3_BLOCK_SIZE, GFP_NOWAIT); - if (buf == NULL) { - /* fallback to stack if we're really memory constrained */ - buf = smallbuf; - buffer_size = sizeof(smallbuf); - } else { - buffer_size = DNBD3_BLOCK_SIZE; - } - remaining = RTT_BLOCK_SIZE; - /* TODO in either case we could build a large iovec that points to the same buffer over and over again */ - while (remaining > 0) { - iov.iov_base = buf; - iov.iov_len = buffer_size; - ret = kernel_recvmsg(sock, &msg, &iov, 1, MIN(remaining, buffer_size), msg.msg_flags); - if (ret <= 0) { - dnbd3_err_dbg_host(dev, addr, "receiving test block payload failed (ret=%d)\n", ret); - goto error; - } - remaining -= ret; - } - func_return = 1; - // Fallthrough! -error: - if (buf != smallbuf) - kfree(buf); - return func_return; + return dnbd3_drain_socket(dev, sock, RTT_BLOCK_SIZE); } #undef dnbd3_err_dbg_host -static int spawn_worker_thread(dnbd3_device_t *dev, struct task_struct **task, const char *name, - int (*threadfn)(void *data)) +static void replace_main_socket(dnbd3_device_t *dev, struct socket *sock, struct sockaddr_storage *addr, u16 protocol_version) { - ASSERT(*task == NULL); - *task = kthread_create(threadfn, dev, "%s-%s", dev->disk->disk_name, name); - if (!IS_ERR(*task)) { - get_task_struct(*task); - wake_up_process(*task); - return 1; + unsigned long irqflags; + + mutex_lock(&dev->send_mutex); + // First, shutdown connection, so receive worker will leave its mainloop + if (dev->sock) + kernel_sock_shutdown(dev->sock, SHUT_RDWR); + mutex_lock(&dev->recv_mutex); + // Receive worker is done, get rid of socket and replace + if (dev->sock) + sock_release(dev->sock); + dev->sock = sock; + spin_lock_irqsave(&dev->blk_lock, irqflags); + if (addr == NULL) { + memset(&dev->cur_server, 0, sizeof(dev->cur_server)); + } else { + dev->cur_server.host = *addr; + dev->cur_server.rtt = 0; + dev->cur_server.protocol_version = protocol_version; } - dev_err(dnbd3_device_to_dev(dev), "failed to create %s thread (%ld)\n", - name, PTR_ERR(*task)); - /* reset possible non-NULL error value */ - *task = NULL; - return 0; + spin_unlock_irqrestore(&dev->blk_lock, irqflags); + mutex_unlock(&dev->recv_mutex); + mutex_unlock(&dev->send_mutex); } -static void stop_worker_thread(dnbd3_device_t *dev, struct task_struct **task, const char *name, int quiet) +static void dnbd3_release_resources(dnbd3_device_t *dev) { - int ret; - - if (*task == NULL) - return; - if (!quiet) - dnbd3_dev_dbg_host_cur(dev, "stop %s thread\n", name); - ret = kthread_stop(*task); - put_task_struct(*task); - if (ret == -EINTR) { - /* thread has never been scheduled and run */ - if (!quiet) - dev_dbg(dnbd3_device_to_dev(dev), "%s thread has never run\n", name); - } else { - /* thread has run, check if it has terminated successfully */ - if (ret < 0 && !quiet) - dev_err(dnbd3_device_to_dev(dev), "%s thread was not terminated correctly\n", name); - } - *task = NULL; + if (dev->send_wq) + destroy_workqueue(dev->send_wq); + dev->send_wq = NULL; + if (dev->recv_wq) + destroy_workqueue(dev->recv_wq); + dev->recv_wq = NULL; + mutex_destroy(&dev->send_mutex); + mutex_destroy(&dev->recv_mutex); } -int dnbd3_net_connect(dnbd3_device_t *dev) +/** + * Establish new connection on a dnbd3 device. + * Return 0 on success, errno otherwise + */ +int dnbd3_new_connection(dnbd3_device_t *dev, struct sockaddr_storage *addr, bool init) { - struct request *req_alt_servers = NULL; unsigned long irqflags; + struct socket *sock = NULL; + uint16_t proto_version; + int ret; - ASSERT(atomic_read(&dev->connection_lock)); - - if (dev->use_server_provided_alts) { - req_alt_servers = kmalloc(sizeof(*req_alt_servers), GFP_KERNEL); - if (req_alt_servers == NULL) - dnbd3_dev_err_host_cur(dev, "Cannot allocate memory to request list of alt servers\n"); + ASSERT(dnbd3_flag_taken(dev->connection_lock)); + if (init && device_active(dev)) { + dnbd3_dev_err_host_cur(dev, "device already configured/connected\n"); + return -EBUSY; + } + if (!init && !device_active(dev)) { + dev_warn(dnbd3_device_to_dev(dev), "connection switch called on unconfigured device\n"); + return -ENOTCONN; } - if (dev->cur_server.host.ss_family == 0 || dev->imgname == NULL) { - dnbd3_dev_err_host_cur(dev, "connect: host or image name not set\n"); + dnbd3_dev_dbg_host(dev, addr, "connecting...\n"); + ret = dnbd3_connect(dev, addr, &sock); + if (ret != 0 || sock == NULL) goto error; - } - if (dev->sock) { - dnbd3_dev_err_host_cur(dev, "socket already connected\n"); + /* execute the "select image" handshake */ + // if init is true, reported_size will be set + if (!dnbd3_execute_handshake(dev, sock, addr, &proto_version, init)) { + ret = -EINVAL; goto error; } - ASSERT(dev->thread_send == NULL); - ASSERT(dev->thread_receive == NULL); - ASSERT(dev->thread_discover == NULL); - - if (dev->better_sock != NULL) { - // Switching server, connection is already established and size request was executed - dnbd3_dev_dbg_host_cur(dev, "on-the-fly server change\n"); - dev->sock = dev->better_sock; - dev->better_sock = NULL; - } else { - // no established connection yet from discovery thread, start new one - uint16_t proto_version; - - dnbd3_dev_dbg_host_cur(dev, "connecting\n"); - dev->sock = dnbd3_connect(dev, &dev->cur_server.host); - if (dev->sock == NULL) { - dnbd3_dev_err_host_cur(dev, "%s: Failed\n", __func__); - goto error; + if (init) { + // We're setting up the device for use - allocate resources + // Do not goto error before this + ASSERT(!dev->send_wq); + ASSERT(!dev->recv_wq); + mutex_init(&dev->send_mutex); + mutex_init(&dev->recv_mutex); + // a designated queue for sending, that allows one active task only + dev->send_wq = alloc_workqueue("dnbd%d-send", + WQ_UNBOUND | WQ_FREEZABLE | WQ_MEM_RECLAIM | WQ_HIGHPRI, + 1, dev->index); + dev->recv_wq = alloc_workqueue("dnbd%d-recv", + WQ_UNBOUND | WQ_FREEZABLE | WQ_MEM_RECLAIM | WQ_HIGHPRI | WQ_CPU_INTENSIVE, + 1, dev->index); + if (!dev->send_wq || !dev->recv_wq) { + ret = -ENOMEM; + goto error_dealloc; } - /* execute the "select image" handshake */ - if (!dnbd3_execute_handshake(dev, dev->sock, &dev->cur_server.host, &proto_version)) - goto error; - - spin_lock_irqsave(&dev->blk_lock, irqflags); - dev->cur_server.protocol_version = proto_version; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); } - /* create required threads */ - if (!spawn_worker_thread(dev, &dev->thread_send, "send", dnbd3_net_send)) - goto error; - if (!spawn_worker_thread(dev, &dev->thread_receive, "receive", dnbd3_net_receive)) - goto error; - if (!spawn_worker_thread(dev, &dev->thread_discover, "discover", dnbd3_net_discover)) - goto error; + set_socket_timeout(sock, false, SOCKET_TIMEOUT_RECV * 1000); // recv + dnbd3_set_primary_connection(dev, sock, addr, proto_version); + sock = NULL; // In case we ever goto error* after this point - dnbd3_dev_dbg_host_cur(dev, "connection established\n"); - dev->panic = 0; - dev->panic_count = 0; + spin_lock_irqsave(&dev->blk_lock, irqflags); + if (init) { + dev->discover_count = 0; + dev->discover_interval = TIMER_INTERVAL_PROBE_STARTUP; + // discovery and keepalive are not critical, use the power efficient queue + queue_delayed_work(system_power_efficient_wq, &dev->discover_work, + dev->discover_interval * HZ); + queue_delayed_work(system_power_efficient_wq, &dev->keepalive_work, + KEEPALIVE_INTERVAL * HZ); + // but the receiver is performance critical AND runs indefinitely, use the + // the cpu intensive queue, as jobs submitted there will not cound towards + // the concurrency limit of per-cpu worker threads. It still feels a little + // dirty to avoid managing our own thread, but nbd does it too. + } + spin_unlock_irqrestore(&dev->blk_lock, irqflags); + return 0; - if (req_alt_servers != NULL) { - // Enqueue request to request_queue_send for a fresh list of alt servers - dnbd3_cmd_to_priv(req_alt_servers, CMD_GET_SERVERS); - spin_lock_irqsave(&dev->blk_lock, irqflags); - list_add(&req_alt_servers->queuelist, &dev->request_queue_send); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - wake_up(&dev->process_queue_send); +error_dealloc: + if (init) { + // If anything fails during initialization, free resources again + dnbd3_release_resources(dev); } +error: + if (init) + dev->reported_size = 0; + if (sock) + sock_release(sock); + return ret < 0 ? ret : -EIO; +} - // add heartbeat timer - // Do not goto error after creating the timer - we require that the timer exists - // if dev->sock != NULL -- see dnbd3_net_disconnect - dev->heartbeat_count = 0; - timer_setup(&dev->hb_timer, dnbd3_net_heartbeat, 0); - dev->hb_timer.expires = jiffies + HZ; - add_timer(&dev->hb_timer); +void dnbd3_net_work_init(dnbd3_device_t *dev) +{ + INIT_WORK(&dev->send_work, dnbd3_send_workfn); + INIT_WORK(&dev->recv_work, dnbd3_recv_workfn); + INIT_DELAYED_WORK(&dev->discover_work, dnbd3_discover_workfn); + INIT_DELAYED_WORK(&dev->keepalive_work, dnbd3_keepalive_workfn); +} - return 0; +static int dnbd3_set_primary_connection(dnbd3_device_t *dev, struct socket *sock, struct sockaddr_storage *addr, u16 protocol_version) +{ + unsigned long irqflags; -error: - stop_worker_thread(dev, &dev->thread_send, "send", 1); - stop_worker_thread(dev, &dev->thread_receive, "receive", 1); - stop_worker_thread(dev, &dev->thread_discover, "discover", 1); - if (dev->sock) { - sock_release(dev->sock); - dev->sock = NULL; + ASSERT(dnbd3_flag_taken(dev->connection_lock)); + if (addr->ss_family == 0 || dev->imgname == NULL || sock == NULL) { + dnbd3_dev_err_host_cur(dev, "connect: host, image name or sock not set\n"); + return -EINVAL; } + + replace_main_socket(dev, sock, addr, protocol_version); spin_lock_irqsave(&dev->blk_lock, irqflags); - dev->cur_server.host.ss_family = 0; + dev->panic = false; + dev->panic_count = 0; + dev->discover_interval = TIMER_INTERVAL_PROBE_SWITCH; + queue_work(dev->recv_wq, &dev->recv_work); spin_unlock_irqrestore(&dev->blk_lock, irqflags); - kfree(req_alt_servers); + if (dev->use_server_provided_alts) { + dnbd3_send_empty_request(dev, CMD_GET_SERVERS); + } - return -1; + dnbd3_dev_dbg_host_cur(dev, "connection switched\n"); + dnbd3_blk_requeue_all_requests(dev); + return 0; } +/** + * Disconnect the device, shutting it down. + */ int dnbd3_net_disconnect(dnbd3_device_t *dev) { - unsigned long irqflags; - + ASSERT(dnbd3_flag_taken(dev->connection_lock)); + if (!device_active(dev)) + return -ENOTCONN; dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n"); - ASSERT(atomic_read(&dev->connection_lock)); - dev->discover = 0; - - if (dev->sock) { - kernel_sock_shutdown(dev->sock, SHUT_RDWR); - // clear heartbeat timer - del_timer(&dev->hb_timer); - } + dev->reported_size = 0; + /* quickly fail all requests */ + dnbd3_blk_fail_all_requests(dev); + replace_main_socket(dev, NULL, NULL, 0); - // kill sending and receiving threads - stop_worker_thread(dev, &dev->thread_send, "send", 0); - stop_worker_thread(dev, &dev->thread_receive, "receive", 0); - stop_worker_thread(dev, &dev->thread_discover, "discover", 0); - if (dev->sock) { - sock_release(dev->sock); - dev->sock = NULL; - } - spin_lock_irqsave(&dev->blk_lock, irqflags); - dev->cur_server.host.ss_family = 0; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); + cancel_delayed_work_sync(&dev->keepalive_work); + cancel_delayed_work_sync(&dev->discover_work); + cancel_work_sync(&dev->send_work); + cancel_work_sync(&dev->recv_work); + dnbd3_blk_fail_all_requests(dev); + dnbd3_release_resources(dev); + dev_dbg(dnbd3_device_to_dev(dev), "all workers shut down\n"); return 0; } |