diff options
author | Simon Rettberg | 2021-03-23 16:00:55 +0100 |
---|---|---|
committer | Simon Rettberg | 2021-03-23 16:00:55 +0100 |
commit | b9c11caeb31d1066979b8554f565f24abfe475f6 (patch) | |
tree | 8a298a7122328ddd1ff2340b74d19667e3b107c2 /src | |
parent | [SERVER] Fix compiler warning (diff) | |
download | dnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.tar.gz dnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.tar.xz dnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.zip |
[KERNEL] Synchronous add/remove of alt-servers via IOCTL
Diffstat (limited to 'src')
-rw-r--r-- | src/kernel/blk.c | 91 | ||||
-rw-r--r-- | src/kernel/dnbd3_main.c | 82 | ||||
-rw-r--r-- | src/kernel/dnbd3_main.h | 14 | ||||
-rw-r--r-- | src/kernel/net.c | 207 |
4 files changed, 229 insertions, 165 deletions
diff --git a/src/kernel/blk.c b/src/kernel/blk.c index 63a2f8c..661bc2e 100644 --- a/src/kernel/blk.c +++ b/src/kernel/blk.c @@ -63,9 +63,6 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int struct request_queue *blk_queue = dev->disk->queue; char *imgname = NULL; dnbd3_ioctl_t *msg = NULL; - dnbd3_server_entry_t server; - dnbd3_server_t old_server; - dnbd3_server_t *alt_server; unsigned long irqflags; int i = 0; u8 locked = 0; @@ -197,27 +194,33 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int } locked = 1; if (dev->imgname == NULL) { - result = -ENOENT; + result = -ENOTCONN; } else if (msg == NULL) { result = -EINVAL; } else { - /* convert host to dnbd3-server for switching */ - memcpy(&server.host, &msg->hosts[0], sizeof(server.host)); - server.failures = 0; + dnbd3_server_t *alt_server; - alt_server = get_existing_server(&server, dev); + mutex_lock(&dev->alt_servers_lock); + alt_server = get_existing_server(&msg->hosts[0], dev); if (alt_server == NULL) { + mutex_unlock(&dev->alt_servers_lock); /* specified server is not known, so do not switch */ - result = -EINVAL; + result = -ENOENT; } else { /* specified server is known, so try to switch to it */ - if (!is_same_server(&dev->cur_server, alt_server)) { - if (alt_server->host.type == HOST_IP4) + dnbd3_server_t new_server = *alt_server; + + new_server = *alt_server; + mutex_unlock(&dev->alt_servers_lock); + if (!is_same_server(&dev->cur_server, &new_server)) { + dnbd3_server_t old_server; + + if (new_server.host.type == HOST_IP4) dev_info(dnbd3_device_to_dev(dev), "manual server switch to %pI4\n", - alt_server->host.addr); + new_server.host.addr); else dev_info(dnbd3_device_to_dev(dev), "manual server switch to [%pI6]\n", - alt_server->host.addr); + new_server.host.addr); /* save current working server */ /* lock device to get consistent copy of current working server */ spin_lock_irqsave(&dev->blk_lock, irqflags); @@ -228,7 +231,7 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int dnbd3_net_disconnect(dev); /* connect to new specified server (switching) */ - memcpy(&dev->cur_server, alt_server, sizeof(dev->cur_server)); + memcpy(&dev->cur_server, &new_server, sizeof(dev->cur_server)); result = dnbd3_net_connect(dev); if (result != 0) { /* reconnect with old server if switching has failed */ @@ -255,24 +258,50 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int case IOCTL_ADD_SRV: case IOCTL_REM_SRV: - if (dev->imgname == NULL) - result = -ENOENT; - else if (msg == NULL) + if (dev->imgname == NULL) { + result = -ENOTCONN; + break; + } + if (msg == NULL) { result = -EINVAL; - - /* protect access to 'new_servers_num' and 'new_servers' */ - spin_lock_irqsave(&dev->blk_lock, irqflags); - if (dev->new_servers_num >= NUMBER_SERVERS) { - result = -EAGAIN; - } else { - /* add or remove specified server */ - memcpy(&dev->new_servers[dev->new_servers_num].host, &msg->hosts[0], sizeof(msg->hosts[0])); - dev->new_servers[dev->new_servers_num].failures = - (cmd == IOCTL_ADD_SRV ? 0 : 1); // 0 = ADD, 1 = REM - ++dev->new_servers_num; - result = 0; + break; + } + if (cmd == IOCTL_ADD_SRV) { + dnbd3_host_t *host = &msg->hosts[0]; + + result = dnbd3_add_server(dev, host); + if (result == -EEXIST) { + // Exists + if (host->type == HOST_IP4) { + dev_info(dnbd3_device_to_dev(dev), "alt server %pI4 already exists\n", + host->addr); + } else { + dev_info(dnbd3_device_to_dev(dev), "alt server [%pI6] already exists\n", + host->addr); + } + } else if (result == -ENOSPC) { + if (host->type == HOST_IP4) { + dev_info(dnbd3_device_to_dev(dev), "cannot add %pI4; no free slot\n", + host->addr); + } else { + dev_info(dnbd3_device_to_dev(dev), "cannot add [%pI6]; no free slot\n", + host->addr); + } + } + } else { // IOCTL_REM_SRV + dnbd3_host_t *host = &msg->hosts[0]; + + result = dnbd3_rem_server(dev, &msg->hosts[0]); + if (result == -ENOENT) { + if (host->type == HOST_IP4) { + dev_info(dnbd3_device_to_dev(dev), "alt server %pI4 not found\n", + host->addr); + } else { + dev_info(dnbd3_device_to_dev(dev), "alt server [%pI6] not found\n", + host->addr); + } + } } - spin_unlock_irqrestore(&dev->blk_lock, irqflags); break; case BLKFLSBUF: @@ -343,6 +372,7 @@ int dnbd3_blk_add_device(dnbd3_device_t *dev, int minor) dev->imgname = NULL; dev->rid = 0; dev->update_available = 0; + mutex_init(&dev->alt_servers_lock); memset(dev->alt_servers, 0, sizeof(dev->alt_servers[0]) * NUMBER_SERVERS); dev->thread_send = NULL; dev->thread_receive = NULL; @@ -433,6 +463,7 @@ int dnbd3_blk_del_device(dnbd3_device_t *dev) del_gendisk(dev->disk); blk_cleanup_queue(dev->queue); blk_mq_free_tag_set(&dev->tag_set); + mutex_destroy(&dev->alt_servers_lock); put_disk(dev->disk); return 0; } diff --git a/src/kernel/dnbd3_main.c b/src/kernel/dnbd3_main.c index 6e5b4a7..17d553e 100644 --- a/src/kernel/dnbd3_main.c +++ b/src/kernel/dnbd3_main.c @@ -41,21 +41,93 @@ int is_same_server(const dnbd3_server_t *const a, const dnbd3_server_t *const b) (0 == memcmp(a->host.addr, b->host.addr, (a->host.type == HOST_IP4 ? 4 : 16))); } -dnbd3_server_t *get_existing_server(const dnbd3_server_entry_t *const newserver, dnbd3_device_t *const dev) +/** + * Get a free slot pointer from the alt_servers list. Tries to find an + * entirely empty slot first, then looks for a slot with a server that + * wasn't reachable recently, finally returns NULL if none of the + * conditions match. + * The caller has to hold dev->alt_servers_lock. + */ +static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t *const dev) { int i; for (i = 0; i < NUMBER_SERVERS; ++i) { - if ((newserver->host.type == dev->alt_servers[i].host.type) && - (newserver->host.port == dev->alt_servers[i].host.port) && - (0 == memcmp(newserver->host.addr, dev->alt_servers[i].host.addr, - (newserver->host.type == HOST_IP4 ? 4 : 16)))) { + if (dev->alt_servers[i].host.type == 0) + return &dev->alt_servers[i]; + } + for (i = 0; i < NUMBER_SERVERS; ++i) { + if (dev->alt_servers[i].failures > 10) + return &dev->alt_servers[i]; + } + return NULL; +} + +/** + * Returns pointer to existing entry in alt_servers that matches the given + * alt server, or NULL if not found. + */ +dnbd3_server_t *get_existing_server(const dnbd3_host_t *const newserver, dnbd3_device_t *const dev) +{ + int i; + + for (i = 0; i < NUMBER_SERVERS; ++i) { + if ((newserver->type == dev->alt_servers[i].host.type) && + (newserver->port == dev->alt_servers[i].host.port) && + (0 == memcmp(newserver->addr, dev->alt_servers[i].host.addr, + (newserver->type == HOST_IP4 ? 4 : 16)))) { return &dev->alt_servers[i]; } } return NULL; } +int dnbd3_add_server(dnbd3_device_t *dev, dnbd3_host_t *host) +{ + int result; + dnbd3_server_t *alt_server; + /* protect access to 'alt_servers' */ + mutex_lock(&dev->alt_servers_lock); + alt_server = get_existing_server(host, dev); + // ADD + if (alt_server != NULL) { + // Exists + result = -EEXIST; + } else { + // OK add + alt_server = get_free_alt_server(dev); + if (alt_server == NULL) { + result = -ENOSPC; + } else { + alt_server->host = *host; + alt_server->failures = 0; + result = 0; + } + } + mutex_unlock(&dev->alt_servers_lock); + return result; +} + +int dnbd3_rem_server(dnbd3_device_t *dev, dnbd3_host_t *host) +{ + dnbd3_server_t *alt_server; + int result; + /* protect access to 'alt_servers' */ + mutex_lock(&dev->alt_servers_lock); + alt_server = get_existing_server(host, dev); + // REMOVE + if (alt_server == NULL) { + // Not found + result = -ENOENT; + } else { + // Remove + alt_server->host.type = 0; + result = 0; + } + mutex_unlock(&dev->alt_servers_lock); + return result; +} + static int __init dnbd3_init(void) { int i; diff --git a/src/kernel/dnbd3_main.h b/src/kernel/dnbd3_main.h index c5b0930..a69d588 100644 --- a/src/kernel/dnbd3_main.h +++ b/src/kernel/dnbd3_main.h @@ -27,6 +27,7 @@ #include <linux/module.h> #include <linux/blkdev.h> #include <linux/blk-mq.h> +#include <linux/mutex.h> #include <net/sock.h> #include <dnbd3/config.h> @@ -53,14 +54,13 @@ typedef struct { struct kobject kobj; // network + struct mutex alt_servers_lock; char *imgname; struct socket *sock; dnbd3_server_t cur_server; unsigned long cur_rtt; serialized_buffer_t payload_buffer; - dnbd3_server_t alt_servers[NUMBER_SERVERS]; // array of alt servers - int new_servers_num; // number of new alt servers that are waiting to be copied to above array - dnbd3_server_entry_t new_servers[NUMBER_SERVERS]; // pending new alt servers + dnbd3_server_t alt_servers[NUMBER_SERVERS]; // array of alt servers, protected by altservers_lock uint8_t discover, panic, update_available, panic_count; atomic_t connection_lock; uint8_t use_server_provided_alts; @@ -85,7 +85,11 @@ typedef struct { extern inline struct device *dnbd3_device_to_dev(dnbd3_device_t *dev); extern inline int is_same_server(const dnbd3_server_t *const a, const dnbd3_server_t *const b); -extern inline dnbd3_server_t *get_existing_server(const dnbd3_server_entry_t *const newserver, - dnbd3_device_t *const dev); + +extern dnbd3_server_t *get_existing_server(const dnbd3_host_t *const newserver, dnbd3_device_t *const dev); + +extern int dnbd3_add_server(dnbd3_device_t *dev, dnbd3_host_t *host); + +extern int dnbd3_rem_server(dnbd3_device_t *dev, dnbd3_host_t *host); #endif /* DNBD_H_ */ diff --git a/src/kernel/net.c b/src/kernel/net.c index e62a9df..6d821fc 100644 --- a/src/kernel/net.c +++ b/src/kernel/net.c @@ -100,21 +100,6 @@ static struct socket *dnbd3_connect(dnbd3_device_t *dev, dnbd3_host_t *host); -static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t *const dev) -{ - int i; - - for (i = 0; i < NUMBER_SERVERS; ++i) { - if (dev->alt_servers[i].host.type == 0) - return &dev->alt_servers[i]; - } - for (i = 0; i < NUMBER_SERVERS; ++i) { - if (dev->alt_servers[i].failures > 10) - return &dev->alt_servers[i]; - } - return NULL; -} - static void dnbd3_net_heartbeat(struct timer_list *arg) { dnbd3_device_t *dev = (dnbd3_device_t *)container_of(arg, dnbd3_device_t, hb_timer); @@ -168,7 +153,7 @@ static int dnbd3_net_discover(void *data) dnbd3_request_t dnbd3_request; dnbd3_reply_t dnbd3_reply; - dnbd3_server_t *alt_server; + dnbd3_server_t host_compare, best_server; struct msghdr msg; struct kvec iov[2]; @@ -176,11 +161,12 @@ static int dnbd3_net_discover(void *data) serialized_buffer_t *payload; uint64_t filesize; uint16_t rid; + uint16_t remote_version; ktime_t start = 0, end = 0; unsigned long rtt, best_rtt = 0; unsigned long irqflags; - int i, j, isize, best_server, current_server; + int i, j, isize; int turn = 0; int ready = 0, do_change = 0; char check_order[NUMBER_SERVERS]; @@ -218,53 +204,7 @@ static int dnbd3_net_discover(void *data) if (dev->reported_size < 4096) continue; - // Check if the list of alt servers needs to be updated and do so if necessary - spin_lock_irqsave(&dev->blk_lock, irqflags); - if (dev->new_servers_num) { - for (i = 0; i < dev->new_servers_num; ++i) { - if (dev->new_servers[i].host.type != HOST_IP4 && - dev->new_servers[i].host.type != HOST_IP6) // Invalid entry? - continue; - alt_server = get_existing_server(&dev->new_servers[i], dev); - if (alt_server != NULL) { - // Server already known - if (dev->new_servers[i].failures == 1) { - // REMOVE request - if (alt_server->host.type == HOST_IP4) - dnbd3_dev_dbg_host_cur(dev, "removing alt server %pI4\n", - alt_server->host.addr); - else - dnbd3_dev_dbg_host_cur(dev, "removing alt server [%pI6]\n", - alt_server->host.addr); - alt_server->host.type = 0; - continue; - } - // ADD, so just reset fail counter - alt_server->failures = 0; - continue; - } - if (dev->new_servers[i].failures == 1) // REMOVE, but server is not in list anyways - continue; - alt_server = get_free_alt_server(dev); - if (alt_server == NULL) // All NUMBER_SERVERS slots are taken, ignore entry - continue; - // Add new server entry - alt_server->host = dev->new_servers[i].host; - if (alt_server->host.type == HOST_IP4) - dnbd3_dev_dbg_host_cur(dev, "adding alt server %pI4\n", alt_server->host.addr); - else - dnbd3_dev_dbg_host_cur(dev, "adding alt server [%pI6]\n", - alt_server->host.addr); - alt_server->rtts[0] = alt_server->rtts[1] = alt_server->rtts[2] = alt_server->rtts[3] = - RTT_UNREACHABLE; - alt_server->protocol_version = 0; - alt_server->failures = 0; - } - dev->new_servers_num = 0; - } - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - - current_server = best_server = -1; + best_server.host.type = 0; best_rtt = 0xFFFFFFFul; if (dev->heartbeat_count < STARTUP_MODE_DURATION || dev->panic) @@ -285,17 +225,19 @@ static int dnbd3_net_discover(void *data) for (j = 0; j < NUMBER_SERVERS; ++j) { i = check_order[j]; - if (dev->alt_servers[i].host.type == 0) // Empty slot - continue; - if (!dev->panic && dev->alt_servers[i].failures > 50 && - (ktime_to_us(start) & 7) != - 0) // If not in panic mode, skip server if it failed too many times - continue; - if (isize-- <= 0 && !is_same_server(&dev->cur_server, &dev->alt_servers[i])) - continue; + mutex_lock(&dev->alt_servers_lock); + host_compare = dev->alt_servers[i]; + mutex_unlock(&dev->alt_servers_lock); + if (host_compare.host.type == 0) + continue; // Empty slot + if (!dev->panic && host_compare.failures > 50 + && (ktime_to_us(start) & 7) != 0) + continue; // If not in panic mode, skip server if it failed too many times + if (isize-- <= 0 && !is_same_server(&dev->cur_server, &host_compare)) + continue; // Only test isize servers plus current server // Initialize socket and connect - sock = dnbd3_connect(dev, &dev->alt_servers[i].host); + sock = dnbd3_connect(dev, &host_compare.host); if (sock == NULL) { dnbd3_dev_dbg_host_alt(dev, "%s: Couldn't connect\n", __func__); goto error; @@ -345,11 +287,11 @@ static int dnbd3_net_discover(void *data) } serializer_reset_read(payload, dnbd3_reply.size); - dev->alt_servers[i].protocol_version = serializer_get_uint16(payload); - if (dev->alt_servers[i].protocol_version < MIN_SUPPORTED_SERVER) { + remote_version = serializer_get_uint16(payload); + if (remote_version < MIN_SUPPORTED_SERVER) { dnbd3_dev_err_host_alt( dev, "server version too old (client: %d, server: %d, min supported: %d)\n", - (int)PROTOCOL_VERSION, (int)dev->alt_servers[i].protocol_version, + (int)PROTOCOL_VERSION, (int)remote_version, (int)MIN_SUPPORTED_SERVER); goto error; } @@ -402,7 +344,9 @@ static int dnbd3_net_discover(void *data) put_task_struct(dev->thread_discover); dev->thread_discover = NULL; dnbd3_net_disconnect(dev); - memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server)); + spin_lock_irqsave(&dev->blk_lock, irqflags); + dev->cur_server = host_compare; + spin_unlock_irqrestore(&dev->blk_lock, irqflags); dnbd3_net_connect(dev); atomic_set(&dev->connection_lock, 0); return 0; @@ -455,16 +399,22 @@ static int dnbd3_net_discover(void *data) end = ktime_get_real(); // end rtt measurement - dev->alt_servers[i].rtts[turn] = (unsigned long)ktime_us_delta(end, start); + mutex_lock(&dev->alt_servers_lock); + if (is_same_server(&dev->alt_servers[i], &host_compare)) { + dev->alt_servers[i].protocol_version = remote_version; + dev->alt_servers[i].rtts[turn] = (unsigned long)ktime_us_delta(end, start); - rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1] + dev->alt_servers[i].rtts[2] + - dev->alt_servers[i].rtts[3]) / - 4; + rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1] + + dev->alt_servers[i].rtts[2] + dev->alt_servers[i].rtts[3]) + / 4; + } + dev->alt_servers[i].failures = 0; + mutex_unlock(&dev->alt_servers_lock); if (best_rtt > rtt) { // This one is better, keep socket open in case we switch best_rtt = rtt; - best_server = i; + best_server = host_compare; if (best_sock != NULL) sock_release(best_sock); best_sock = sock; @@ -476,37 +426,36 @@ static int dnbd3_net_discover(void *data) } // update cur servers rtt - if (is_same_server(&dev->cur_server, &dev->alt_servers[i])) { + if (is_same_server(&dev->cur_server, &host_compare)) dev->cur_rtt = rtt; - current_server = i; - } - - dev->alt_servers[i].failures = 0; continue; error: - ++dev->alt_servers[i].failures; if (sock != NULL) { sock_release(sock); sock = NULL; } - dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE; - if (is_same_server(&dev->cur_server, &dev->alt_servers[i])) { - dev->cur_rtt = RTT_UNREACHABLE; - current_server = i; + mutex_lock(&dev->alt_servers_lock); + if (is_same_server(&dev->alt_servers[i], &host_compare)) { + ++dev->alt_servers[i].failures; + dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE; } + mutex_unlock(&dev->alt_servers_lock); + if (is_same_server(&dev->cur_server, &host_compare)) + dev->cur_rtt = RTT_UNREACHABLE; + continue; } if (dev->panic) { - // After 21 retries, bail out by reporting errors to block layer + // If probe timeout is set, report error to block layer if (PROBE_COUNT_TIMEOUT > 0 && dev->panic_count < 255 && ++dev->panic_count == PROBE_COUNT_TIMEOUT + 1) dnbd3_blk_fail_all_requests(dev); } - if (best_server == -1 || kthread_should_stop() || dev->thread_discover == NULL) { + if (best_server.host.type == 0 || kthread_should_stop() || dev->thread_discover == NULL) { // No alt server could be reached at all or thread should stop if (best_sock != NULL) { // Should never happen actually @@ -516,8 +465,8 @@ error: continue; } - do_change = ready && best_server != current_server && (ktime_to_us(start) & 3) != 0 && - RTT_THRESHOLD_FACTOR(dev->cur_rtt) > best_rtt + 1500; + do_change = ready && !is_same_server(&best_server, &dev->cur_server) && (ktime_to_us(start) & 3) != 0 + && RTT_THRESHOLD_FACTOR(dev->cur_rtt) > best_rtt + 1500; if (ready && !do_change) { spin_lock_irqsave(&dev->blk_lock, irqflags); @@ -536,14 +485,16 @@ error: // take server with lowest rtt // if a (dis)connect is already in progress, we do nothing, this is not panic mode if (do_change && atomic_cmpxchg(&dev->connection_lock, 0, 1) == 0) { - dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", best_server, + dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", -1, // XXX (unsigned long long)best_rtt, (unsigned long long)dev->cur_rtt); kfree(buf); dev->better_sock = best_sock; // Take shortcut by continuing to use open connection put_task_struct(dev->thread_discover); dev->thread_discover = NULL; dnbd3_net_disconnect(dev); - memcpy(&dev->cur_server, &dev->alt_servers[best_server], sizeof(dev->cur_server)); + spin_lock_irqsave(&dev->blk_lock, irqflags); + dev->cur_server = best_server; + spin_unlock_irqrestore(&dev->blk_lock, irqflags); dev->cur_rtt = best_rtt; dnbd3_net_connect(dev); atomic_set(&dev->connection_lock, 0); @@ -810,32 +761,30 @@ static int dnbd3_net_receive(void *data) continue; case CMD_GET_SERVERS: - if (!dev->use_server_provided_alts) { - remaining = dnbd3_reply.size; - goto consume_payload; - } - spin_lock_irqsave(&dev->blk_lock, irqflags); - dev->new_servers_num = 0; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); - count = MIN(NUMBER_SERVERS, dnbd3_reply.size / sizeof(dnbd3_server_entry_t)); - - if (count != 0) { - iov.iov_base = dev->new_servers; - iov.iov_len = count * sizeof(dnbd3_server_entry_t); - if (kernel_recvmsg(dev->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)), - msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) { - if (!atomic_read(&dev->connection_lock)) - dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n"); - ret = -EINVAL; - goto cleanup; + remaining = dnbd3_reply.size; + if (dev->use_server_provided_alts) { + dnbd3_server_entry_t new_server; + + while (remaining >= sizeof(dnbd3_server_entry_t)) { + iov.iov_base = &new_server; + iov.iov_len = sizeof(dnbd3_server_entry_t); + if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, + msg.msg_flags) != sizeof(dnbd3_server_entry_t)) { + if (!atomic_read(&dev->connection_lock)) + dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n"); + ret = -EINVAL; + goto cleanup; + } + // TODO: Log + if (new_server.failures == 0) { // ADD + dnbd3_add_server(dev, &new_server.host); + } else { // REM + dnbd3_rem_server(dev, &new_server.host); + } + remaining -= sizeof(dnbd3_server_entry_t); } - spin_lock_irqsave(&dev->blk_lock, irqflags); - dev->new_servers_num = count; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); } - // If there were more servers than accepted, remove the remaining data from the socket buffer - remaining = dnbd3_reply.size - (count * sizeof(dnbd3_server_entry_t)); -consume_payload: + // Drain any payload still on the wire while (remaining > 0) { count = MIN(sizeof(dnbd3_reply), remaining); // Abuse the reply struct as the receive buffer @@ -1027,7 +976,7 @@ int dnbd3_net_connect(dnbd3_device_t *dev) dnbd3_reply_t dnbd3_reply; struct msghdr msg; struct kvec iov[2]; - uint16_t rid; + uint16_t rid, proto_version; char *name; int mlen; @@ -1082,8 +1031,11 @@ int dnbd3_net_connect(dnbd3_device_t *dev) } // handle/check reply payload serializer_reset_read(&dev->payload_buffer, dnbd3_reply.size); - dev->cur_server.protocol_version = serializer_get_uint16(&dev->payload_buffer); - if (dev->cur_server.protocol_version < MIN_SUPPORTED_SERVER) { + proto_version = serializer_get_uint16(&dev->payload_buffer); + spin_lock_irqsave(&dev->blk_lock, irqflags); + dev->cur_server.protocol_version = proto_version; + spin_unlock_irqrestore(&dev->blk_lock, irqflags); + if (proto_version < MIN_SUPPORTED_SERVER) { dnbd3_dev_err_host_cur(dev, "server version is lower than min supported version\n"); goto error; } @@ -1205,8 +1157,10 @@ error: sock_release(dev->sock); dev->sock = NULL; } + spin_lock_irqsave(&dev->blk_lock, irqflags); dev->cur_server.host.type = 0; dev->cur_server.host.port = 0; + spin_unlock_irqrestore(&dev->blk_lock, irqflags); kfree(req1); return -1; @@ -1215,6 +1169,7 @@ error: int dnbd3_net_disconnect(dnbd3_device_t *dev) { struct task_struct *thread = NULL; + unsigned long irqflags; int ret; dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n"); @@ -1283,8 +1238,10 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) sock_release(dev->sock); dev->sock = NULL; } + spin_lock_irqsave(&dev->blk_lock, irqflags); dev->cur_server.host.type = 0; dev->cur_server.host.port = 0; + spin_unlock_irqrestore(&dev->blk_lock, irqflags); return 0; } |