summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorSimon Rettberg2021-03-23 16:00:55 +0100
committerSimon Rettberg2021-03-23 16:00:55 +0100
commitb9c11caeb31d1066979b8554f565f24abfe475f6 (patch)
tree8a298a7122328ddd1ff2340b74d19667e3b107c2 /src
parent[SERVER] Fix compiler warning (diff)
downloaddnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.tar.gz
dnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.tar.xz
dnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.zip
[KERNEL] Synchronous add/remove of alt-servers via IOCTL
Diffstat (limited to 'src')
-rw-r--r--src/kernel/blk.c91
-rw-r--r--src/kernel/dnbd3_main.c82
-rw-r--r--src/kernel/dnbd3_main.h14
-rw-r--r--src/kernel/net.c207
4 files changed, 229 insertions, 165 deletions
diff --git a/src/kernel/blk.c b/src/kernel/blk.c
index 63a2f8c..661bc2e 100644
--- a/src/kernel/blk.c
+++ b/src/kernel/blk.c
@@ -63,9 +63,6 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
struct request_queue *blk_queue = dev->disk->queue;
char *imgname = NULL;
dnbd3_ioctl_t *msg = NULL;
- dnbd3_server_entry_t server;
- dnbd3_server_t old_server;
- dnbd3_server_t *alt_server;
unsigned long irqflags;
int i = 0;
u8 locked = 0;
@@ -197,27 +194,33 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
}
locked = 1;
if (dev->imgname == NULL) {
- result = -ENOENT;
+ result = -ENOTCONN;
} else if (msg == NULL) {
result = -EINVAL;
} else {
- /* convert host to dnbd3-server for switching */
- memcpy(&server.host, &msg->hosts[0], sizeof(server.host));
- server.failures = 0;
+ dnbd3_server_t *alt_server;
- alt_server = get_existing_server(&server, dev);
+ mutex_lock(&dev->alt_servers_lock);
+ alt_server = get_existing_server(&msg->hosts[0], dev);
if (alt_server == NULL) {
+ mutex_unlock(&dev->alt_servers_lock);
/* specified server is not known, so do not switch */
- result = -EINVAL;
+ result = -ENOENT;
} else {
/* specified server is known, so try to switch to it */
- if (!is_same_server(&dev->cur_server, alt_server)) {
- if (alt_server->host.type == HOST_IP4)
+ dnbd3_server_t new_server = *alt_server;
+
+ new_server = *alt_server;
+ mutex_unlock(&dev->alt_servers_lock);
+ if (!is_same_server(&dev->cur_server, &new_server)) {
+ dnbd3_server_t old_server;
+
+ if (new_server.host.type == HOST_IP4)
dev_info(dnbd3_device_to_dev(dev), "manual server switch to %pI4\n",
- alt_server->host.addr);
+ new_server.host.addr);
else
dev_info(dnbd3_device_to_dev(dev), "manual server switch to [%pI6]\n",
- alt_server->host.addr);
+ new_server.host.addr);
/* save current working server */
/* lock device to get consistent copy of current working server */
spin_lock_irqsave(&dev->blk_lock, irqflags);
@@ -228,7 +231,7 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
dnbd3_net_disconnect(dev);
/* connect to new specified server (switching) */
- memcpy(&dev->cur_server, alt_server, sizeof(dev->cur_server));
+ memcpy(&dev->cur_server, &new_server, sizeof(dev->cur_server));
result = dnbd3_net_connect(dev);
if (result != 0) {
/* reconnect with old server if switching has failed */
@@ -255,24 +258,50 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
case IOCTL_ADD_SRV:
case IOCTL_REM_SRV:
- if (dev->imgname == NULL)
- result = -ENOENT;
- else if (msg == NULL)
+ if (dev->imgname == NULL) {
+ result = -ENOTCONN;
+ break;
+ }
+ if (msg == NULL) {
result = -EINVAL;
-
- /* protect access to 'new_servers_num' and 'new_servers' */
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- if (dev->new_servers_num >= NUMBER_SERVERS) {
- result = -EAGAIN;
- } else {
- /* add or remove specified server */
- memcpy(&dev->new_servers[dev->new_servers_num].host, &msg->hosts[0], sizeof(msg->hosts[0]));
- dev->new_servers[dev->new_servers_num].failures =
- (cmd == IOCTL_ADD_SRV ? 0 : 1); // 0 = ADD, 1 = REM
- ++dev->new_servers_num;
- result = 0;
+ break;
+ }
+ if (cmd == IOCTL_ADD_SRV) {
+ dnbd3_host_t *host = &msg->hosts[0];
+
+ result = dnbd3_add_server(dev, host);
+ if (result == -EEXIST) {
+ // Exists
+ if (host->type == HOST_IP4) {
+ dev_info(dnbd3_device_to_dev(dev), "alt server %pI4 already exists\n",
+ host->addr);
+ } else {
+ dev_info(dnbd3_device_to_dev(dev), "alt server [%pI6] already exists\n",
+ host->addr);
+ }
+ } else if (result == -ENOSPC) {
+ if (host->type == HOST_IP4) {
+ dev_info(dnbd3_device_to_dev(dev), "cannot add %pI4; no free slot\n",
+ host->addr);
+ } else {
+ dev_info(dnbd3_device_to_dev(dev), "cannot add [%pI6]; no free slot\n",
+ host->addr);
+ }
+ }
+ } else { // IOCTL_REM_SRV
+ dnbd3_host_t *host = &msg->hosts[0];
+
+ result = dnbd3_rem_server(dev, &msg->hosts[0]);
+ if (result == -ENOENT) {
+ if (host->type == HOST_IP4) {
+ dev_info(dnbd3_device_to_dev(dev), "alt server %pI4 not found\n",
+ host->addr);
+ } else {
+ dev_info(dnbd3_device_to_dev(dev), "alt server [%pI6] not found\n",
+ host->addr);
+ }
+ }
}
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
break;
case BLKFLSBUF:
@@ -343,6 +372,7 @@ int dnbd3_blk_add_device(dnbd3_device_t *dev, int minor)
dev->imgname = NULL;
dev->rid = 0;
dev->update_available = 0;
+ mutex_init(&dev->alt_servers_lock);
memset(dev->alt_servers, 0, sizeof(dev->alt_servers[0]) * NUMBER_SERVERS);
dev->thread_send = NULL;
dev->thread_receive = NULL;
@@ -433,6 +463,7 @@ int dnbd3_blk_del_device(dnbd3_device_t *dev)
del_gendisk(dev->disk);
blk_cleanup_queue(dev->queue);
blk_mq_free_tag_set(&dev->tag_set);
+ mutex_destroy(&dev->alt_servers_lock);
put_disk(dev->disk);
return 0;
}
diff --git a/src/kernel/dnbd3_main.c b/src/kernel/dnbd3_main.c
index 6e5b4a7..17d553e 100644
--- a/src/kernel/dnbd3_main.c
+++ b/src/kernel/dnbd3_main.c
@@ -41,21 +41,93 @@ int is_same_server(const dnbd3_server_t *const a, const dnbd3_server_t *const b)
(0 == memcmp(a->host.addr, b->host.addr, (a->host.type == HOST_IP4 ? 4 : 16)));
}
-dnbd3_server_t *get_existing_server(const dnbd3_server_entry_t *const newserver, dnbd3_device_t *const dev)
+/**
+ * Get a free slot pointer from the alt_servers list. Tries to find an
+ * entirely empty slot first, then looks for a slot with a server that
+ * wasn't reachable recently, finally returns NULL if none of the
+ * conditions match.
+ * The caller has to hold dev->alt_servers_lock.
+ */
+static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t *const dev)
{
int i;
for (i = 0; i < NUMBER_SERVERS; ++i) {
- if ((newserver->host.type == dev->alt_servers[i].host.type) &&
- (newserver->host.port == dev->alt_servers[i].host.port) &&
- (0 == memcmp(newserver->host.addr, dev->alt_servers[i].host.addr,
- (newserver->host.type == HOST_IP4 ? 4 : 16)))) {
+ if (dev->alt_servers[i].host.type == 0)
+ return &dev->alt_servers[i];
+ }
+ for (i = 0; i < NUMBER_SERVERS; ++i) {
+ if (dev->alt_servers[i].failures > 10)
+ return &dev->alt_servers[i];
+ }
+ return NULL;
+}
+
+/**
+ * Returns pointer to existing entry in alt_servers that matches the given
+ * alt server, or NULL if not found.
+ */
+dnbd3_server_t *get_existing_server(const dnbd3_host_t *const newserver, dnbd3_device_t *const dev)
+{
+ int i;
+
+ for (i = 0; i < NUMBER_SERVERS; ++i) {
+ if ((newserver->type == dev->alt_servers[i].host.type) &&
+ (newserver->port == dev->alt_servers[i].host.port) &&
+ (0 == memcmp(newserver->addr, dev->alt_servers[i].host.addr,
+ (newserver->type == HOST_IP4 ? 4 : 16)))) {
return &dev->alt_servers[i];
}
}
return NULL;
}
+int dnbd3_add_server(dnbd3_device_t *dev, dnbd3_host_t *host)
+{
+ int result;
+ dnbd3_server_t *alt_server;
+ /* protect access to 'alt_servers' */
+ mutex_lock(&dev->alt_servers_lock);
+ alt_server = get_existing_server(host, dev);
+ // ADD
+ if (alt_server != NULL) {
+ // Exists
+ result = -EEXIST;
+ } else {
+ // OK add
+ alt_server = get_free_alt_server(dev);
+ if (alt_server == NULL) {
+ result = -ENOSPC;
+ } else {
+ alt_server->host = *host;
+ alt_server->failures = 0;
+ result = 0;
+ }
+ }
+ mutex_unlock(&dev->alt_servers_lock);
+ return result;
+}
+
+int dnbd3_rem_server(dnbd3_device_t *dev, dnbd3_host_t *host)
+{
+ dnbd3_server_t *alt_server;
+ int result;
+ /* protect access to 'alt_servers' */
+ mutex_lock(&dev->alt_servers_lock);
+ alt_server = get_existing_server(host, dev);
+ // REMOVE
+ if (alt_server == NULL) {
+ // Not found
+ result = -ENOENT;
+ } else {
+ // Remove
+ alt_server->host.type = 0;
+ result = 0;
+ }
+ mutex_unlock(&dev->alt_servers_lock);
+ return result;
+}
+
static int __init dnbd3_init(void)
{
int i;
diff --git a/src/kernel/dnbd3_main.h b/src/kernel/dnbd3_main.h
index c5b0930..a69d588 100644
--- a/src/kernel/dnbd3_main.h
+++ b/src/kernel/dnbd3_main.h
@@ -27,6 +27,7 @@
#include <linux/module.h>
#include <linux/blkdev.h>
#include <linux/blk-mq.h>
+#include <linux/mutex.h>
#include <net/sock.h>
#include <dnbd3/config.h>
@@ -53,14 +54,13 @@ typedef struct {
struct kobject kobj;
// network
+ struct mutex alt_servers_lock;
char *imgname;
struct socket *sock;
dnbd3_server_t cur_server;
unsigned long cur_rtt;
serialized_buffer_t payload_buffer;
- dnbd3_server_t alt_servers[NUMBER_SERVERS]; // array of alt servers
- int new_servers_num; // number of new alt servers that are waiting to be copied to above array
- dnbd3_server_entry_t new_servers[NUMBER_SERVERS]; // pending new alt servers
+ dnbd3_server_t alt_servers[NUMBER_SERVERS]; // array of alt servers, protected by altservers_lock
uint8_t discover, panic, update_available, panic_count;
atomic_t connection_lock;
uint8_t use_server_provided_alts;
@@ -85,7 +85,11 @@ typedef struct {
extern inline struct device *dnbd3_device_to_dev(dnbd3_device_t *dev);
extern inline int is_same_server(const dnbd3_server_t *const a, const dnbd3_server_t *const b);
-extern inline dnbd3_server_t *get_existing_server(const dnbd3_server_entry_t *const newserver,
- dnbd3_device_t *const dev);
+
+extern dnbd3_server_t *get_existing_server(const dnbd3_host_t *const newserver, dnbd3_device_t *const dev);
+
+extern int dnbd3_add_server(dnbd3_device_t *dev, dnbd3_host_t *host);
+
+extern int dnbd3_rem_server(dnbd3_device_t *dev, dnbd3_host_t *host);
#endif /* DNBD_H_ */
diff --git a/src/kernel/net.c b/src/kernel/net.c
index e62a9df..6d821fc 100644
--- a/src/kernel/net.c
+++ b/src/kernel/net.c
@@ -100,21 +100,6 @@
static struct socket *dnbd3_connect(dnbd3_device_t *dev, dnbd3_host_t *host);
-static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t *const dev)
-{
- int i;
-
- for (i = 0; i < NUMBER_SERVERS; ++i) {
- if (dev->alt_servers[i].host.type == 0)
- return &dev->alt_servers[i];
- }
- for (i = 0; i < NUMBER_SERVERS; ++i) {
- if (dev->alt_servers[i].failures > 10)
- return &dev->alt_servers[i];
- }
- return NULL;
-}
-
static void dnbd3_net_heartbeat(struct timer_list *arg)
{
dnbd3_device_t *dev = (dnbd3_device_t *)container_of(arg, dnbd3_device_t, hb_timer);
@@ -168,7 +153,7 @@ static int dnbd3_net_discover(void *data)
dnbd3_request_t dnbd3_request;
dnbd3_reply_t dnbd3_reply;
- dnbd3_server_t *alt_server;
+ dnbd3_server_t host_compare, best_server;
struct msghdr msg;
struct kvec iov[2];
@@ -176,11 +161,12 @@ static int dnbd3_net_discover(void *data)
serialized_buffer_t *payload;
uint64_t filesize;
uint16_t rid;
+ uint16_t remote_version;
ktime_t start = 0, end = 0;
unsigned long rtt, best_rtt = 0;
unsigned long irqflags;
- int i, j, isize, best_server, current_server;
+ int i, j, isize;
int turn = 0;
int ready = 0, do_change = 0;
char check_order[NUMBER_SERVERS];
@@ -218,53 +204,7 @@ static int dnbd3_net_discover(void *data)
if (dev->reported_size < 4096)
continue;
- // Check if the list of alt servers needs to be updated and do so if necessary
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- if (dev->new_servers_num) {
- for (i = 0; i < dev->new_servers_num; ++i) {
- if (dev->new_servers[i].host.type != HOST_IP4 &&
- dev->new_servers[i].host.type != HOST_IP6) // Invalid entry?
- continue;
- alt_server = get_existing_server(&dev->new_servers[i], dev);
- if (alt_server != NULL) {
- // Server already known
- if (dev->new_servers[i].failures == 1) {
- // REMOVE request
- if (alt_server->host.type == HOST_IP4)
- dnbd3_dev_dbg_host_cur(dev, "removing alt server %pI4\n",
- alt_server->host.addr);
- else
- dnbd3_dev_dbg_host_cur(dev, "removing alt server [%pI6]\n",
- alt_server->host.addr);
- alt_server->host.type = 0;
- continue;
- }
- // ADD, so just reset fail counter
- alt_server->failures = 0;
- continue;
- }
- if (dev->new_servers[i].failures == 1) // REMOVE, but server is not in list anyways
- continue;
- alt_server = get_free_alt_server(dev);
- if (alt_server == NULL) // All NUMBER_SERVERS slots are taken, ignore entry
- continue;
- // Add new server entry
- alt_server->host = dev->new_servers[i].host;
- if (alt_server->host.type == HOST_IP4)
- dnbd3_dev_dbg_host_cur(dev, "adding alt server %pI4\n", alt_server->host.addr);
- else
- dnbd3_dev_dbg_host_cur(dev, "adding alt server [%pI6]\n",
- alt_server->host.addr);
- alt_server->rtts[0] = alt_server->rtts[1] = alt_server->rtts[2] = alt_server->rtts[3] =
- RTT_UNREACHABLE;
- alt_server->protocol_version = 0;
- alt_server->failures = 0;
- }
- dev->new_servers_num = 0;
- }
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
-
- current_server = best_server = -1;
+ best_server.host.type = 0;
best_rtt = 0xFFFFFFFul;
if (dev->heartbeat_count < STARTUP_MODE_DURATION || dev->panic)
@@ -285,17 +225,19 @@ static int dnbd3_net_discover(void *data)
for (j = 0; j < NUMBER_SERVERS; ++j) {
i = check_order[j];
- if (dev->alt_servers[i].host.type == 0) // Empty slot
- continue;
- if (!dev->panic && dev->alt_servers[i].failures > 50 &&
- (ktime_to_us(start) & 7) !=
- 0) // If not in panic mode, skip server if it failed too many times
- continue;
- if (isize-- <= 0 && !is_same_server(&dev->cur_server, &dev->alt_servers[i]))
- continue;
+ mutex_lock(&dev->alt_servers_lock);
+ host_compare = dev->alt_servers[i];
+ mutex_unlock(&dev->alt_servers_lock);
+ if (host_compare.host.type == 0)
+ continue; // Empty slot
+ if (!dev->panic && host_compare.failures > 50
+ && (ktime_to_us(start) & 7) != 0)
+ continue; // If not in panic mode, skip server if it failed too many times
+ if (isize-- <= 0 && !is_same_server(&dev->cur_server, &host_compare))
+ continue; // Only test isize servers plus current server
// Initialize socket and connect
- sock = dnbd3_connect(dev, &dev->alt_servers[i].host);
+ sock = dnbd3_connect(dev, &host_compare.host);
if (sock == NULL) {
dnbd3_dev_dbg_host_alt(dev, "%s: Couldn't connect\n", __func__);
goto error;
@@ -345,11 +287,11 @@ static int dnbd3_net_discover(void *data)
}
serializer_reset_read(payload, dnbd3_reply.size);
- dev->alt_servers[i].protocol_version = serializer_get_uint16(payload);
- if (dev->alt_servers[i].protocol_version < MIN_SUPPORTED_SERVER) {
+ remote_version = serializer_get_uint16(payload);
+ if (remote_version < MIN_SUPPORTED_SERVER) {
dnbd3_dev_err_host_alt(
dev, "server version too old (client: %d, server: %d, min supported: %d)\n",
- (int)PROTOCOL_VERSION, (int)dev->alt_servers[i].protocol_version,
+ (int)PROTOCOL_VERSION, (int)remote_version,
(int)MIN_SUPPORTED_SERVER);
goto error;
}
@@ -402,7 +344,9 @@ static int dnbd3_net_discover(void *data)
put_task_struct(dev->thread_discover);
dev->thread_discover = NULL;
dnbd3_net_disconnect(dev);
- memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server));
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ dev->cur_server = host_compare;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
dnbd3_net_connect(dev);
atomic_set(&dev->connection_lock, 0);
return 0;
@@ -455,16 +399,22 @@ static int dnbd3_net_discover(void *data)
end = ktime_get_real(); // end rtt measurement
- dev->alt_servers[i].rtts[turn] = (unsigned long)ktime_us_delta(end, start);
+ mutex_lock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->alt_servers[i], &host_compare)) {
+ dev->alt_servers[i].protocol_version = remote_version;
+ dev->alt_servers[i].rtts[turn] = (unsigned long)ktime_us_delta(end, start);
- rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1] + dev->alt_servers[i].rtts[2] +
- dev->alt_servers[i].rtts[3]) /
- 4;
+ rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1]
+ + dev->alt_servers[i].rtts[2] + dev->alt_servers[i].rtts[3])
+ / 4;
+ }
+ dev->alt_servers[i].failures = 0;
+ mutex_unlock(&dev->alt_servers_lock);
if (best_rtt > rtt) {
// This one is better, keep socket open in case we switch
best_rtt = rtt;
- best_server = i;
+ best_server = host_compare;
if (best_sock != NULL)
sock_release(best_sock);
best_sock = sock;
@@ -476,37 +426,36 @@ static int dnbd3_net_discover(void *data)
}
// update cur servers rtt
- if (is_same_server(&dev->cur_server, &dev->alt_servers[i])) {
+ if (is_same_server(&dev->cur_server, &host_compare))
dev->cur_rtt = rtt;
- current_server = i;
- }
-
- dev->alt_servers[i].failures = 0;
continue;
error:
- ++dev->alt_servers[i].failures;
if (sock != NULL) {
sock_release(sock);
sock = NULL;
}
- dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE;
- if (is_same_server(&dev->cur_server, &dev->alt_servers[i])) {
- dev->cur_rtt = RTT_UNREACHABLE;
- current_server = i;
+ mutex_lock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->alt_servers[i], &host_compare)) {
+ ++dev->alt_servers[i].failures;
+ dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE;
}
+ mutex_unlock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->cur_server, &host_compare))
+ dev->cur_rtt = RTT_UNREACHABLE;
+
continue;
}
if (dev->panic) {
- // After 21 retries, bail out by reporting errors to block layer
+ // If probe timeout is set, report error to block layer
if (PROBE_COUNT_TIMEOUT > 0 && dev->panic_count < 255 &&
++dev->panic_count == PROBE_COUNT_TIMEOUT + 1)
dnbd3_blk_fail_all_requests(dev);
}
- if (best_server == -1 || kthread_should_stop() || dev->thread_discover == NULL) {
+ if (best_server.host.type == 0 || kthread_should_stop() || dev->thread_discover == NULL) {
// No alt server could be reached at all or thread should stop
if (best_sock != NULL) {
// Should never happen actually
@@ -516,8 +465,8 @@ error:
continue;
}
- do_change = ready && best_server != current_server && (ktime_to_us(start) & 3) != 0 &&
- RTT_THRESHOLD_FACTOR(dev->cur_rtt) > best_rtt + 1500;
+ do_change = ready && !is_same_server(&best_server, &dev->cur_server) && (ktime_to_us(start) & 3) != 0
+ && RTT_THRESHOLD_FACTOR(dev->cur_rtt) > best_rtt + 1500;
if (ready && !do_change) {
spin_lock_irqsave(&dev->blk_lock, irqflags);
@@ -536,14 +485,16 @@ error:
// take server with lowest rtt
// if a (dis)connect is already in progress, we do nothing, this is not panic mode
if (do_change && atomic_cmpxchg(&dev->connection_lock, 0, 1) == 0) {
- dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", best_server,
+ dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", -1, // XXX
(unsigned long long)best_rtt, (unsigned long long)dev->cur_rtt);
kfree(buf);
dev->better_sock = best_sock; // Take shortcut by continuing to use open connection
put_task_struct(dev->thread_discover);
dev->thread_discover = NULL;
dnbd3_net_disconnect(dev);
- memcpy(&dev->cur_server, &dev->alt_servers[best_server], sizeof(dev->cur_server));
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ dev->cur_server = best_server;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
dev->cur_rtt = best_rtt;
dnbd3_net_connect(dev);
atomic_set(&dev->connection_lock, 0);
@@ -810,32 +761,30 @@ static int dnbd3_net_receive(void *data)
continue;
case CMD_GET_SERVERS:
- if (!dev->use_server_provided_alts) {
- remaining = dnbd3_reply.size;
- goto consume_payload;
- }
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- dev->new_servers_num = 0;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- count = MIN(NUMBER_SERVERS, dnbd3_reply.size / sizeof(dnbd3_server_entry_t));
-
- if (count != 0) {
- iov.iov_base = dev->new_servers;
- iov.iov_len = count * sizeof(dnbd3_server_entry_t);
- if (kernel_recvmsg(dev->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)),
- msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) {
- if (!atomic_read(&dev->connection_lock))
- dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n");
- ret = -EINVAL;
- goto cleanup;
+ remaining = dnbd3_reply.size;
+ if (dev->use_server_provided_alts) {
+ dnbd3_server_entry_t new_server;
+
+ while (remaining >= sizeof(dnbd3_server_entry_t)) {
+ iov.iov_base = &new_server;
+ iov.iov_len = sizeof(dnbd3_server_entry_t);
+ if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len,
+ msg.msg_flags) != sizeof(dnbd3_server_entry_t)) {
+ if (!atomic_read(&dev->connection_lock))
+ dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n");
+ ret = -EINVAL;
+ goto cleanup;
+ }
+ // TODO: Log
+ if (new_server.failures == 0) { // ADD
+ dnbd3_add_server(dev, &new_server.host);
+ } else { // REM
+ dnbd3_rem_server(dev, &new_server.host);
+ }
+ remaining -= sizeof(dnbd3_server_entry_t);
}
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- dev->new_servers_num = count;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
}
- // If there were more servers than accepted, remove the remaining data from the socket buffer
- remaining = dnbd3_reply.size - (count * sizeof(dnbd3_server_entry_t));
-consume_payload:
+ // Drain any payload still on the wire
while (remaining > 0) {
count = MIN(sizeof(dnbd3_reply),
remaining); // Abuse the reply struct as the receive buffer
@@ -1027,7 +976,7 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
dnbd3_reply_t dnbd3_reply;
struct msghdr msg;
struct kvec iov[2];
- uint16_t rid;
+ uint16_t rid, proto_version;
char *name;
int mlen;
@@ -1082,8 +1031,11 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
}
// handle/check reply payload
serializer_reset_read(&dev->payload_buffer, dnbd3_reply.size);
- dev->cur_server.protocol_version = serializer_get_uint16(&dev->payload_buffer);
- if (dev->cur_server.protocol_version < MIN_SUPPORTED_SERVER) {
+ proto_version = serializer_get_uint16(&dev->payload_buffer);
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ dev->cur_server.protocol_version = proto_version;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ if (proto_version < MIN_SUPPORTED_SERVER) {
dnbd3_dev_err_host_cur(dev, "server version is lower than min supported version\n");
goto error;
}
@@ -1205,8 +1157,10 @@ error:
sock_release(dev->sock);
dev->sock = NULL;
}
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
dev->cur_server.host.type = 0;
dev->cur_server.host.port = 0;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
kfree(req1);
return -1;
@@ -1215,6 +1169,7 @@ error:
int dnbd3_net_disconnect(dnbd3_device_t *dev)
{
struct task_struct *thread = NULL;
+ unsigned long irqflags;
int ret;
dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n");
@@ -1283,8 +1238,10 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
sock_release(dev->sock);
dev->sock = NULL;
}
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
dev->cur_server.host.type = 0;
dev->cur_server.host.port = 0;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
return 0;
}