summaryrefslogtreecommitdiffstats
path: root/src/kernel/net.c
diff options
context:
space:
mode:
authorSimon Rettberg2021-03-23 16:00:55 +0100
committerSimon Rettberg2021-03-23 16:00:55 +0100
commitb9c11caeb31d1066979b8554f565f24abfe475f6 (patch)
tree8a298a7122328ddd1ff2340b74d19667e3b107c2 /src/kernel/net.c
parent[SERVER] Fix compiler warning (diff)
downloaddnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.tar.gz
dnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.tar.xz
dnbd3-b9c11caeb31d1066979b8554f565f24abfe475f6.zip
[KERNEL] Synchronous add/remove of alt-servers via IOCTL
Diffstat (limited to 'src/kernel/net.c')
-rw-r--r--src/kernel/net.c207
1 files changed, 82 insertions, 125 deletions
diff --git a/src/kernel/net.c b/src/kernel/net.c
index e62a9df..6d821fc 100644
--- a/src/kernel/net.c
+++ b/src/kernel/net.c
@@ -100,21 +100,6 @@
static struct socket *dnbd3_connect(dnbd3_device_t *dev, dnbd3_host_t *host);
-static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t *const dev)
-{
- int i;
-
- for (i = 0; i < NUMBER_SERVERS; ++i) {
- if (dev->alt_servers[i].host.type == 0)
- return &dev->alt_servers[i];
- }
- for (i = 0; i < NUMBER_SERVERS; ++i) {
- if (dev->alt_servers[i].failures > 10)
- return &dev->alt_servers[i];
- }
- return NULL;
-}
-
static void dnbd3_net_heartbeat(struct timer_list *arg)
{
dnbd3_device_t *dev = (dnbd3_device_t *)container_of(arg, dnbd3_device_t, hb_timer);
@@ -168,7 +153,7 @@ static int dnbd3_net_discover(void *data)
dnbd3_request_t dnbd3_request;
dnbd3_reply_t dnbd3_reply;
- dnbd3_server_t *alt_server;
+ dnbd3_server_t host_compare, best_server;
struct msghdr msg;
struct kvec iov[2];
@@ -176,11 +161,12 @@ static int dnbd3_net_discover(void *data)
serialized_buffer_t *payload;
uint64_t filesize;
uint16_t rid;
+ uint16_t remote_version;
ktime_t start = 0, end = 0;
unsigned long rtt, best_rtt = 0;
unsigned long irqflags;
- int i, j, isize, best_server, current_server;
+ int i, j, isize;
int turn = 0;
int ready = 0, do_change = 0;
char check_order[NUMBER_SERVERS];
@@ -218,53 +204,7 @@ static int dnbd3_net_discover(void *data)
if (dev->reported_size < 4096)
continue;
- // Check if the list of alt servers needs to be updated and do so if necessary
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- if (dev->new_servers_num) {
- for (i = 0; i < dev->new_servers_num; ++i) {
- if (dev->new_servers[i].host.type != HOST_IP4 &&
- dev->new_servers[i].host.type != HOST_IP6) // Invalid entry?
- continue;
- alt_server = get_existing_server(&dev->new_servers[i], dev);
- if (alt_server != NULL) {
- // Server already known
- if (dev->new_servers[i].failures == 1) {
- // REMOVE request
- if (alt_server->host.type == HOST_IP4)
- dnbd3_dev_dbg_host_cur(dev, "removing alt server %pI4\n",
- alt_server->host.addr);
- else
- dnbd3_dev_dbg_host_cur(dev, "removing alt server [%pI6]\n",
- alt_server->host.addr);
- alt_server->host.type = 0;
- continue;
- }
- // ADD, so just reset fail counter
- alt_server->failures = 0;
- continue;
- }
- if (dev->new_servers[i].failures == 1) // REMOVE, but server is not in list anyways
- continue;
- alt_server = get_free_alt_server(dev);
- if (alt_server == NULL) // All NUMBER_SERVERS slots are taken, ignore entry
- continue;
- // Add new server entry
- alt_server->host = dev->new_servers[i].host;
- if (alt_server->host.type == HOST_IP4)
- dnbd3_dev_dbg_host_cur(dev, "adding alt server %pI4\n", alt_server->host.addr);
- else
- dnbd3_dev_dbg_host_cur(dev, "adding alt server [%pI6]\n",
- alt_server->host.addr);
- alt_server->rtts[0] = alt_server->rtts[1] = alt_server->rtts[2] = alt_server->rtts[3] =
- RTT_UNREACHABLE;
- alt_server->protocol_version = 0;
- alt_server->failures = 0;
- }
- dev->new_servers_num = 0;
- }
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
-
- current_server = best_server = -1;
+ best_server.host.type = 0;
best_rtt = 0xFFFFFFFul;
if (dev->heartbeat_count < STARTUP_MODE_DURATION || dev->panic)
@@ -285,17 +225,19 @@ static int dnbd3_net_discover(void *data)
for (j = 0; j < NUMBER_SERVERS; ++j) {
i = check_order[j];
- if (dev->alt_servers[i].host.type == 0) // Empty slot
- continue;
- if (!dev->panic && dev->alt_servers[i].failures > 50 &&
- (ktime_to_us(start) & 7) !=
- 0) // If not in panic mode, skip server if it failed too many times
- continue;
- if (isize-- <= 0 && !is_same_server(&dev->cur_server, &dev->alt_servers[i]))
- continue;
+ mutex_lock(&dev->alt_servers_lock);
+ host_compare = dev->alt_servers[i];
+ mutex_unlock(&dev->alt_servers_lock);
+ if (host_compare.host.type == 0)
+ continue; // Empty slot
+ if (!dev->panic && host_compare.failures > 50
+ && (ktime_to_us(start) & 7) != 0)
+ continue; // If not in panic mode, skip server if it failed too many times
+ if (isize-- <= 0 && !is_same_server(&dev->cur_server, &host_compare))
+ continue; // Only test isize servers plus current server
// Initialize socket and connect
- sock = dnbd3_connect(dev, &dev->alt_servers[i].host);
+ sock = dnbd3_connect(dev, &host_compare.host);
if (sock == NULL) {
dnbd3_dev_dbg_host_alt(dev, "%s: Couldn't connect\n", __func__);
goto error;
@@ -345,11 +287,11 @@ static int dnbd3_net_discover(void *data)
}
serializer_reset_read(payload, dnbd3_reply.size);
- dev->alt_servers[i].protocol_version = serializer_get_uint16(payload);
- if (dev->alt_servers[i].protocol_version < MIN_SUPPORTED_SERVER) {
+ remote_version = serializer_get_uint16(payload);
+ if (remote_version < MIN_SUPPORTED_SERVER) {
dnbd3_dev_err_host_alt(
dev, "server version too old (client: %d, server: %d, min supported: %d)\n",
- (int)PROTOCOL_VERSION, (int)dev->alt_servers[i].protocol_version,
+ (int)PROTOCOL_VERSION, (int)remote_version,
(int)MIN_SUPPORTED_SERVER);
goto error;
}
@@ -402,7 +344,9 @@ static int dnbd3_net_discover(void *data)
put_task_struct(dev->thread_discover);
dev->thread_discover = NULL;
dnbd3_net_disconnect(dev);
- memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server));
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ dev->cur_server = host_compare;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
dnbd3_net_connect(dev);
atomic_set(&dev->connection_lock, 0);
return 0;
@@ -455,16 +399,22 @@ static int dnbd3_net_discover(void *data)
end = ktime_get_real(); // end rtt measurement
- dev->alt_servers[i].rtts[turn] = (unsigned long)ktime_us_delta(end, start);
+ mutex_lock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->alt_servers[i], &host_compare)) {
+ dev->alt_servers[i].protocol_version = remote_version;
+ dev->alt_servers[i].rtts[turn] = (unsigned long)ktime_us_delta(end, start);
- rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1] + dev->alt_servers[i].rtts[2] +
- dev->alt_servers[i].rtts[3]) /
- 4;
+ rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1]
+ + dev->alt_servers[i].rtts[2] + dev->alt_servers[i].rtts[3])
+ / 4;
+ }
+ dev->alt_servers[i].failures = 0;
+ mutex_unlock(&dev->alt_servers_lock);
if (best_rtt > rtt) {
// This one is better, keep socket open in case we switch
best_rtt = rtt;
- best_server = i;
+ best_server = host_compare;
if (best_sock != NULL)
sock_release(best_sock);
best_sock = sock;
@@ -476,37 +426,36 @@ static int dnbd3_net_discover(void *data)
}
// update cur servers rtt
- if (is_same_server(&dev->cur_server, &dev->alt_servers[i])) {
+ if (is_same_server(&dev->cur_server, &host_compare))
dev->cur_rtt = rtt;
- current_server = i;
- }
-
- dev->alt_servers[i].failures = 0;
continue;
error:
- ++dev->alt_servers[i].failures;
if (sock != NULL) {
sock_release(sock);
sock = NULL;
}
- dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE;
- if (is_same_server(&dev->cur_server, &dev->alt_servers[i])) {
- dev->cur_rtt = RTT_UNREACHABLE;
- current_server = i;
+ mutex_lock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->alt_servers[i], &host_compare)) {
+ ++dev->alt_servers[i].failures;
+ dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE;
}
+ mutex_unlock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->cur_server, &host_compare))
+ dev->cur_rtt = RTT_UNREACHABLE;
+
continue;
}
if (dev->panic) {
- // After 21 retries, bail out by reporting errors to block layer
+ // If probe timeout is set, report error to block layer
if (PROBE_COUNT_TIMEOUT > 0 && dev->panic_count < 255 &&
++dev->panic_count == PROBE_COUNT_TIMEOUT + 1)
dnbd3_blk_fail_all_requests(dev);
}
- if (best_server == -1 || kthread_should_stop() || dev->thread_discover == NULL) {
+ if (best_server.host.type == 0 || kthread_should_stop() || dev->thread_discover == NULL) {
// No alt server could be reached at all or thread should stop
if (best_sock != NULL) {
// Should never happen actually
@@ -516,8 +465,8 @@ error:
continue;
}
- do_change = ready && best_server != current_server && (ktime_to_us(start) & 3) != 0 &&
- RTT_THRESHOLD_FACTOR(dev->cur_rtt) > best_rtt + 1500;
+ do_change = ready && !is_same_server(&best_server, &dev->cur_server) && (ktime_to_us(start) & 3) != 0
+ && RTT_THRESHOLD_FACTOR(dev->cur_rtt) > best_rtt + 1500;
if (ready && !do_change) {
spin_lock_irqsave(&dev->blk_lock, irqflags);
@@ -536,14 +485,16 @@ error:
// take server with lowest rtt
// if a (dis)connect is already in progress, we do nothing, this is not panic mode
if (do_change && atomic_cmpxchg(&dev->connection_lock, 0, 1) == 0) {
- dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", best_server,
+ dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", -1, // XXX
(unsigned long long)best_rtt, (unsigned long long)dev->cur_rtt);
kfree(buf);
dev->better_sock = best_sock; // Take shortcut by continuing to use open connection
put_task_struct(dev->thread_discover);
dev->thread_discover = NULL;
dnbd3_net_disconnect(dev);
- memcpy(&dev->cur_server, &dev->alt_servers[best_server], sizeof(dev->cur_server));
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ dev->cur_server = best_server;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
dev->cur_rtt = best_rtt;
dnbd3_net_connect(dev);
atomic_set(&dev->connection_lock, 0);
@@ -810,32 +761,30 @@ static int dnbd3_net_receive(void *data)
continue;
case CMD_GET_SERVERS:
- if (!dev->use_server_provided_alts) {
- remaining = dnbd3_reply.size;
- goto consume_payload;
- }
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- dev->new_servers_num = 0;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- count = MIN(NUMBER_SERVERS, dnbd3_reply.size / sizeof(dnbd3_server_entry_t));
-
- if (count != 0) {
- iov.iov_base = dev->new_servers;
- iov.iov_len = count * sizeof(dnbd3_server_entry_t);
- if (kernel_recvmsg(dev->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)),
- msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) {
- if (!atomic_read(&dev->connection_lock))
- dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n");
- ret = -EINVAL;
- goto cleanup;
+ remaining = dnbd3_reply.size;
+ if (dev->use_server_provided_alts) {
+ dnbd3_server_entry_t new_server;
+
+ while (remaining >= sizeof(dnbd3_server_entry_t)) {
+ iov.iov_base = &new_server;
+ iov.iov_len = sizeof(dnbd3_server_entry_t);
+ if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len,
+ msg.msg_flags) != sizeof(dnbd3_server_entry_t)) {
+ if (!atomic_read(&dev->connection_lock))
+ dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n");
+ ret = -EINVAL;
+ goto cleanup;
+ }
+ // TODO: Log
+ if (new_server.failures == 0) { // ADD
+ dnbd3_add_server(dev, &new_server.host);
+ } else { // REM
+ dnbd3_rem_server(dev, &new_server.host);
+ }
+ remaining -= sizeof(dnbd3_server_entry_t);
}
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- dev->new_servers_num = count;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
}
- // If there were more servers than accepted, remove the remaining data from the socket buffer
- remaining = dnbd3_reply.size - (count * sizeof(dnbd3_server_entry_t));
-consume_payload:
+ // Drain any payload still on the wire
while (remaining > 0) {
count = MIN(sizeof(dnbd3_reply),
remaining); // Abuse the reply struct as the receive buffer
@@ -1027,7 +976,7 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
dnbd3_reply_t dnbd3_reply;
struct msghdr msg;
struct kvec iov[2];
- uint16_t rid;
+ uint16_t rid, proto_version;
char *name;
int mlen;
@@ -1082,8 +1031,11 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
}
// handle/check reply payload
serializer_reset_read(&dev->payload_buffer, dnbd3_reply.size);
- dev->cur_server.protocol_version = serializer_get_uint16(&dev->payload_buffer);
- if (dev->cur_server.protocol_version < MIN_SUPPORTED_SERVER) {
+ proto_version = serializer_get_uint16(&dev->payload_buffer);
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ dev->cur_server.protocol_version = proto_version;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ if (proto_version < MIN_SUPPORTED_SERVER) {
dnbd3_dev_err_host_cur(dev, "server version is lower than min supported version\n");
goto error;
}
@@ -1205,8 +1157,10 @@ error:
sock_release(dev->sock);
dev->sock = NULL;
}
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
dev->cur_server.host.type = 0;
dev->cur_server.host.port = 0;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
kfree(req1);
return -1;
@@ -1215,6 +1169,7 @@ error:
int dnbd3_net_disconnect(dnbd3_device_t *dev)
{
struct task_struct *thread = NULL;
+ unsigned long irqflags;
int ret;
dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n");
@@ -1283,8 +1238,10 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
sock_release(dev->sock);
dev->sock = NULL;
}
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
dev->cur_server.host.type = 0;
dev->cur_server.host.port = 0;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
return 0;
}